diff --git a/aux/ristretto.sage b/aux/ristretto.sage index 24ac5e1..68cfaa4 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -50,7 +50,7 @@ def isqrt(x,exn=InvalidEncodingException("Not on curve")): def isqrt_i(x): """Return 1/sqrt(x) or 1/sqrt(zeta * x)""" - if x==0: return 0 + if x==0: return True,0 gen = x.parent(-1) while is_square(gen): gen = sqrt(gen) if is_square(x): return True,1/sqrt(x) @@ -225,7 +225,6 @@ class RistrettoPoint(QuotientEdwardsPoint): a,d = cls.a,cls.d assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2 x = 2*s*cls.magic / t - if negative(x): x = -x # TODO: doesn't work without resolving x y = (1+a*s^2) / (1-a*s^2) return cls(sgn*x,y) @@ -237,11 +236,11 @@ class RistrettoPoint(QuotientEdwardsPoint): n1 = cls.a*(r+1)*(a+d)*(d-a)/den n2 = r*n1 if is_square(n1): - sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1 + sgn,s,t = 1, xsqrt(n1), -(r-1)*(a+d)^2 / den - 1 else: - sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 + sgn,s,t = -1,-xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 - return cls.fromJacobiQuartic(s,t,sgn) + return cls.fromJacobiQuartic(s,t) @classmethod @optimized_version_of("elligatorSpec") @@ -257,8 +256,9 @@ class RistrettoPoint(QuotientEdwardsPoint): else: sgn,twiddle = -1,r0*cls.qnr isri *= twiddle s = isri*num - t = isri*s*(r-1)*(d+a)^2 + sgn - return cls.fromJacobiQuartic(s,t,sgn) + t = -sgn*isri*s*(r-1)*(d+a)^2 - 1 + if negative(s) == iss: s = -s + return cls.fromJacobiQuartic(s,t) class Decaf_1_1_Point(QuotientEdwardsPoint): @@ -364,8 +364,6 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): """Convert point from its Jacobi Quartic representation""" a,d = cls.a,cls.d if s==0: return cls() - altx = 2*s*cls.isoMagic / t - if negative(altx): t = -t x = 2*s / (1+a*s^2) y = (1-a*s^2) / t return cls(x,sgn*y) @@ -379,16 +377,29 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): n1 = (r+1)*(a-2*d)/den n2 = r*n1 if is_square(n1): - sgn,s,t = 1,xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1 + sgn,s,t = 1, xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1 else: - sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1 + sgn,s,t = -1, -xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1 - return cls.fromJacobiQuartic(s,t,sgn) + return cls.fromJacobiQuartic(s,t) @classmethod @optimized_version_of("elligatorSpec") def elligator(cls,r0): - return cls.elligatorSpec(r0) + a,d = cls.a,cls.d + r0 = cls.bytesToGf(r0) + r = cls.qnr * r0^2 + den = (d*r-(d-a))*((d-a)*r-d) + num = (r+1)*(a-2*d) + + iss,isri = isqrt_i(num*den) + if iss: sgn,twiddle = 1,1 + else: sgn,twiddle = -1,r0*cls.qnr + isri *= twiddle + s = isri*num + t = -sgn*isri*s*(r-1)*(a-2*d)^2 - 1 + if negative(s) == iss: s = -s + return cls.fromJacobiQuartic(s,t) class Ed25519Point(RistrettoPoint): F = GF(2^255-19) @@ -477,7 +488,26 @@ class TestFailedException(Exception): pass def test(cls,n): print "Testing curve %s" % cls.__name__ - # TODO: test corner cases like 0,1,i + + specials = [1] + ii = cls.F(-1) + while is_square(ii): + specials.append(ii) + ii = sqrt(ii) + specials.append(ii) + for i in specials: + if negative(cls.F(i)): i = -i + i = enc_le(i,cls.encLen) + try: + Q = cls.decode(i) + QE = Q.encode() + if QE != i: + raise TestFailedException("Round trip special %s != %s" % + (binascii.hexlify(QE),binascii.hexlify(i))) + except NotOnCurveException: pass + except InvalidEncodingException: pass + + P = cls.base() Q = cls() for i in xrange(n): @@ -520,6 +550,13 @@ testElligator(TwistedEd448GoldilocksPoint,100) def gangtest(classes,n): print "Gang test",[cls.__name__ for cls in classes] + specials = [1] + ii = classes[0].F(-1) + while is_square(ii): + specials.append(ii) + ii = sqrt(ii) + specials.append(ii) + for i in xrange(n): rets = [bytes((cls.base()*i).encode()) for cls in classes] if len(set(rets)) != 1: @@ -528,7 +565,9 @@ def gangtest(classes,n): print c,binascii.hexlify(ret) print - r0 = randombytes(classes[0].encLen) + if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen) + else: r0 = randombytes(classes[0].encLen) + rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes] if len(set(rets)) != 1: print "Divergence in elligator at %d" % i