diff --git a/zk/qndleq/internal_test.go b/zk/qndleq/internal_test.go index 69381492..3ce172a9 100644 --- a/zk/qndleq/internal_test.go +++ b/zk/qndleq/internal_test.go @@ -9,8 +9,8 @@ import ( ) func TestForgedProofSecParamZero(t *testing.T) { - p := big.NewInt(1019) - q := big.NewInt(1021) + // Safe primes: https://oeis.org/A005385 + p, q := big.NewInt(1019), big.NewInt(1187) N := new(big.Int).Mul(p, q) g, err := SampleQn(rand.Reader, N) @@ -49,3 +49,55 @@ func TestForgedProofSecParamZero(t *testing.T) { test.CheckOk(!forged.Verify(g, gx, h, hx, N), "forged proof must be rejected", t) } + +func TestOutOfBounds(t *testing.T) { + // Safe primes: https://oeis.org/A005385 + p, q := big.NewInt(1019), big.NewInt(1187) + N := new(big.Int).Mul(p, q) + + x := big.NewInt(2) + g, gx := big.NewInt(4), big.NewInt(16) + h, hx := big.NewInt(9), big.NewInt(81) + + invalidValues := []*big.Int{ + new(big.Int).Neg(g), // Negative + big.NewInt(0), // Zero + new(big.Int).Set(N), // N + new(big.Int).Add(N, N), // bigger than N + } + + t.Run("prove", func(t *testing.T) { + for _, invalidValue := range invalidValues { + p, err := Prove(rand.Reader, x, invalidValue, gx, h, hx, N, 128) + test.CheckIsErr(t, err, "Prove must fail") + test.CheckOk(p == nil, "proof must be nil", t) + } + }) + + t.Run("verify", func(t *testing.T) { + for _, invalidValue := range invalidValues { + p, err := Prove(rand.Reader, x, g, gx, h, hx, N, 128) + test.CheckNoErr(t, err, "Prove must succeed") + + isValid := p.Verify(invalidValue, gx, h, hx, N) + test.CheckOk(isValid == false, "proof verification must return false", t) + } + }) +} + +func TestChallengeZero(t *testing.T) { + // Safe primes: https://oeis.org/A005385 + p, q := big.NewInt(1019), big.NewInt(1187) + N := new(big.Int).Mul(p, q) + g, gx := big.NewInt(4), big.NewInt(16) // 4^2 == 16 mod N + h, hx := big.NewInt(9), big.NewInt(81) // 9^2 == 81 mod N + + // Proof must fail as challenge is congruent to zero modulo m = (p-1)(q-1)/4 = 509*593. + // c = m * 1079334709418571583321702591065767 + c, _ := new(big.Int).SetString("325783150686773390995072744979517913979", 10) + z, _ := new(big.Int).SetString("909208770437996720153744987183938443953758507571077689625761350579371458087594451128", 10) + invalidProof := Proof{z, c, 128} + + isValid := invalidProof.Verify(g, gx, h, hx, N) + test.CheckOk(isValid == false, "proof verification must fail", t) +} diff --git a/zk/qndleq/qndleq.go b/zk/qndleq/qndleq.go index bafb5d73..446d223b 100644 --- a/zk/qndleq/qndleq.go +++ b/zk/qndleq/qndleq.go @@ -70,32 +70,69 @@ func SampleQn(random io.Reader, N *big.Int) (*big.Int, error) { // Note: this function does not run in constant time because it uses // big.Int arithmetic. func Prove(random io.Reader, x, g, gx, h, hx, N *big.Int, secParam uint) (*Proof, error) { - rSizeBits := uint(N.BitLen()) + 2*secParam - rSizeBytes := (rSizeBits + 7) / 8 - - rBytes := make([]byte, rSizeBytes) - _, err := io.ReadFull(random, rBytes) + err := checkBounds(N, g, gx, h, hx) if err != nil { return nil, err } - r := new(big.Int).SetBytes(rBytes) - gP := new(big.Int).Exp(g, r, N) - hP := new(big.Int).Exp(h, r, N) + rSizeBits := uint(N.BitLen()) + 2*secParam + rSizeBytes := (rSizeBits + 7) / 8 + rBytes := make([]byte, rSizeBytes) - c, err := doChallenge(g, gx, h, hx, gP, hP, N, secParam) - if err != nil { - return nil, err - } + ONE := big.NewInt(1) + const NUM_TRIES = 10 + var r, gP, hP, gc, hc big.Int + for i := 0; i < NUM_TRIES; i++ { + _, err := io.ReadFull(random, rBytes) + if err != nil { + return nil, err + } - z := new(big.Int) - z.Mul(c, x).Add(z, r) + r.SetBytes(rBytes) + gP.Exp(g, &r, N) + hP.Exp(h, &r, N) - return &Proof{z, c, secParam}, nil + c, err := doChallenge(g, gx, h, hx, &gP, &hP, N, secParam) + if err != nil { + return nil, err + } + + // Challenge must not be congruent to zero. + // c != 0 mod m, where m = (p-1)(q-1)/4, and N = p*q. + // Check this by doing an Exp because m is unknown. + // + // This is valid assuming N is the product of two safe prime numbers. + // In the verification equation, c multiplies the witness. + // When c is zero, it removes the witness allowing to trivially + // pass the verification check. + gc.Exp(g, c, N) + hc.Exp(h, c, N) + if gc.Cmp(ONE) != 0 && hc.Cmp(ONE) != 0 { + z := new(big.Int).Mul(c, x) + z.Add(z, &r) + return &Proof{z, c, secParam}, nil + } + } + + return nil, ErrProve } // Verify checks whether x = Log_g(g^x) = Log_h(h^x). func (p Proof) Verify(g, gx, h, hx, N *big.Int) bool { + err := checkBounds(N, g, gx, h, hx) + if err != nil { + return false + } + + // Check c != 0 (mod m), where m = (p-1)(q-1)/4, + // by doing an Exp as m is unknown. + ONE := big.NewInt(1) + gc := new(big.Int).Exp(g, p.c, N) + hc := new(big.Int).Exp(h, p.c, N) + if gc.Cmp(ONE) == 0 || hc.Cmp(ONE) == 0 { + return false + } + gPNum := new(big.Int).Exp(g, p.z, N) gPDen := new(big.Int).Exp(gx, p.c, N) ok := gPDen.ModInverse(gPDen, N) @@ -171,5 +208,27 @@ func doChallenge(g, gx, h, hx, gP, hP, N *big.Int, secParam uint) (*big.Int, err return new(big.Int).SetBytes(cBytes), nil } -// ErrSecParam is returned when the security parameter is less than 128. -var ErrSecParam = errors.New("zk/qndleq: the security parameter must be greater than 128") +// checkBounds returns nil if 0 < x[i] < N for all 0 <= i < len(x); +// otherwise, returns ErrBounds. +func checkBounds(N *big.Int, x ...*big.Int) error { + if N.Sign() <= 0 { + return ErrBounds + } + + for _, xi := range x { + if !(0 < xi.Sign() && xi.Cmp(N) < 0) { + return ErrBounds + } + } + + return nil +} + +var ( + // ErrSecParam is returned when the security parameter is less than 128. + ErrSecParam = errors.New("zk/qndleq: the security parameter must be greater than 128") + // ErrBounds is returned when a value is not in the range 0 to N. + ErrBounds = errors.New("zk/qndleq: input must be greater than 0 and less than N") + // ErrProve is returned when Prove exhausted the number of proof tries. + ErrProve = errors.New("zk/qndleq: exhausted the number of proof tries") +) diff --git a/zk/qndleq/qndleq_test.go b/zk/qndleq/qndleq_test.go index a8af777c..4eb9096d 100644 --- a/zk/qndleq/qndleq_test.go +++ b/zk/qndleq/qndleq_test.go @@ -6,20 +6,24 @@ import ( "testing" "github.com/cloudflare/circl/internal/test" + cmath "github.com/cloudflare/circl/math" "github.com/cloudflare/circl/zk/qndleq" ) func TestProve(t *testing.T) { - const testTimes = 1 << 8 - const SecParam = 128 - one := big.NewInt(1) - max := new(big.Int).Lsh(one, 256) + const ( + testTimes = 1 << 8 + SecParam = 128 + BitLength = 128 // [Warning]: this is only for tests, use a secure bit length above 2048 bits. + ) + + p, err := cmath.SafePrime(rand.Reader, BitLength) + test.CheckNoErr(t, err, "failed to generate a safe prime") + q, err := cmath.SafePrime(rand.Reader, BitLength) + test.CheckNoErr(t, err, "failed to generate a safe prime") + N := new(big.Int).Mul(p, q) for i := 0; i < testTimes; i++ { - N, _ := rand.Int(rand.Reader, max) - if N.Bit(0) == 0 { - N.Add(N, one) - } x, _ := rand.Int(rand.Reader, N) g, err := qndleq.SampleQn(rand.Reader, N) test.CheckNoErr(t, err, "failed to sampleQn") @@ -34,6 +38,20 @@ func TestProve(t *testing.T) { } } +func TestInvalidStatement(t *testing.T) { + // Safe primes: https://oeis.org/A005385 + p, q := big.NewInt(1019), big.NewInt(1187) + N := new(big.Int).Mul(p, q) + g, gx := big.NewInt(4), big.NewInt(16) // 4^2 == 16 mod N + h, hx := big.NewInt(9), big.NewInt(81) // 9^2 == 81 mod N + incorrectX := big.NewInt(3) + + proof, err := qndleq.Prove(rand.Reader, incorrectX, g, gx, h, hx, N, 128) + test.CheckNoErr(t, err, "an alleged proof must be computed") + isValid := proof.Verify(g, gx, h, hx, N) + test.CheckOk(isValid == false, "proof verification must fail", t) +} + func TestSampleQn(t *testing.T) { const testTimes = 1 << 7 one := big.NewInt(1)