diff --git a/aux/ristretto.sage b/aux/ristretto.sage index 501704c..6d032b4 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -22,7 +22,7 @@ class EdwardsPoint(object): raise NotOnCurveException() def __repr__(self): - return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y) + return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) def __iter__(self): yield self.x @@ -58,25 +58,8 @@ class EdwardsPoint(object): z = self.F.random_element() return x*z,y*z,z,x*y*z -class Ed25519Point(EdwardsPoint): - F = GF(2^255-19) - d = F(-121665/121666) - a = F(-1) - i = sqrt(F(-1)) - - @classmethod - def base(cls): - y = cls.F(4/5) - x = sqrt((y^2-1)/(cls.d*y^2+1)) - if lobit(x): x = -x - return cls(x,y) - - def torque(self): - return self.__class__(self.y*self.i, self.x*self.i) - -class RistrettoPoint(Ed25519Point): +class RistrettoPoint(EdwardsPoint): """Like current decaf but tweaked for simplicity""" - encLen = 32 def __eq__(self,other): x,y = self @@ -90,15 +73,22 @@ class RistrettoPoint(Ed25519Point): if negative(s): s=-s return s - def encode(self): + def encodeSpec(self): + """Unoptimized specification for encoding""" x,y = self - if lobit(x*y) or x==0: (x,y) = (self.i*y,self.i*x) - if lobit(x): x,y = -x,-y + if self.cofactor==8 and (lobit(x*y) or x==0): + (x,y) = (self.i*y,self.i*x) + elif self.cofactor==4 and y==-1: + y = 1 # Doesn't affect impl + + if lobit(x): y=-y s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even")) + return enc_le(s,self.encLen) @classmethod - def decode(cls,s): + def decodeSpec(cls,s): + """Unoptimized specification for decoding""" if len(s) != cls.encLen: raise InvalidEncodingException("wrong length %d" % len(s)) s = dec_le(s) @@ -109,24 +99,13 @@ class RistrettoPoint(Ed25519Point): x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2)) y = (1-s^2) / (1+s^2) - if lobit(x*y) or x==0: + if cls.cofactor==8 and (lobit(x*y) or x==0): raise InvalidEncodingException("x*y has high bit") return cls(x,y) -class OptimizedRistrettoPoint(RistrettoPoint): - magic = isqrt(RistrettoPoint.d+1) - - """Like Ristretto but uses isqrt instead""" - @classmethod - def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs): - s = isqrt(isqrt*inv^2) - return s*inv, s^2*isqrt*inv - def encode(self): - right_answer = super(OptimizedRistrettoPoint,self).encode() x,y,z,t = self.xyzt() - x *= self.i u1 = (z+y)*(z-y) u2 = x*y # = t*z @@ -135,25 +114,26 @@ class OptimizedRistrettoPoint(RistrettoPoint): i2 = isr*u2 z_inv = i1*i2*t - rotate = lobit(t*self.i*z_inv) + rotate = self.cofactor==8 and lobit(t*z_inv) if rotate: - x,y = y,x - den_inv = self.magic * i1 + magic = isqrt(-self.d-1) + x,y = y*self.i,x*self.i + den_inv = magic * i1 else: den_inv = i2 - if rotate ^^ lobit(x*z_inv): y = -y + if lobit(x*z_inv): y = -y s = (z-y) * den_inv - if s==0: s = F(1) + if self.cofactor==8 and s==0: s += 1 if lobit(s): s=-s ret = enc_le(s,self.encLen) - assert ret == right_answer + assert ret == self.encodeSpec() return ret @classmethod def decode(cls,s): - right_answer = super(cls,OptimizedRistrettoPoint).decode(s) + right_answer = cls.decodeSpec(s) # Sanity check s if len(s) != cls.encLen: @@ -177,65 +157,47 @@ class OptimizedRistrettoPoint(RistrettoPoint): if lobit(x): x = -x y = ynum * yden_inv - if lobit(x*y) or x==0: + if cls.cofactor==8 and (lobit(x*y) or x==0): raise InvalidEncodingException("x*y has high bit") ret = cls(x,y) assert ret == right_answer return ret + + def torque(self): + if self.cofactor == 8: + return self.__class__(self.y*self.i, self.x*self.i) + else: + return self.__class__(-self.x, -self.y) + -class DecafPoint(Ed25519Point): - """Works like current decaf""" - dMont = Ed25519Point.F(-121665) - magic = sqrt(dMont-1) +class Ed25519Point(RistrettoPoint): + F = GF(2^255-19) + d = F(-121665/121666) + a = F(-1) + i = sqrt(F(-1)) + cofactor = 8 encLen = 32 - def __eq__(self,other): - x,y = self - X,Y = other - return x*Y == X*y or x*X == y*Y - - def encode(self): - x,y = self - a,d = self.a,self.d - - if x*y == 0: - # This will happen anyway with straightforward square root trick - return enc_le(0,self.encLen) - - if not is_square((1-y)/(1+y)): - raise Exception("Unimplemented: odd point in RistrettoPoint.encode") - - # Choose representative in 4-torsion group - if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x) - if hibit(2*self.magic/x): x,y = -x,-y - - s = sqrt((1-y)/(1+y)) - if hibit(s): s = -s - return enc_le(s,self.encLen) - @classmethod - def decode(cls,s): - if len(s) != cls.encLen: - raise InvalidEncodingException("wrong length %d" % len(s)) - s = dec_le(s) - if s == 0: return cls(0,1) - if s < 0 or s >= (cls.F.modulus()+1)/2: - raise InvalidEncodingException("%d out of range!" % s) - s = cls.F(s) - - if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1): - raise InvalidEncodingException("Not on curve") - - t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s - if hibit(t): t = -t - - y = (1-s^2)/(1+s^2) - x = 2*cls.magic/t + def base(cls): + y = cls.F(4/5) + x = sqrt((y^2-1)/(cls.d*y^2+1)) + if lobit(x): x = -x + return cls(x,y) + +class Ed448Point(RistrettoPoint): + F = GF(2^448-2^224-1) + d = F(-39082) + a = F(-1) + cofactor = 4 + encLen = 56 - if y == 0 or lobit(t/y): - raise InvalidEncodingException("t/y has high bit") - + @classmethod + def base(cls): + y = cls.F(6) # FIXME: no it isn't + x = sqrt((y^2-1)/(cls.d*y^2+1)) + if lobit(x): x = -x return cls(x,y) class TestFailedException(Exception): pass