| @@ -1,11 +1,42 @@ | |||
| import binascii | |||
| class InvalidEncodingException(Exception): pass | |||
| class NotOnCurveException(Exception): pass | |||
| class SpecException(Exception): pass | |||
| def lobit(x): return int(x) & 1 | |||
| def hibit(x): return lobit(2*x) | |||
| def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) | |||
| def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) | |||
| def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) | |||
| def optimized_version_of(spec): | |||
| def decorator(f): | |||
| def wrapper(self,*args,**kwargs): | |||
| try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None | |||
| except Exception as e: spec_ans = None,e | |||
| try: opt_ans = f(self,*args,**kwargs),None | |||
| except Exception as e: opt_ans = None,e | |||
| if spec_ans[1] is None and opt_ans[1] is not None: | |||
| raise SpecException("Mismatch in %s: spec returned %s but opt threw %s" | |||
| % (f.__name__,str(spec_ans[0]),str(opt_ans[1]))) | |||
| if spec_ans[1] is not None and opt_ans[1] is None: | |||
| raise SpecException("Mismatch in %s: spec threw %s but opt returned %s" | |||
| % (f.__name__,str(spec_ans[1]),str(opt_ans[0]))) | |||
| if spec_ans[0] != opt_ans[0]: | |||
| raise SpecException("Mismatch in %s: %s != %s" | |||
| % (f.__name__,str(spec_ans[0]),str(opt_ans[0]))) | |||
| if opt_ans[1] is not None: raise opt_ans[1] | |||
| else: return opt_ans[0] | |||
| wrapper.__name__ = f.__name__ | |||
| return wrapper | |||
| return decorator | |||
| def xsqrt(x,exn=InvalidEncodingException("Not on curve")): | |||
| """Return sqrt(x)""" | |||
| if not is_square(x): raise exn | |||
| s = sqrt(x) | |||
| if lobit(s): s=-s | |||
| return s | |||
| def isqrt(x,exn=InvalidEncodingException("Not on curve")): | |||
| """Return 1/sqrt(x)""" | |||
| @@ -13,13 +44,21 @@ def isqrt(x,exn=InvalidEncodingException("Not on curve")): | |||
| if not is_square(x): raise exn | |||
| return 1/sqrt(x) | |||
| def isqrt_i(x): | |||
| """Return 1/sqrt(x) or 1/sqrt(zeta * x)""" | |||
| if x==0: return 0 | |||
| gen = x.parent(-1) | |||
| while is_square(gen): gen = sqrt(gen) | |||
| if is_square(x): return True,1/sqrt(x) | |||
| else: return False,1/sqrt(x*gen) | |||
| class EdwardsPoint(object): | |||
| """Abstract class for point an an Edwards curve; needs F,a,d to work""" | |||
| def __init__(self,x=0,y=1): | |||
| x = self.x = self.F(x) | |||
| y = self.y = self.F(y) | |||
| if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: | |||
| raise NotOnCurveException() | |||
| raise NotOnCurveException(str(self)) | |||
| def __repr__(self): | |||
| return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) | |||
| @@ -57,96 +96,97 @@ class EdwardsPoint(object): | |||
| x,y = self | |||
| z = self.F.random_element() | |||
| return x*z,y*z,z,x*y*z | |||
| def torque(self): | |||
| """Apply cofactor group, except keeping the point even""" | |||
| if self.cofactor == 8: | |||
| return self.__class__(self.y*self.i, self.x*self.i) | |||
| else: | |||
| return self.__class__(-self.x, -self.y) | |||
| class RistrettoPoint(EdwardsPoint): | |||
| """Like current decaf but tweaked for simplicity""" | |||
| def __eq__(self,other): | |||
| x,y = self | |||
| X,Y = other | |||
| return x*Y == X*y or x*X == y*Y | |||
| @staticmethod | |||
| def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")): | |||
| if not is_square(x): raise exn | |||
| s = sqrt(x) | |||
| if negative(s): s=-s | |||
| return s | |||
| @classmethod | |||
| def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False): | |||
| """Convert little-endian bytes to field element, sanity check length""" | |||
| if len(bytes) != cls.encLen: | |||
| raise InvalidEncodingException("wrong length %d" % len(bytes)) | |||
| s = dec_le(bytes) | |||
| if mustBeProper and s >= cls.F.modulus(): | |||
| raise InvalidEncodingException("%d out of range!" % s) | |||
| if mustBePositive and lobit(s): | |||
| raise InvalidEncodingException("%d is negative!" % s) | |||
| return cls.F(s) | |||
| def encodeSpec(self): | |||
| """Unoptimized specification for encoding""" | |||
| x,y = self | |||
| if self.cofactor==8 and (lobit(x*y) or x==0): | |||
| if self.cofactor==8 and (lobit(x*y) or y==0): | |||
| (x,y) = (self.i*y,self.i*x) | |||
| elif self.cofactor==4 and y==-1: | |||
| y = 1 # Doesn't affect impl | |||
| if lobit(x): y=-y | |||
| s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even")) | |||
| if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl | |||
| if lobit(x): x,y = -x,-y | |||
| s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self))) | |||
| return enc_le(s,self.encLen) | |||
| @classmethod | |||
| def decodeSpec(cls,s): | |||
| """Unoptimized specification for decoding""" | |||
| if len(s) != cls.encLen: | |||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||
| s = dec_le(s) | |||
| if s < 0 or s >= cls.F.modulus() or lobit(s): | |||
| raise InvalidEncodingException("%d out of range!" % s) | |||
| s = cls.F(s) | |||
| s = cls.bytesToGf(s,mustBePositive=True) | |||
| x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2)) | |||
| y = (1-s^2) / (1+s^2) | |||
| a,d = cls.a,cls.d | |||
| x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2)) | |||
| y = (1+a*s^2) / (1-a*s^2) | |||
| if cls.cofactor==8 and (lobit(x*y) or x==0): | |||
| if cls.cofactor==8 and (lobit(x*y) or y==0): | |||
| raise InvalidEncodingException("x*y has high bit") | |||
| return cls(x,y) | |||
| @optimized_version_of("encodeSpec") | |||
| def encode(self): | |||
| """Encode, optimized version""" | |||
| a,d = self.a,self.d | |||
| x,y,z,t = self.xyzt() | |||
| u1 = (z+y)*(z-y) | |||
| u1 = a*(y+z)*(y-z) | |||
| u2 = x*y # = t*z | |||
| isr = isqrt(u1 * u2^2) | |||
| isr = isqrt(u1*u2^2) | |||
| i1 = isr*u1 | |||
| i2 = isr*u2 | |||
| z_inv = i1*i2*t | |||
| rotate = self.cofactor==8 and lobit(t*z_inv) | |||
| if rotate: | |||
| magic = isqrt(-self.d-1) | |||
| x,y = y*self.i,x*self.i | |||
| den_inv = magic * i1 | |||
| den_inv = self.magic * i1 | |||
| else: | |||
| den_inv = i2 | |||
| if lobit(x*z_inv): y = -y | |||
| s = (z-y) * den_inv | |||
| if self.cofactor==8 and s==0: s += 1 | |||
| if lobit(s): s=-s | |||
| ret = enc_le(s,self.encLen) | |||
| assert ret == self.encodeSpec() | |||
| return ret | |||
| return enc_le(s,self.encLen) | |||
| @classmethod | |||
| @optimized_version_of("decodeSpec") | |||
| def decode(cls,s): | |||
| right_answer = cls.decodeSpec(s) | |||
| """Decode, optimized version""" | |||
| s = cls.bytesToGf(s,mustBePositive=True) | |||
| # Sanity check s | |||
| if len(s) != cls.encLen: | |||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||
| s = dec_le(s) | |||
| if s < 0 or s >= cls.F.modulus() or lobit(s): | |||
| raise InvalidEncodingException("%d out of range!" % s) | |||
| s = cls.F(s) | |||
| yden = 1+s^2 | |||
| ynum = 1-s^2 | |||
| a,d = cls.a,cls.d | |||
| yden = 1-a*s^2 | |||
| ynum = 1+a*s^2 | |||
| yden_sqr = yden^2 | |||
| xden_sqr = -cls.d*ynum^2 - yden_sqr | |||
| xden_sqr = a*d*ynum^2 - yden_sqr | |||
| isr = isqrt(xden_sqr * yden_sqr) | |||
| @@ -157,25 +197,60 @@ class RistrettoPoint(EdwardsPoint): | |||
| if lobit(x): x = -x | |||
| y = ynum * yden_inv | |||
| if cls.cofactor==8 and (lobit(x*y) or x==0): | |||
| raise InvalidEncodingException("x*y has high bit") | |||
| if cls.cofactor==8 and (lobit(x*y) or y==0): | |||
| raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) | |||
| ret = cls(x,y) | |||
| assert ret == right_answer | |||
| return ret | |||
| def torque(self): | |||
| if self.cofactor == 8: | |||
| return self.__class__(self.y*self.i, self.x*self.i) | |||
| return cls(x,y) | |||
| @classmethod | |||
| def fromJacobiQuartic(cls,s,t,sgn=1): | |||
| """Convert point from its Jacobi Quartic representation""" | |||
| a,d = cls.a,cls.d | |||
| assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2 | |||
| x = 2*s*cls.magic / t | |||
| if lobit(x): x = -x # TODO: doesn't work without resolving x | |||
| y = (1+a*s^2) / (1-a*s^2) | |||
| return cls(sgn*x,y) | |||
| @classmethod | |||
| def elligatorSpec(cls,r0): | |||
| a,d = cls.a,cls.d | |||
| r = cls.qnr * cls.bytesToGf(r0)^2 | |||
| den = (d*r-a)*(a*r-d) | |||
| n1 = cls.a*(r+1)*(a+d)*(d-a)/den | |||
| n2 = r*n1 | |||
| if is_square(n1): | |||
| sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1 | |||
| else: | |||
| return self.__class__(-self.x, -self.y) | |||
| sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 | |||
| ret = cls.fromJacobiQuartic(s,t,sgn) | |||
| return ret | |||
| @classmethod | |||
| @optimized_version_of("elligatorSpec") | |||
| def elligator(cls,r0): | |||
| a,d = cls.a,cls.d | |||
| r0 = cls.bytesToGf(r0) | |||
| r = cls.qnr * r0^2 | |||
| den = (d*r-a)*(a*r-d) | |||
| num = cls.a*(r+1)*(a+d)*(d-a) | |||
| iss,isri = isqrt_i(num*den) | |||
| if iss: sgn,twiddle = 1,1 | |||
| else: sgn,twiddle = -1,r0*cls.qnr | |||
| isri *= twiddle | |||
| s = isri*num | |||
| t = isri*s*(r-1)*(d+a)^2 + sgn | |||
| return cls.fromJacobiQuartic(s,t,sgn) | |||
| class Ed25519Point(RistrettoPoint): | |||
| F = GF(2^255-19) | |||
| d = F(-121665/121666) | |||
| a = F(-1) | |||
| i = sqrt(F(-1)) | |||
| qnr = i | |||
| magic = isqrt(a*d-1) | |||
| cofactor = 8 | |||
| encLen = 32 | |||
| @@ -186,30 +261,72 @@ class Ed25519Point(RistrettoPoint): | |||
| if lobit(x): x = -x | |||
| return cls(x,y) | |||
| class Ed448Point(RistrettoPoint): | |||
| class TwistedEd448GoldilocksPoint(RistrettoPoint): | |||
| F = GF(2^448-2^224-1) | |||
| d = F(-39082) | |||
| a = F(-1) | |||
| qnr = -1 | |||
| magic = isqrt(a*d-1) | |||
| cofactor = 4 | |||
| encLen = 56 | |||
| @classmethod | |||
| def base(cls): | |||
| y = cls.F(6) # FIXME: no it isn't | |||
| y = cls.F(6) # TODO: no it isn't | |||
| x = sqrt((y^2-1)/(cls.d*y^2+1)) | |||
| if lobit(x): x = -x | |||
| return cls(x,y) | |||
| class Ed448GoldilocksPoint(RistrettoPoint): | |||
| # TODO: decaf vs ristretto | |||
| F = GF(2^448-2^224-1) | |||
| d = F(-39081) | |||
| a = F(1) | |||
| qnr = -1 | |||
| magic = isqrt(a*d-1) | |||
| cofactor = 4 | |||
| encLen = 56 | |||
| @classmethod | |||
| def base(cls): | |||
| return cls( | |||
| 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555, | |||
| 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed | |||
| ) | |||
| class IsoEd448Point(RistrettoPoint): | |||
| F = GF(2^448-2^224-1) | |||
| d = F(1/39081+1) | |||
| a = F(1) | |||
| qnr = -1 | |||
| magic = isqrt(a*d-1) | |||
| cofactor = 4 | |||
| encLen = 56 | |||
| @classmethod | |||
| def base(cls): | |||
| # = ..., -3/2 | |||
| return cls.decodeSpec(bytearray(binascii.unhexlify( | |||
| "00000000000000000000000000000000000000000000000000000000"+ | |||
| "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) | |||
| class TestFailedException(Exception): pass | |||
| def test(cls,n): | |||
| # TODO: test corner cases like 0,1,i | |||
| P = cls.base() | |||
| Q = cls() | |||
| for i in xrange(n): | |||
| #print binascii.hexlify(Q.encode()) | |||
| QQ = cls.decode(Q.encode()) | |||
| if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) | |||
| if Q.encode() != Q.torque().encode(): | |||
| raise TestFailedException("Can't torque %s" % str(Q)) | |||
| QT = Q | |||
| QE = Q.encode() | |||
| for h in xrange(cls.cofactor): | |||
| QT = QT.torque() | |||
| if QT.encode() != QE: | |||
| raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1)) | |||
| Q0 = Q + P | |||
| if Q0 == Q: raise TestFailedException("Addition doesn't work") | |||
| @@ -220,5 +337,16 @@ def test(cls,n): | |||
| Q2 = Q0*(r+1) | |||
| if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") | |||
| Q = Q1 | |||
| test(Ed25519Point,100) | |||
| test(TwistedEd448GoldilocksPoint,100) | |||
| test(Ed448GoldilocksPoint,100) | |||
| test(IsoEd448Point,100) | |||
| def testElligator(cls,n): | |||
| for i in xrange(n): | |||
| cls.elligator(randombytes(cls.encLen)) | |||
| testElligator(Ed25519Point,100) | |||
| testElligator(Ed448GoldilocksPoint,100) | |||
| testElligator(TwistedEd448GoldilocksPoint,100) | |||
| testElligator(IsoEd448Point,100) | |||