| @@ -7,6 +7,12 @@ def hibit(x): return lobit(2*x) | |||||
| def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) | def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) | ||||
| def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) | def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) | ||||
| def isqrt(x,exn=InvalidEncodingException("Not on curve")): | |||||
| """Return 1/sqrt(x)""" | |||||
| if x==0: return 0 | |||||
| if not is_square(x): raise exn | |||||
| return 1/sqrt(x) | |||||
| class EdwardsPoint(object): | class EdwardsPoint(object): | ||||
| """Abstract class for point an an Edwards curve; needs F,a,d to work""" | """Abstract class for point an an Edwards curve; needs F,a,d to work""" | ||||
| def __init__(self,x=0,y=1): | def __init__(self,x=0,y=1): | ||||
| @@ -46,12 +52,17 @@ class EdwardsPoint(object): | |||||
| work += work | work += work | ||||
| exp >>= 1 | exp >>= 1 | ||||
| return total | return total | ||||
| def xyzt(self): | |||||
| x,y = self | |||||
| z = self.F.random_element() | |||||
| return x*z,y*z,z,x*y*z | |||||
| class Ed25519Point(EdwardsPoint): | class Ed25519Point(EdwardsPoint): | ||||
| F = GF(2^255-19) | F = GF(2^255-19) | ||||
| d = F(-121665/121666) | d = F(-121665/121666) | ||||
| a = F(-1) | a = F(-1) | ||||
| i = sqrt(Ed25519Point.F(-1)) | |||||
| i = sqrt(F(-1)) | |||||
| @classmethod | @classmethod | ||||
| def base(cls): | def base(cls): | ||||
| @@ -100,8 +111,78 @@ class RistrettoPoint(Ed25519Point): | |||||
| if lobit(x*y) or x==0: | if lobit(x*y) or x==0: | ||||
| raise InvalidEncodingException("x*y has high bit") | raise InvalidEncodingException("x*y has high bit") | ||||
| return cls(x,y) | return cls(x,y) | ||||
| class OptimizedRistrettoPoint(RistrettoPoint): | |||||
| magic = isqrt(RistrettoPoint.d+1) | |||||
| """Like Ristretto but uses isqrt instead""" | |||||
| @classmethod | |||||
| def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs): | |||||
| s = isqrt(isqrt*inv^2) | |||||
| return s*inv, s^2*isqrt*inv | |||||
| def encode(self): | |||||
| right_answer = super(OptimizedRistrettoPoint,self).encode() | |||||
| x,y,z,t = self.xyzt() | |||||
| x *= self.i | |||||
| u1 = (z+y)*(z-y) | |||||
| u2 = x*y # = t*z | |||||
| isr = isqrt(u1 * u2^2) | |||||
| i1 = isr*u1 | |||||
| i2 = isr*u2 | |||||
| z_inv = i1*i2*t | |||||
| rotate = lobit(t*self.i*z_inv) | |||||
| if rotate: | |||||
| x,y = y,x | |||||
| den_inv = self.magic * i1 | |||||
| else: | |||||
| den_inv = i2 | |||||
| if rotate ^^ lobit(x*z_inv): y = -y | |||||
| s = (z-y) * den_inv | |||||
| if s==0: s = F(1) | |||||
| if lobit(s): s=-s | |||||
| ret = enc_le(s,self.encLen) | |||||
| assert ret == right_answer | |||||
| return ret | |||||
| @classmethod | |||||
| def decode(cls,s): | |||||
| right_answer = super(cls,OptimizedRistrettoPoint).decode(s) | |||||
| # Sanity check s | |||||
| if len(s) != cls.encLen: | |||||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||||
| s = dec_le(s) | |||||
| if s < 0 or s >= cls.F.modulus() or lobit(s): | |||||
| raise InvalidEncodingException("%d out of range!" % s) | |||||
| s = cls.F(s) | |||||
| yden = 1+s^2 | |||||
| ynum = 1-s^2 | |||||
| yden_sqr = yden^2 | |||||
| xden_sqr = -cls.d*ynum^2 - yden_sqr | |||||
| isr = isqrt(xden_sqr * yden_sqr) | |||||
| xden_inv = isr * yden | |||||
| yden_inv = xden_inv * isr * xden_sqr | |||||
| x = 2*s*xden_inv | |||||
| if lobit(x): x = -x | |||||
| y = ynum * yden_inv | |||||
| if lobit(x*y) or x==0: | |||||
| raise InvalidEncodingException("x*y has high bit") | |||||
| ret = cls(x,y) | |||||
| assert ret == right_answer | |||||
| return ret | |||||
| class DecafPoint(Ed25519Point): | class DecafPoint(Ed25519Point): | ||||
| """Works like current decaf""" | """Works like current decaf""" | ||||