From b9b855f172eb78b8d66700f6fea3deb4e4726818 Mon Sep 17 00:00:00 2001 From: Michael Hamburg Date: Tue, 22 Aug 2017 10:11:22 -0700 Subject: [PATCH] passes gang tests. ship it? --- aux/ristretto.sage | 85 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/aux/ristretto.sage b/aux/ristretto.sage index 1df7e4d..24ac5e1 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -308,6 +308,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): x,y,z,t = self.xyzt() if self.cofactor == 8: + # Cofactor 8 version num = (z+y)*(z-y) den = x*y tmp = isqrt(num*(a-d)*den^2) @@ -327,6 +328,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): s = tmp*den*yisr*(tiisr*z - 1) else: + # Much simpler cofactor 4 version num = (x+t)*(x-t) isr = isqrt(num*(a-d)*x^2) ratio = isr*num @@ -339,8 +341,55 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): @optimized_version_of("decodeSpec") def decode(cls,s): """Decode, optimized version""" - return cls.decodeSpec(s) # TODO + a,d = cls.a,cls.d + s = cls.bytesToGf(s,mustBePositive=True) + + if s==0: return cls() + s2 = s^2 + den = 1+a*s2 + num = den^2 - 4*d*s2 + isr = isqrt(num*den^2) + altx = 2*s*isr*den*cls.isoMagic + if negative(altx): isr = -isr + x = 2*s *isr^2*den*num + y = (1-a*s^2) * isr*den + + if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0): + raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) + + return cls(x,y) + @classmethod + def fromJacobiQuartic(cls,s,t,sgn=1): + """Convert point from its Jacobi Quartic representation""" + a,d = cls.a,cls.d + if s==0: return cls() + altx = 2*s*cls.isoMagic / t + if negative(altx): t = -t + x = 2*s / (1+a*s^2) + y = (1-a*s^2) / t + return cls(x,sgn*y) + + @classmethod + def elligatorSpec(cls,r0): + a,d = cls.a,cls.d + r = cls.qnr * cls.bytesToGf(r0)^2 + + den = (d*r-(d-a))*((d-a)*r-d) + n1 = (r+1)*(a-2*d)/den + n2 = r*n1 + if is_square(n1): + sgn,s,t = 1,xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1 + else: + sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1 + + return cls.fromJacobiQuartic(s,t,sgn) + + @classmethod + @optimized_version_of("elligatorSpec") + def elligator(cls,r0): + return cls.elligatorSpec(r0) + class Ed25519Point(RistrettoPoint): F = GF(2^255-19) d = F(-121665/121666) @@ -427,6 +476,7 @@ class IsoEd25519Point(Decaf_1_1_Point): class TestFailedException(Exception): pass def test(cls,n): + print "Testing curve %s" % cls.__name__ # TODO: test corner cases like 0,1,i P = cls.base() Q = cls() @@ -451,30 +501,39 @@ def test(cls,n): Q2 = Q0*(r+1) if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") Q = Q1 + test(Ed25519Point,100) test(IsoEd25519Point,100) test(IsoEd448Point,100) test(TwistedEd448GoldilocksPoint,100) test(Ed448GoldilocksPoint,100) + + +def testElligator(cls,n): + print "Testing elligator on %s" % cls.__name__ + for i in xrange(n): + cls.elligator(randombytes(cls.encLen)) +testElligator(Ed25519Point,100) +testElligator(IsoEd448Point,100) +testElligator(Ed448GoldilocksPoint,100) +testElligator(TwistedEd448GoldilocksPoint,100) def gangtest(classes,n): + print "Gang test",[cls.__name__ for cls in classes] for i in xrange(n): rets = [bytes((cls.base()*i).encode()) for cls in classes] if len(set(rets)) != 1: - print "Divergence at %d" % i + print "Divergence in encode at %d" % i + for c,ret in zip(classes,rets): + print c,binascii.hexlify(ret) + print + + r0 = randombytes(classes[0].encLen) + rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes] + if len(set(rets)) != 1: + print "Divergence in elligator at %d" % i for c,ret in zip(classes,rets): print c,binascii.hexlify(ret) print gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100) gangtest([Ed25519Point,IsoEd25519Point],100) - - - - -def testElligator(cls,n): - for i in xrange(n): - cls.elligator(randombytes(cls.encLen)) -testElligator(Ed25519Point,100) -testElligator(IsoEd448Point,100) -# testElligator(Ed448GoldilocksPoint,100) -# testElligator(TwistedEd448GoldilocksPoint,100)