mldsa: add benchmark for Verify

This commit is contained in:
Sun Yimin 2025-05-30 15:25:37 +08:00 committed by GitHub
parent 8fc001fb45
commit b218e76328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 37 deletions

View File

@ -7,7 +7,6 @@
package mldsa package mldsa
import ( import (
"fmt"
"math/big" "math/big"
mathrand "math/rand/v2" mathrand "math/rand/v2"
"testing" "testing"
@ -110,10 +109,6 @@ func TestFieldMul(t *testing.T) {
} }
} }
} }
for _, z := range zetasMontgomery {
fmt.Printf("%v, ", fieldReduce(uint64(z)))
}
fmt.Println()
} }
func TestFieldBarrettMul(t *testing.T) { func TestFieldBarrettMul(t *testing.T) {
@ -138,67 +133,54 @@ func randomRingElement() ringElement {
func TestNTT(t *testing.T) { func TestNTT(t *testing.T) {
r := randomRingElement() r := randomRingElement()
r1 := r rNTT := ntt(r)
r2 := ntt(r) rBarretNTT := barrettNTT(r)
r3 := barrettNTT(r1) for i, v := range rNTT {
for i, v := range r3 { if v != rBarretNTT[i] {
if v != r2[i] { t.Errorf("expected %v, got %v", v, rBarretNTT[i])
t.Errorf("expected %v, got %v", v, r2[i])
} }
} }
} }
func TestInverseNTT(t *testing.T) { func TestInverseNTT(t *testing.T) {
r := randomRingElement() r := randomRingElement()
r1 := r ret := inverseNTT(ntt(r))
r2 := ntt(r1)
r3 := inverseNTT(r2)
for i, v := range r { for i, v := range r {
if v != fieldReduce(uint64(r3[i])) { if v != fieldReduce(uint64(ret[i])) {
t.Errorf("expected %v, got %v", v, fieldReduce(uint64(r3[i]))) t.Errorf("expected %v, got %v", v, fieldReduce(uint64(ret[i])))
} }
} }
} }
func TestInverseBarrettNTT(t *testing.T) { func TestInverseBarrettNTT(t *testing.T) {
r := randomRingElement() r := randomRingElement()
r1 := r ret := inverseBarrettNTT(barrettNTT(r))
r2 := barrettNTT(r1)
r3 := inverseBarrettNTT(r2)
for i, v := range r { for i, v := range r {
if v != r3[i] { if v != ret[i] {
t.Errorf("expected %v, got %v", v, r3[i]) t.Errorf("expected %v, got %v", v, ret[i])
} }
} }
} }
// this is the real use case for NTT: // this is the real use case for NTT:
// //
// - convert to NTT // - convert to NTT
// - multiply in NTT // - multiply in NTT
// - inverse NTT // - inverse NTT
func TestInverseNTTWithMultiply(t *testing.T) { func TestInverseNTTWithMultiply(t *testing.T) {
r1 := randomRingElement() r1 := randomRingElement()
r2 := randomRingElement() r2 := randomRingElement()
// Montgomery Method // Montgomery Method
r11 := r1 ret1 := inverseNTT(nttMul(ntt(r1), ntt(r2)))
r111 := ntt(r11)
r22 := r2
r222 := ntt(r22)
r31 := nttMul(r111, r222)
r32 := inverseNTT(r31)
// Barrett Method // Barrett Method
b11 := barrettNTT(r1) ret2 := inverseBarrettNTT(nttBarrettMul(barrettNTT(r1), barrettNTT(r2)))
b22 := barrettNTT(r2)
r33 := nttBarrettMul(b11, b22)
r34 := inverseBarrettNTT(r33)
// Check if the results are equal // Check if the results are equal
for i := range r32 { for i := range ret1 {
if r32[i] != r34[i] { if ret1[i] != ret2[i] {
t.Errorf("expected %v, got %v", r34[i], r32[i]) t.Errorf("expected %v, got %v", ret2[i], ret1[i])
} }
} }
} }

View File

@ -339,3 +339,23 @@ func BenchmarkSign44(b *testing.B) {
} }
} }
} }
func BenchmarkVerify44(b *testing.B) {
c := sigVer44InternalProjectionCases[0]
pk, _ := hex.DecodeString(c.pk)
sig, _ := hex.DecodeString(c.sig)
msg, _ := hex.DecodeString(c.message)
ctx, _ := hex.DecodeString(c.context)
pub, err := NewPublicKey44(pk)
if err != nil {
b.Fatalf("NewPublicKey44 failed: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
if !pub.Verify(sig, msg, ctx) {
b.Errorf("Verify failed")
}
}
}

View File

@ -329,3 +329,23 @@ func BenchmarkSign65(b *testing.B) {
} }
} }
} }
func BenchmarkVerify65(b *testing.B) {
c := sigVer65InternalProjectionCases[1]
pk, _ := hex.DecodeString(c.pk)
sig, _ := hex.DecodeString(c.sig)
msg, _ := hex.DecodeString(c.message)
ctx, _ := hex.DecodeString(c.context)
pub, err := NewPublicKey65(pk)
if err != nil {
b.Fatalf("NewPublicKey65 failed: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
if !pub.Verify(sig, msg, ctx) {
b.Errorf("Verify failed")
}
}
}

View File

@ -289,3 +289,23 @@ func BenchmarkSign87(b *testing.B) {
} }
} }
} }
func BenchmarkVerify87(b *testing.B) {
c := sigVer87InternalProjectionCases[2]
pk, _ := hex.DecodeString(c.pk)
sig, _ := hex.DecodeString(c.sig)
msg, _ := hex.DecodeString(c.message)
ctx, _ := hex.DecodeString(c.context)
pub, err := NewPublicKey87(pk)
if err != nil {
b.Fatalf("NewPublicKey87 failed: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
if !pub.Verify(sig, msg, ctx) {
b.Errorf("Verify failed")
}
}
}