| @@ -22,7 +22,7 @@ class EdwardsPoint(object): | |||
| raise NotOnCurveException() | |||
| def __repr__(self): | |||
| return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y) | |||
| return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) | |||
| def __iter__(self): | |||
| yield self.x | |||
| @@ -58,25 +58,8 @@ class EdwardsPoint(object): | |||
| z = self.F.random_element() | |||
| return x*z,y*z,z,x*y*z | |||
| class Ed25519Point(EdwardsPoint): | |||
| F = GF(2^255-19) | |||
| d = F(-121665/121666) | |||
| a = F(-1) | |||
| i = sqrt(F(-1)) | |||
| @classmethod | |||
| def base(cls): | |||
| y = cls.F(4/5) | |||
| x = sqrt((y^2-1)/(cls.d*y^2+1)) | |||
| if lobit(x): x = -x | |||
| return cls(x,y) | |||
| def torque(self): | |||
| return self.__class__(self.y*self.i, self.x*self.i) | |||
| class RistrettoPoint(Ed25519Point): | |||
| class RistrettoPoint(EdwardsPoint): | |||
| """Like current decaf but tweaked for simplicity""" | |||
| encLen = 32 | |||
| def __eq__(self,other): | |||
| x,y = self | |||
| @@ -90,15 +73,22 @@ class RistrettoPoint(Ed25519Point): | |||
| if negative(s): s=-s | |||
| return s | |||
| def encode(self): | |||
| def encodeSpec(self): | |||
| """Unoptimized specification for encoding""" | |||
| x,y = self | |||
| if lobit(x*y) or x==0: (x,y) = (self.i*y,self.i*x) | |||
| if lobit(x): x,y = -x,-y | |||
| if self.cofactor==8 and (lobit(x*y) or x==0): | |||
| (x,y) = (self.i*y,self.i*x) | |||
| elif self.cofactor==4 and y==-1: | |||
| y = 1 # Doesn't affect impl | |||
| if lobit(x): y=-y | |||
| s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even")) | |||
| return enc_le(s,self.encLen) | |||
| @classmethod | |||
| def decode(cls,s): | |||
| def decodeSpec(cls,s): | |||
| """Unoptimized specification for decoding""" | |||
| if len(s) != cls.encLen: | |||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||
| s = dec_le(s) | |||
| @@ -109,24 +99,13 @@ class RistrettoPoint(Ed25519Point): | |||
| x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2)) | |||
| y = (1-s^2) / (1+s^2) | |||
| if lobit(x*y) or x==0: | |||
| if cls.cofactor==8 and (lobit(x*y) or x==0): | |||
| raise InvalidEncodingException("x*y has high bit") | |||
| return cls(x,y) | |||
| class OptimizedRistrettoPoint(RistrettoPoint): | |||
| magic = isqrt(RistrettoPoint.d+1) | |||
| """Like Ristretto but uses isqrt instead""" | |||
| @classmethod | |||
| def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs): | |||
| s = isqrt(isqrt*inv^2) | |||
| return s*inv, s^2*isqrt*inv | |||
| def encode(self): | |||
| right_answer = super(OptimizedRistrettoPoint,self).encode() | |||
| x,y,z,t = self.xyzt() | |||
| x *= self.i | |||
| u1 = (z+y)*(z-y) | |||
| u2 = x*y # = t*z | |||
| @@ -135,25 +114,26 @@ class OptimizedRistrettoPoint(RistrettoPoint): | |||
| i2 = isr*u2 | |||
| z_inv = i1*i2*t | |||
| rotate = lobit(t*self.i*z_inv) | |||
| rotate = self.cofactor==8 and lobit(t*z_inv) | |||
| if rotate: | |||
| x,y = y,x | |||
| den_inv = self.magic * i1 | |||
| magic = isqrt(-self.d-1) | |||
| x,y = y*self.i,x*self.i | |||
| den_inv = magic * i1 | |||
| else: | |||
| den_inv = i2 | |||
| if rotate ^^ lobit(x*z_inv): y = -y | |||
| if lobit(x*z_inv): y = -y | |||
| s = (z-y) * den_inv | |||
| if s==0: s = F(1) | |||
| if self.cofactor==8 and s==0: s += 1 | |||
| if lobit(s): s=-s | |||
| ret = enc_le(s,self.encLen) | |||
| assert ret == right_answer | |||
| assert ret == self.encodeSpec() | |||
| return ret | |||
| @classmethod | |||
| def decode(cls,s): | |||
| right_answer = super(cls,OptimizedRistrettoPoint).decode(s) | |||
| right_answer = cls.decodeSpec(s) | |||
| # Sanity check s | |||
| if len(s) != cls.encLen: | |||
| @@ -177,65 +157,47 @@ class OptimizedRistrettoPoint(RistrettoPoint): | |||
| if lobit(x): x = -x | |||
| y = ynum * yden_inv | |||
| if lobit(x*y) or x==0: | |||
| if cls.cofactor==8 and (lobit(x*y) or x==0): | |||
| raise InvalidEncodingException("x*y has high bit") | |||
| ret = cls(x,y) | |||
| assert ret == right_answer | |||
| return ret | |||
| def torque(self): | |||
| if self.cofactor == 8: | |||
| return self.__class__(self.y*self.i, self.x*self.i) | |||
| else: | |||
| return self.__class__(-self.x, -self.y) | |||
| class DecafPoint(Ed25519Point): | |||
| """Works like current decaf""" | |||
| dMont = Ed25519Point.F(-121665) | |||
| magic = sqrt(dMont-1) | |||
| class Ed25519Point(RistrettoPoint): | |||
| F = GF(2^255-19) | |||
| d = F(-121665/121666) | |||
| a = F(-1) | |||
| i = sqrt(F(-1)) | |||
| cofactor = 8 | |||
| encLen = 32 | |||
| def __eq__(self,other): | |||
| x,y = self | |||
| X,Y = other | |||
| return x*Y == X*y or x*X == y*Y | |||
| def encode(self): | |||
| x,y = self | |||
| a,d = self.a,self.d | |||
| if x*y == 0: | |||
| # This will happen anyway with straightforward square root trick | |||
| return enc_le(0,self.encLen) | |||
| if not is_square((1-y)/(1+y)): | |||
| raise Exception("Unimplemented: odd point in RistrettoPoint.encode") | |||
| # Choose representative in 4-torsion group | |||
| if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x) | |||
| if hibit(2*self.magic/x): x,y = -x,-y | |||
| s = sqrt((1-y)/(1+y)) | |||
| if hibit(s): s = -s | |||
| return enc_le(s,self.encLen) | |||
| @classmethod | |||
| def decode(cls,s): | |||
| if len(s) != cls.encLen: | |||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||
| s = dec_le(s) | |||
| if s == 0: return cls(0,1) | |||
| if s < 0 or s >= (cls.F.modulus()+1)/2: | |||
| raise InvalidEncodingException("%d out of range!" % s) | |||
| s = cls.F(s) | |||
| if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1): | |||
| raise InvalidEncodingException("Not on curve") | |||
| t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s | |||
| if hibit(t): t = -t | |||
| y = (1-s^2)/(1+s^2) | |||
| x = 2*cls.magic/t | |||
| def base(cls): | |||
| y = cls.F(4/5) | |||
| x = sqrt((y^2-1)/(cls.d*y^2+1)) | |||
| if lobit(x): x = -x | |||
| return cls(x,y) | |||
| class Ed448Point(RistrettoPoint): | |||
| F = GF(2^448-2^224-1) | |||
| d = F(-39082) | |||
| a = F(-1) | |||
| cofactor = 4 | |||
| encLen = 56 | |||
| if y == 0 or lobit(t/y): | |||
| raise InvalidEncodingException("t/y has high bit") | |||
| @classmethod | |||
| def base(cls): | |||
| y = cls.F(6) # FIXME: no it isn't | |||
| x = sqrt((y^2-1)/(cls.d*y^2+1)) | |||
| if lobit(x): x = -x | |||
| return cls(x,y) | |||
| class TestFailedException(Exception): pass | |||