mirror of
https://github.com/emmansun/gmsm.git
synced 2025-06-03 01:44:54 +00:00
mldsa: add benchmark for Verify
This commit is contained in:
parent
8fc001fb45
commit
b218e76328
@ -7,7 +7,6 @@
|
||||
package mldsa
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
mathrand "math/rand/v2"
|
||||
"testing"
|
||||
@ -110,10 +109,6 @@ func TestFieldMul(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, z := range zetasMontgomery {
|
||||
fmt.Printf("%v, ", fieldReduce(uint64(z)))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func TestFieldBarrettMul(t *testing.T) {
|
||||
@ -138,67 +133,54 @@ func randomRingElement() ringElement {
|
||||
|
||||
func TestNTT(t *testing.T) {
|
||||
r := randomRingElement()
|
||||
r1 := r
|
||||
r2 := ntt(r)
|
||||
r3 := barrettNTT(r1)
|
||||
for i, v := range r3 {
|
||||
if v != r2[i] {
|
||||
t.Errorf("expected %v, got %v", v, r2[i])
|
||||
rNTT := ntt(r)
|
||||
rBarretNTT := barrettNTT(r)
|
||||
for i, v := range rNTT {
|
||||
if v != rBarretNTT[i] {
|
||||
t.Errorf("expected %v, got %v", v, rBarretNTT[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInverseNTT(t *testing.T) {
|
||||
r := randomRingElement()
|
||||
r1 := r
|
||||
r2 := ntt(r1)
|
||||
r3 := inverseNTT(r2)
|
||||
ret := inverseNTT(ntt(r))
|
||||
for i, v := range r {
|
||||
if v != fieldReduce(uint64(r3[i])) {
|
||||
t.Errorf("expected %v, got %v", v, fieldReduce(uint64(r3[i])))
|
||||
if v != fieldReduce(uint64(ret[i])) {
|
||||
t.Errorf("expected %v, got %v", v, fieldReduce(uint64(ret[i])))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInverseBarrettNTT(t *testing.T) {
|
||||
r := randomRingElement()
|
||||
r1 := r
|
||||
r2 := barrettNTT(r1)
|
||||
r3 := inverseBarrettNTT(r2)
|
||||
ret := inverseBarrettNTT(barrettNTT(r))
|
||||
for i, v := range r {
|
||||
if v != r3[i] {
|
||||
t.Errorf("expected %v, got %v", v, r3[i])
|
||||
if v != ret[i] {
|
||||
t.Errorf("expected %v, got %v", v, ret[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// this is the real use case for NTT:
|
||||
//
|
||||
// - convert to NTT
|
||||
// - multiply in NTT
|
||||
// - inverse NTT
|
||||
// - convert to NTT
|
||||
// - multiply in NTT
|
||||
// - inverse NTT
|
||||
func TestInverseNTTWithMultiply(t *testing.T) {
|
||||
r1 := randomRingElement()
|
||||
r2 := randomRingElement()
|
||||
|
||||
// Montgomery Method
|
||||
r11 := r1
|
||||
r111 := ntt(r11)
|
||||
r22 := r2
|
||||
r222 := ntt(r22)
|
||||
r31 := nttMul(r111, r222)
|
||||
r32 := inverseNTT(r31)
|
||||
ret1 := inverseNTT(nttMul(ntt(r1), ntt(r2)))
|
||||
|
||||
// Barrett Method
|
||||
b11 := barrettNTT(r1)
|
||||
b22 := barrettNTT(r2)
|
||||
r33 := nttBarrettMul(b11, b22)
|
||||
r34 := inverseBarrettNTT(r33)
|
||||
ret2 := inverseBarrettNTT(nttBarrettMul(barrettNTT(r1), barrettNTT(r2)))
|
||||
|
||||
// Check if the results are equal
|
||||
for i := range r32 {
|
||||
if r32[i] != r34[i] {
|
||||
t.Errorf("expected %v, got %v", r34[i], r32[i])
|
||||
for i := range ret1 {
|
||||
if ret1[i] != ret2[i] {
|
||||
t.Errorf("expected %v, got %v", ret2[i], ret1[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -339,3 +339,23 @@ func BenchmarkSign44(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkVerify44(b *testing.B) {
|
||||
c := sigVer44InternalProjectionCases[0]
|
||||
pk, _ := hex.DecodeString(c.pk)
|
||||
sig, _ := hex.DecodeString(c.sig)
|
||||
msg, _ := hex.DecodeString(c.message)
|
||||
ctx, _ := hex.DecodeString(c.context)
|
||||
|
||||
pub, err := NewPublicKey44(pk)
|
||||
if err != nil {
|
||||
b.Fatalf("NewPublicKey44 failed: %v", err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
if !pub.Verify(sig, msg, ctx) {
|
||||
b.Errorf("Verify failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -329,3 +329,23 @@ func BenchmarkSign65(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkVerify65(b *testing.B) {
|
||||
c := sigVer65InternalProjectionCases[1]
|
||||
pk, _ := hex.DecodeString(c.pk)
|
||||
sig, _ := hex.DecodeString(c.sig)
|
||||
msg, _ := hex.DecodeString(c.message)
|
||||
ctx, _ := hex.DecodeString(c.context)
|
||||
|
||||
pub, err := NewPublicKey65(pk)
|
||||
if err != nil {
|
||||
b.Fatalf("NewPublicKey65 failed: %v", err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
if !pub.Verify(sig, msg, ctx) {
|
||||
b.Errorf("Verify failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -289,3 +289,23 @@ func BenchmarkSign87(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkVerify87(b *testing.B) {
|
||||
c := sigVer87InternalProjectionCases[2]
|
||||
pk, _ := hex.DecodeString(c.pk)
|
||||
sig, _ := hex.DecodeString(c.sig)
|
||||
msg, _ := hex.DecodeString(c.message)
|
||||
ctx, _ := hex.DecodeString(c.context)
|
||||
|
||||
pub, err := NewPublicKey87(pk)
|
||||
if err != nil {
|
||||
b.Fatalf("NewPublicKey87 failed: %v", err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
if !pub.Verify(sig, msg, ctx) {
|
||||
b.Errorf("Verify failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user