From 4a5aad2241be040bc9cf901dbbf0ceaa90eedc81 Mon Sep 17 00:00:00 2001 From: Michael Hamburg Date: Wed, 16 Aug 2017 18:25:29 -0700 Subject: [PATCH] working on ristretto sage files --- aux/ristretto.sage | 252 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 190 insertions(+), 62 deletions(-) diff --git a/aux/ristretto.sage b/aux/ristretto.sage index 6d032b4..388f5d3 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -1,11 +1,42 @@ - +import binascii class InvalidEncodingException(Exception): pass class NotOnCurveException(Exception): pass +class SpecException(Exception): pass def lobit(x): return int(x) & 1 def hibit(x): return lobit(2*x) def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) +def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) + +def optimized_version_of(spec): + def decorator(f): + def wrapper(self,*args,**kwargs): + try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None + except Exception as e: spec_ans = None,e + try: opt_ans = f(self,*args,**kwargs),None + except Exception as e: opt_ans = None,e + if spec_ans[1] is None and opt_ans[1] is not None: + raise SpecException("Mismatch in %s: spec returned %s but opt threw %s" + % (f.__name__,str(spec_ans[0]),str(opt_ans[1]))) + if spec_ans[1] is not None and opt_ans[1] is None: + raise SpecException("Mismatch in %s: spec threw %s but opt returned %s" + % (f.__name__,str(spec_ans[1]),str(opt_ans[0]))) + if spec_ans[0] != opt_ans[0]: + raise SpecException("Mismatch in %s: %s != %s" + % (f.__name__,str(spec_ans[0]),str(opt_ans[0]))) + if opt_ans[1] is not None: raise opt_ans[1] + else: return opt_ans[0] + wrapper.__name__ = f.__name__ + return wrapper + return decorator + +def xsqrt(x,exn=InvalidEncodingException("Not on curve")): + """Return sqrt(x)""" + if not is_square(x): raise exn + s = sqrt(x) + if lobit(s): s=-s + return s def isqrt(x,exn=InvalidEncodingException("Not on curve")): """Return 1/sqrt(x)""" @@ -13,13 +44,21 @@ def isqrt(x,exn=InvalidEncodingException("Not on curve")): if not is_square(x): raise exn return 1/sqrt(x) +def isqrt_i(x): + """Return 1/sqrt(x) or 1/sqrt(zeta * x)""" + if x==0: return 0 + gen = x.parent(-1) + while is_square(gen): gen = sqrt(gen) + if is_square(x): return True,1/sqrt(x) + else: return False,1/sqrt(x*gen) + class EdwardsPoint(object): """Abstract class for point an an Edwards curve; needs F,a,d to work""" def __init__(self,x=0,y=1): x = self.x = self.F(x) y = self.y = self.F(y) if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: - raise NotOnCurveException() + raise NotOnCurveException(str(self)) def __repr__(self): return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) @@ -57,96 +96,97 @@ class EdwardsPoint(object): x,y = self z = self.F.random_element() return x*z,y*z,z,x*y*z + + def torque(self): + """Apply cofactor group, except keeping the point even""" + if self.cofactor == 8: + return self.__class__(self.y*self.i, self.x*self.i) + else: + return self.__class__(-self.x, -self.y) class RistrettoPoint(EdwardsPoint): """Like current decaf but tweaked for simplicity""" - def __eq__(self,other): x,y = self X,Y = other return x*Y == X*y or x*X == y*Y - @staticmethod - def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")): - if not is_square(x): raise exn - s = sqrt(x) - if negative(s): s=-s - return s + @classmethod + def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False): + """Convert little-endian bytes to field element, sanity check length""" + if len(bytes) != cls.encLen: + raise InvalidEncodingException("wrong length %d" % len(bytes)) + s = dec_le(bytes) + if mustBeProper and s >= cls.F.modulus(): + raise InvalidEncodingException("%d out of range!" % s) + if mustBePositive and lobit(s): + raise InvalidEncodingException("%d is negative!" % s) + return cls.F(s) def encodeSpec(self): """Unoptimized specification for encoding""" x,y = self - if self.cofactor==8 and (lobit(x*y) or x==0): + if self.cofactor==8 and (lobit(x*y) or y==0): (x,y) = (self.i*y,self.i*x) - elif self.cofactor==4 and y==-1: - y = 1 # Doesn't affect impl - if lobit(x): y=-y - s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even")) + if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl + + if lobit(x): x,y = -x,-y + s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self))) return enc_le(s,self.encLen) @classmethod def decodeSpec(cls,s): """Unoptimized specification for decoding""" - if len(s) != cls.encLen: - raise InvalidEncodingException("wrong length %d" % len(s)) - s = dec_le(s) - if s < 0 or s >= cls.F.modulus() or lobit(s): - raise InvalidEncodingException("%d out of range!" % s) - s = cls.F(s) + s = cls.bytesToGf(s,mustBePositive=True) - x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2)) - y = (1-s^2) / (1+s^2) + a,d = cls.a,cls.d + x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2)) + y = (1+a*s^2) / (1-a*s^2) - if cls.cofactor==8 and (lobit(x*y) or x==0): + if cls.cofactor==8 and (lobit(x*y) or y==0): raise InvalidEncodingException("x*y has high bit") return cls(x,y) - + + @optimized_version_of("encodeSpec") def encode(self): + """Encode, optimized version""" + a,d = self.a,self.d x,y,z,t = self.xyzt() - - u1 = (z+y)*(z-y) + + u1 = a*(y+z)*(y-z) u2 = x*y # = t*z - isr = isqrt(u1 * u2^2) + isr = isqrt(u1*u2^2) i1 = isr*u1 i2 = isr*u2 z_inv = i1*i2*t rotate = self.cofactor==8 and lobit(t*z_inv) if rotate: - magic = isqrt(-self.d-1) x,y = y*self.i,x*self.i - den_inv = magic * i1 + den_inv = self.magic * i1 else: den_inv = i2 if lobit(x*z_inv): y = -y s = (z-y) * den_inv - if self.cofactor==8 and s==0: s += 1 if lobit(s): s=-s - ret = enc_le(s,self.encLen) - assert ret == self.encodeSpec() - return ret + return enc_le(s,self.encLen) @classmethod + @optimized_version_of("decodeSpec") def decode(cls,s): - right_answer = cls.decodeSpec(s) + """Decode, optimized version""" + s = cls.bytesToGf(s,mustBePositive=True) - # Sanity check s - if len(s) != cls.encLen: - raise InvalidEncodingException("wrong length %d" % len(s)) - s = dec_le(s) - if s < 0 or s >= cls.F.modulus() or lobit(s): - raise InvalidEncodingException("%d out of range!" % s) - s = cls.F(s) - - yden = 1+s^2 - ynum = 1-s^2 + a,d = cls.a,cls.d + yden = 1-a*s^2 + ynum = 1+a*s^2 yden_sqr = yden^2 - xden_sqr = -cls.d*ynum^2 - yden_sqr + xden_sqr = a*d*ynum^2 - yden_sqr isr = isqrt(xden_sqr * yden_sqr) @@ -157,25 +197,60 @@ class RistrettoPoint(EdwardsPoint): if lobit(x): x = -x y = ynum * yden_inv - if cls.cofactor==8 and (lobit(x*y) or x==0): - raise InvalidEncodingException("x*y has high bit") + if cls.cofactor==8 and (lobit(x*y) or y==0): + raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) - ret = cls(x,y) - assert ret == right_answer - return ret - - def torque(self): - if self.cofactor == 8: - return self.__class__(self.y*self.i, self.x*self.i) + return cls(x,y) + + @classmethod + def fromJacobiQuartic(cls,s,t,sgn=1): + """Convert point from its Jacobi Quartic representation""" + a,d = cls.a,cls.d + assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2 + x = 2*s*cls.magic / t + if lobit(x): x = -x # TODO: doesn't work without resolving x + y = (1+a*s^2) / (1-a*s^2) + return cls(sgn*x,y) + + @classmethod + def elligatorSpec(cls,r0): + a,d = cls.a,cls.d + r = cls.qnr * cls.bytesToGf(r0)^2 + den = (d*r-a)*(a*r-d) + n1 = cls.a*(r+1)*(a+d)*(d-a)/den + n2 = r*n1 + if is_square(n1): + sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1 else: - return self.__class__(-self.x, -self.y) + sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 + + ret = cls.fromJacobiQuartic(s,t,sgn) + return ret - + @classmethod + @optimized_version_of("elligatorSpec") + def elligator(cls,r0): + a,d = cls.a,cls.d + r0 = cls.bytesToGf(r0) + r = cls.qnr * r0^2 + den = (d*r-a)*(a*r-d) + num = cls.a*(r+1)*(a+d)*(d-a) + + iss,isri = isqrt_i(num*den) + if iss: sgn,twiddle = 1,1 + else: sgn,twiddle = -1,r0*cls.qnr + isri *= twiddle + s = isri*num + t = isri*s*(r-1)*(d+a)^2 + sgn + return cls.fromJacobiQuartic(s,t,sgn) + class Ed25519Point(RistrettoPoint): F = GF(2^255-19) d = F(-121665/121666) a = F(-1) i = sqrt(F(-1)) + qnr = i + magic = isqrt(a*d-1) cofactor = 8 encLen = 32 @@ -186,30 +261,72 @@ class Ed25519Point(RistrettoPoint): if lobit(x): x = -x return cls(x,y) -class Ed448Point(RistrettoPoint): +class TwistedEd448GoldilocksPoint(RistrettoPoint): F = GF(2^448-2^224-1) d = F(-39082) a = F(-1) + qnr = -1 + magic = isqrt(a*d-1) cofactor = 4 encLen = 56 - + @classmethod def base(cls): - y = cls.F(6) # FIXME: no it isn't + y = cls.F(6) # TODO: no it isn't x = sqrt((y^2-1)/(cls.d*y^2+1)) if lobit(x): x = -x return cls(x,y) +class Ed448GoldilocksPoint(RistrettoPoint): + # TODO: decaf vs ristretto + F = GF(2^448-2^224-1) + d = F(-39081) + a = F(1) + qnr = -1 + magic = isqrt(a*d-1) + cofactor = 4 + encLen = 56 + + @classmethod + def base(cls): + return cls( + 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555, + 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed + ) + +class IsoEd448Point(RistrettoPoint): + F = GF(2^448-2^224-1) + d = F(1/39081+1) + a = F(1) + qnr = -1 + magic = isqrt(a*d-1) + cofactor = 4 + encLen = 56 + + @classmethod + def base(cls): + # = ..., -3/2 + return cls.decodeSpec(bytearray(binascii.unhexlify( + "00000000000000000000000000000000000000000000000000000000"+ + "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) + class TestFailedException(Exception): pass + def test(cls,n): # TODO: test corner cases like 0,1,i P = cls.base() Q = cls() for i in xrange(n): + #print binascii.hexlify(Q.encode()) QQ = cls.decode(Q.encode()) if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) - if Q.encode() != Q.torque().encode(): - raise TestFailedException("Can't torque %s" % str(Q)) + + QT = Q + QE = Q.encode() + for h in xrange(cls.cofactor): + QT = QT.torque() + if QT.encode() != QE: + raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1)) Q0 = Q + P if Q0 == Q: raise TestFailedException("Addition doesn't work") @@ -220,5 +337,16 @@ def test(cls,n): Q2 = Q0*(r+1) if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") Q = Q1 +test(Ed25519Point,100) +test(TwistedEd448GoldilocksPoint,100) +test(Ed448GoldilocksPoint,100) +test(IsoEd448Point,100) - \ No newline at end of file + +def testElligator(cls,n): + for i in xrange(n): + cls.elligator(randombytes(cls.encLen)) +testElligator(Ed25519Point,100) +testElligator(Ed448GoldilocksPoint,100) +testElligator(TwistedEd448GoldilocksPoint,100) +testElligator(IsoEd448Point,100)