| @@ -1,11 +1,42 @@ | |||||
| import binascii | |||||
| class InvalidEncodingException(Exception): pass | class InvalidEncodingException(Exception): pass | ||||
| class NotOnCurveException(Exception): pass | class NotOnCurveException(Exception): pass | ||||
| class SpecException(Exception): pass | |||||
| def lobit(x): return int(x) & 1 | def lobit(x): return int(x) & 1 | ||||
| def hibit(x): return lobit(2*x) | def hibit(x): return lobit(2*x) | ||||
| def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) | def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) | ||||
| def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) | def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) | ||||
| def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) | |||||
| def optimized_version_of(spec): | |||||
| def decorator(f): | |||||
| def wrapper(self,*args,**kwargs): | |||||
| try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None | |||||
| except Exception as e: spec_ans = None,e | |||||
| try: opt_ans = f(self,*args,**kwargs),None | |||||
| except Exception as e: opt_ans = None,e | |||||
| if spec_ans[1] is None and opt_ans[1] is not None: | |||||
| raise SpecException("Mismatch in %s: spec returned %s but opt threw %s" | |||||
| % (f.__name__,str(spec_ans[0]),str(opt_ans[1]))) | |||||
| if spec_ans[1] is not None and opt_ans[1] is None: | |||||
| raise SpecException("Mismatch in %s: spec threw %s but opt returned %s" | |||||
| % (f.__name__,str(spec_ans[1]),str(opt_ans[0]))) | |||||
| if spec_ans[0] != opt_ans[0]: | |||||
| raise SpecException("Mismatch in %s: %s != %s" | |||||
| % (f.__name__,str(spec_ans[0]),str(opt_ans[0]))) | |||||
| if opt_ans[1] is not None: raise opt_ans[1] | |||||
| else: return opt_ans[0] | |||||
| wrapper.__name__ = f.__name__ | |||||
| return wrapper | |||||
| return decorator | |||||
| def xsqrt(x,exn=InvalidEncodingException("Not on curve")): | |||||
| """Return sqrt(x)""" | |||||
| if not is_square(x): raise exn | |||||
| s = sqrt(x) | |||||
| if lobit(s): s=-s | |||||
| return s | |||||
| def isqrt(x,exn=InvalidEncodingException("Not on curve")): | def isqrt(x,exn=InvalidEncodingException("Not on curve")): | ||||
| """Return 1/sqrt(x)""" | """Return 1/sqrt(x)""" | ||||
| @@ -13,13 +44,21 @@ def isqrt(x,exn=InvalidEncodingException("Not on curve")): | |||||
| if not is_square(x): raise exn | if not is_square(x): raise exn | ||||
| return 1/sqrt(x) | return 1/sqrt(x) | ||||
| def isqrt_i(x): | |||||
| """Return 1/sqrt(x) or 1/sqrt(zeta * x)""" | |||||
| if x==0: return 0 | |||||
| gen = x.parent(-1) | |||||
| while is_square(gen): gen = sqrt(gen) | |||||
| if is_square(x): return True,1/sqrt(x) | |||||
| else: return False,1/sqrt(x*gen) | |||||
| class EdwardsPoint(object): | class EdwardsPoint(object): | ||||
| """Abstract class for point an an Edwards curve; needs F,a,d to work""" | """Abstract class for point an an Edwards curve; needs F,a,d to work""" | ||||
| def __init__(self,x=0,y=1): | def __init__(self,x=0,y=1): | ||||
| x = self.x = self.F(x) | x = self.x = self.F(x) | ||||
| y = self.y = self.F(y) | y = self.y = self.F(y) | ||||
| if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: | if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: | ||||
| raise NotOnCurveException() | |||||
| raise NotOnCurveException(str(self)) | |||||
| def __repr__(self): | def __repr__(self): | ||||
| return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) | return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y) | ||||
| @@ -57,96 +96,97 @@ class EdwardsPoint(object): | |||||
| x,y = self | x,y = self | ||||
| z = self.F.random_element() | z = self.F.random_element() | ||||
| return x*z,y*z,z,x*y*z | return x*z,y*z,z,x*y*z | ||||
| def torque(self): | |||||
| """Apply cofactor group, except keeping the point even""" | |||||
| if self.cofactor == 8: | |||||
| return self.__class__(self.y*self.i, self.x*self.i) | |||||
| else: | |||||
| return self.__class__(-self.x, -self.y) | |||||
| class RistrettoPoint(EdwardsPoint): | class RistrettoPoint(EdwardsPoint): | ||||
| """Like current decaf but tweaked for simplicity""" | """Like current decaf but tweaked for simplicity""" | ||||
| def __eq__(self,other): | def __eq__(self,other): | ||||
| x,y = self | x,y = self | ||||
| X,Y = other | X,Y = other | ||||
| return x*Y == X*y or x*X == y*Y | return x*Y == X*y or x*X == y*Y | ||||
| @staticmethod | |||||
| def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")): | |||||
| if not is_square(x): raise exn | |||||
| s = sqrt(x) | |||||
| if negative(s): s=-s | |||||
| return s | |||||
| @classmethod | |||||
| def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False): | |||||
| """Convert little-endian bytes to field element, sanity check length""" | |||||
| if len(bytes) != cls.encLen: | |||||
| raise InvalidEncodingException("wrong length %d" % len(bytes)) | |||||
| s = dec_le(bytes) | |||||
| if mustBeProper and s >= cls.F.modulus(): | |||||
| raise InvalidEncodingException("%d out of range!" % s) | |||||
| if mustBePositive and lobit(s): | |||||
| raise InvalidEncodingException("%d is negative!" % s) | |||||
| return cls.F(s) | |||||
| def encodeSpec(self): | def encodeSpec(self): | ||||
| """Unoptimized specification for encoding""" | """Unoptimized specification for encoding""" | ||||
| x,y = self | x,y = self | ||||
| if self.cofactor==8 and (lobit(x*y) or x==0): | |||||
| if self.cofactor==8 and (lobit(x*y) or y==0): | |||||
| (x,y) = (self.i*y,self.i*x) | (x,y) = (self.i*y,self.i*x) | ||||
| elif self.cofactor==4 and y==-1: | |||||
| y = 1 # Doesn't affect impl | |||||
| if lobit(x): y=-y | |||||
| s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even")) | |||||
| if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl | |||||
| if lobit(x): x,y = -x,-y | |||||
| s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self))) | |||||
| return enc_le(s,self.encLen) | return enc_le(s,self.encLen) | ||||
| @classmethod | @classmethod | ||||
| def decodeSpec(cls,s): | def decodeSpec(cls,s): | ||||
| """Unoptimized specification for decoding""" | """Unoptimized specification for decoding""" | ||||
| if len(s) != cls.encLen: | |||||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||||
| s = dec_le(s) | |||||
| if s < 0 or s >= cls.F.modulus() or lobit(s): | |||||
| raise InvalidEncodingException("%d out of range!" % s) | |||||
| s = cls.F(s) | |||||
| s = cls.bytesToGf(s,mustBePositive=True) | |||||
| x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2)) | |||||
| y = (1-s^2) / (1+s^2) | |||||
| a,d = cls.a,cls.d | |||||
| x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2)) | |||||
| y = (1+a*s^2) / (1-a*s^2) | |||||
| if cls.cofactor==8 and (lobit(x*y) or x==0): | |||||
| if cls.cofactor==8 and (lobit(x*y) or y==0): | |||||
| raise InvalidEncodingException("x*y has high bit") | raise InvalidEncodingException("x*y has high bit") | ||||
| return cls(x,y) | return cls(x,y) | ||||
| @optimized_version_of("encodeSpec") | |||||
| def encode(self): | def encode(self): | ||||
| """Encode, optimized version""" | |||||
| a,d = self.a,self.d | |||||
| x,y,z,t = self.xyzt() | x,y,z,t = self.xyzt() | ||||
| u1 = (z+y)*(z-y) | |||||
| u1 = a*(y+z)*(y-z) | |||||
| u2 = x*y # = t*z | u2 = x*y # = t*z | ||||
| isr = isqrt(u1 * u2^2) | |||||
| isr = isqrt(u1*u2^2) | |||||
| i1 = isr*u1 | i1 = isr*u1 | ||||
| i2 = isr*u2 | i2 = isr*u2 | ||||
| z_inv = i1*i2*t | z_inv = i1*i2*t | ||||
| rotate = self.cofactor==8 and lobit(t*z_inv) | rotate = self.cofactor==8 and lobit(t*z_inv) | ||||
| if rotate: | if rotate: | ||||
| magic = isqrt(-self.d-1) | |||||
| x,y = y*self.i,x*self.i | x,y = y*self.i,x*self.i | ||||
| den_inv = magic * i1 | |||||
| den_inv = self.magic * i1 | |||||
| else: | else: | ||||
| den_inv = i2 | den_inv = i2 | ||||
| if lobit(x*z_inv): y = -y | if lobit(x*z_inv): y = -y | ||||
| s = (z-y) * den_inv | s = (z-y) * den_inv | ||||
| if self.cofactor==8 and s==0: s += 1 | |||||
| if lobit(s): s=-s | if lobit(s): s=-s | ||||
| ret = enc_le(s,self.encLen) | |||||
| assert ret == self.encodeSpec() | |||||
| return ret | |||||
| return enc_le(s,self.encLen) | |||||
| @classmethod | @classmethod | ||||
| @optimized_version_of("decodeSpec") | |||||
| def decode(cls,s): | def decode(cls,s): | ||||
| right_answer = cls.decodeSpec(s) | |||||
| """Decode, optimized version""" | |||||
| s = cls.bytesToGf(s,mustBePositive=True) | |||||
| # Sanity check s | |||||
| if len(s) != cls.encLen: | |||||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||||
| s = dec_le(s) | |||||
| if s < 0 or s >= cls.F.modulus() or lobit(s): | |||||
| raise InvalidEncodingException("%d out of range!" % s) | |||||
| s = cls.F(s) | |||||
| yden = 1+s^2 | |||||
| ynum = 1-s^2 | |||||
| a,d = cls.a,cls.d | |||||
| yden = 1-a*s^2 | |||||
| ynum = 1+a*s^2 | |||||
| yden_sqr = yden^2 | yden_sqr = yden^2 | ||||
| xden_sqr = -cls.d*ynum^2 - yden_sqr | |||||
| xden_sqr = a*d*ynum^2 - yden_sqr | |||||
| isr = isqrt(xden_sqr * yden_sqr) | isr = isqrt(xden_sqr * yden_sqr) | ||||
| @@ -157,25 +197,60 @@ class RistrettoPoint(EdwardsPoint): | |||||
| if lobit(x): x = -x | if lobit(x): x = -x | ||||
| y = ynum * yden_inv | y = ynum * yden_inv | ||||
| if cls.cofactor==8 and (lobit(x*y) or x==0): | |||||
| raise InvalidEncodingException("x*y has high bit") | |||||
| if cls.cofactor==8 and (lobit(x*y) or y==0): | |||||
| raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) | |||||
| ret = cls(x,y) | |||||
| assert ret == right_answer | |||||
| return ret | |||||
| def torque(self): | |||||
| if self.cofactor == 8: | |||||
| return self.__class__(self.y*self.i, self.x*self.i) | |||||
| return cls(x,y) | |||||
| @classmethod | |||||
| def fromJacobiQuartic(cls,s,t,sgn=1): | |||||
| """Convert point from its Jacobi Quartic representation""" | |||||
| a,d = cls.a,cls.d | |||||
| assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2 | |||||
| x = 2*s*cls.magic / t | |||||
| if lobit(x): x = -x # TODO: doesn't work without resolving x | |||||
| y = (1+a*s^2) / (1-a*s^2) | |||||
| return cls(sgn*x,y) | |||||
| @classmethod | |||||
| def elligatorSpec(cls,r0): | |||||
| a,d = cls.a,cls.d | |||||
| r = cls.qnr * cls.bytesToGf(r0)^2 | |||||
| den = (d*r-a)*(a*r-d) | |||||
| n1 = cls.a*(r+1)*(a+d)*(d-a)/den | |||||
| n2 = r*n1 | |||||
| if is_square(n1): | |||||
| sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1 | |||||
| else: | else: | ||||
| return self.__class__(-self.x, -self.y) | |||||
| sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1 | |||||
| ret = cls.fromJacobiQuartic(s,t,sgn) | |||||
| return ret | |||||
| @classmethod | |||||
| @optimized_version_of("elligatorSpec") | |||||
| def elligator(cls,r0): | |||||
| a,d = cls.a,cls.d | |||||
| r0 = cls.bytesToGf(r0) | |||||
| r = cls.qnr * r0^2 | |||||
| den = (d*r-a)*(a*r-d) | |||||
| num = cls.a*(r+1)*(a+d)*(d-a) | |||||
| iss,isri = isqrt_i(num*den) | |||||
| if iss: sgn,twiddle = 1,1 | |||||
| else: sgn,twiddle = -1,r0*cls.qnr | |||||
| isri *= twiddle | |||||
| s = isri*num | |||||
| t = isri*s*(r-1)*(d+a)^2 + sgn | |||||
| return cls.fromJacobiQuartic(s,t,sgn) | |||||
| class Ed25519Point(RistrettoPoint): | class Ed25519Point(RistrettoPoint): | ||||
| F = GF(2^255-19) | F = GF(2^255-19) | ||||
| d = F(-121665/121666) | d = F(-121665/121666) | ||||
| a = F(-1) | a = F(-1) | ||||
| i = sqrt(F(-1)) | i = sqrt(F(-1)) | ||||
| qnr = i | |||||
| magic = isqrt(a*d-1) | |||||
| cofactor = 8 | cofactor = 8 | ||||
| encLen = 32 | encLen = 32 | ||||
| @@ -186,30 +261,72 @@ class Ed25519Point(RistrettoPoint): | |||||
| if lobit(x): x = -x | if lobit(x): x = -x | ||||
| return cls(x,y) | return cls(x,y) | ||||
| class Ed448Point(RistrettoPoint): | |||||
| class TwistedEd448GoldilocksPoint(RistrettoPoint): | |||||
| F = GF(2^448-2^224-1) | F = GF(2^448-2^224-1) | ||||
| d = F(-39082) | d = F(-39082) | ||||
| a = F(-1) | a = F(-1) | ||||
| qnr = -1 | |||||
| magic = isqrt(a*d-1) | |||||
| cofactor = 4 | cofactor = 4 | ||||
| encLen = 56 | encLen = 56 | ||||
| @classmethod | @classmethod | ||||
| def base(cls): | def base(cls): | ||||
| y = cls.F(6) # FIXME: no it isn't | |||||
| y = cls.F(6) # TODO: no it isn't | |||||
| x = sqrt((y^2-1)/(cls.d*y^2+1)) | x = sqrt((y^2-1)/(cls.d*y^2+1)) | ||||
| if lobit(x): x = -x | if lobit(x): x = -x | ||||
| return cls(x,y) | return cls(x,y) | ||||
| class Ed448GoldilocksPoint(RistrettoPoint): | |||||
| # TODO: decaf vs ristretto | |||||
| F = GF(2^448-2^224-1) | |||||
| d = F(-39081) | |||||
| a = F(1) | |||||
| qnr = -1 | |||||
| magic = isqrt(a*d-1) | |||||
| cofactor = 4 | |||||
| encLen = 56 | |||||
| @classmethod | |||||
| def base(cls): | |||||
| return cls( | |||||
| 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555, | |||||
| 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed | |||||
| ) | |||||
| class IsoEd448Point(RistrettoPoint): | |||||
| F = GF(2^448-2^224-1) | |||||
| d = F(1/39081+1) | |||||
| a = F(1) | |||||
| qnr = -1 | |||||
| magic = isqrt(a*d-1) | |||||
| cofactor = 4 | |||||
| encLen = 56 | |||||
| @classmethod | |||||
| def base(cls): | |||||
| # = ..., -3/2 | |||||
| return cls.decodeSpec(bytearray(binascii.unhexlify( | |||||
| "00000000000000000000000000000000000000000000000000000000"+ | |||||
| "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) | |||||
| class TestFailedException(Exception): pass | class TestFailedException(Exception): pass | ||||
| def test(cls,n): | def test(cls,n): | ||||
| # TODO: test corner cases like 0,1,i | # TODO: test corner cases like 0,1,i | ||||
| P = cls.base() | P = cls.base() | ||||
| Q = cls() | Q = cls() | ||||
| for i in xrange(n): | for i in xrange(n): | ||||
| #print binascii.hexlify(Q.encode()) | |||||
| QQ = cls.decode(Q.encode()) | QQ = cls.decode(Q.encode()) | ||||
| if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) | if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) | ||||
| if Q.encode() != Q.torque().encode(): | |||||
| raise TestFailedException("Can't torque %s" % str(Q)) | |||||
| QT = Q | |||||
| QE = Q.encode() | |||||
| for h in xrange(cls.cofactor): | |||||
| QT = QT.torque() | |||||
| if QT.encode() != QE: | |||||
| raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1)) | |||||
| Q0 = Q + P | Q0 = Q + P | ||||
| if Q0 == Q: raise TestFailedException("Addition doesn't work") | if Q0 == Q: raise TestFailedException("Addition doesn't work") | ||||
| @@ -220,5 +337,16 @@ def test(cls,n): | |||||
| Q2 = Q0*(r+1) | Q2 = Q0*(r+1) | ||||
| if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") | if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") | ||||
| Q = Q1 | Q = Q1 | ||||
| test(Ed25519Point,100) | |||||
| test(TwistedEd448GoldilocksPoint,100) | |||||
| test(Ed448GoldilocksPoint,100) | |||||
| test(IsoEd448Point,100) | |||||
| def testElligator(cls,n): | |||||
| for i in xrange(n): | |||||
| cls.elligator(randombytes(cls.encLen)) | |||||
| testElligator(Ed25519Point,100) | |||||
| testElligator(Ed448GoldilocksPoint,100) | |||||
| testElligator(TwistedEd448GoldilocksPoint,100) | |||||
| testElligator(IsoEd448Point,100) | |||||