mldsa: add benchmark for Verify

This commit is contained in:
Sun Yimin 2025-05-30 15:25:37 +08:00 committed by GitHub
parent 8fc001fb45
commit b218e76328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 37 deletions

View File

@ -7,7 +7,6 @@
package mldsa
import (
"fmt"
"math/big"
mathrand "math/rand/v2"
"testing"
@ -110,10 +109,6 @@ func TestFieldMul(t *testing.T) {
}
}
}
for _, z := range zetasMontgomery {
fmt.Printf("%v, ", fieldReduce(uint64(z)))
}
fmt.Println()
}
func TestFieldBarrettMul(t *testing.T) {
@ -138,67 +133,54 @@ func randomRingElement() ringElement {
func TestNTT(t *testing.T) {
r := randomRingElement()
r1 := r
r2 := ntt(r)
r3 := barrettNTT(r1)
for i, v := range r3 {
if v != r2[i] {
t.Errorf("expected %v, got %v", v, r2[i])
rNTT := ntt(r)
rBarretNTT := barrettNTT(r)
for i, v := range rNTT {
if v != rBarretNTT[i] {
t.Errorf("expected %v, got %v", v, rBarretNTT[i])
}
}
}
func TestInverseNTT(t *testing.T) {
r := randomRingElement()
r1 := r
r2 := ntt(r1)
r3 := inverseNTT(r2)
ret := inverseNTT(ntt(r))
for i, v := range r {
if v != fieldReduce(uint64(r3[i])) {
t.Errorf("expected %v, got %v", v, fieldReduce(uint64(r3[i])))
if v != fieldReduce(uint64(ret[i])) {
t.Errorf("expected %v, got %v", v, fieldReduce(uint64(ret[i])))
}
}
}
func TestInverseBarrettNTT(t *testing.T) {
r := randomRingElement()
r1 := r
r2 := barrettNTT(r1)
r3 := inverseBarrettNTT(r2)
ret := inverseBarrettNTT(barrettNTT(r))
for i, v := range r {
if v != r3[i] {
t.Errorf("expected %v, got %v", v, r3[i])
if v != ret[i] {
t.Errorf("expected %v, got %v", v, ret[i])
}
}
}
// this is the real use case for NTT:
//
// - convert to NTT
// - multiply in NTT
// - inverse NTT
// - convert to NTT
// - multiply in NTT
// - inverse NTT
func TestInverseNTTWithMultiply(t *testing.T) {
r1 := randomRingElement()
r2 := randomRingElement()
// Montgomery Method
r11 := r1
r111 := ntt(r11)
r22 := r2
r222 := ntt(r22)
r31 := nttMul(r111, r222)
r32 := inverseNTT(r31)
ret1 := inverseNTT(nttMul(ntt(r1), ntt(r2)))
// Barrett Method
b11 := barrettNTT(r1)
b22 := barrettNTT(r2)
r33 := nttBarrettMul(b11, b22)
r34 := inverseBarrettNTT(r33)
ret2 := inverseBarrettNTT(nttBarrettMul(barrettNTT(r1), barrettNTT(r2)))
// Check if the results are equal
for i := range r32 {
if r32[i] != r34[i] {
t.Errorf("expected %v, got %v", r34[i], r32[i])
for i := range ret1 {
if ret1[i] != ret2[i] {
t.Errorf("expected %v, got %v", ret2[i], ret1[i])
}
}
}

View File

@ -339,3 +339,23 @@ func BenchmarkSign44(b *testing.B) {
}
}
}
func BenchmarkVerify44(b *testing.B) {
c := sigVer44InternalProjectionCases[0]
pk, _ := hex.DecodeString(c.pk)
sig, _ := hex.DecodeString(c.sig)
msg, _ := hex.DecodeString(c.message)
ctx, _ := hex.DecodeString(c.context)
pub, err := NewPublicKey44(pk)
if err != nil {
b.Fatalf("NewPublicKey44 failed: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
if !pub.Verify(sig, msg, ctx) {
b.Errorf("Verify failed")
}
}
}

View File

@ -329,3 +329,23 @@ func BenchmarkSign65(b *testing.B) {
}
}
}
func BenchmarkVerify65(b *testing.B) {
c := sigVer65InternalProjectionCases[1]
pk, _ := hex.DecodeString(c.pk)
sig, _ := hex.DecodeString(c.sig)
msg, _ := hex.DecodeString(c.message)
ctx, _ := hex.DecodeString(c.context)
pub, err := NewPublicKey65(pk)
if err != nil {
b.Fatalf("NewPublicKey65 failed: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
if !pub.Verify(sig, msg, ctx) {
b.Errorf("Verify failed")
}
}
}

View File

@ -289,3 +289,23 @@ func BenchmarkSign87(b *testing.B) {
}
}
}
func BenchmarkVerify87(b *testing.B) {
c := sigVer87InternalProjectionCases[2]
pk, _ := hex.DecodeString(c.pk)
sig, _ := hex.DecodeString(c.sig)
msg, _ := hex.DecodeString(c.message)
ctx, _ := hex.DecodeString(c.context)
pub, err := NewPublicKey87(pk)
if err != nil {
b.Fatalf("NewPublicKey87 failed: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
if !pub.Verify(sig, msg, ctx) {
b.Errorf("Verify failed")
}
}
}