| @@ -7,6 +7,12 @@ def hibit(x): return lobit(2*x) | |||
| def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) | |||
| def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) | |||
| def isqrt(x,exn=InvalidEncodingException("Not on curve")): | |||
| """Return 1/sqrt(x)""" | |||
| if x==0: return 0 | |||
| if not is_square(x): raise exn | |||
| return 1/sqrt(x) | |||
| class EdwardsPoint(object): | |||
| """Abstract class for point an an Edwards curve; needs F,a,d to work""" | |||
| def __init__(self,x=0,y=1): | |||
| @@ -46,12 +52,17 @@ class EdwardsPoint(object): | |||
| work += work | |||
| exp >>= 1 | |||
| return total | |||
| def xyzt(self): | |||
| x,y = self | |||
| z = self.F.random_element() | |||
| return x*z,y*z,z,x*y*z | |||
| class Ed25519Point(EdwardsPoint): | |||
| F = GF(2^255-19) | |||
| d = F(-121665/121666) | |||
| a = F(-1) | |||
| i = sqrt(Ed25519Point.F(-1)) | |||
| i = sqrt(F(-1)) | |||
| @classmethod | |||
| def base(cls): | |||
| @@ -100,8 +111,78 @@ class RistrettoPoint(Ed25519Point): | |||
| if lobit(x*y) or x==0: | |||
| raise InvalidEncodingException("x*y has high bit") | |||
| return cls(x,y) | |||
| class OptimizedRistrettoPoint(RistrettoPoint): | |||
| magic = isqrt(RistrettoPoint.d+1) | |||
| """Like Ristretto but uses isqrt instead""" | |||
| @classmethod | |||
| def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs): | |||
| s = isqrt(isqrt*inv^2) | |||
| return s*inv, s^2*isqrt*inv | |||
| def encode(self): | |||
| right_answer = super(OptimizedRistrettoPoint,self).encode() | |||
| x,y,z,t = self.xyzt() | |||
| x *= self.i | |||
| u1 = (z+y)*(z-y) | |||
| u2 = x*y # = t*z | |||
| isr = isqrt(u1 * u2^2) | |||
| i1 = isr*u1 | |||
| i2 = isr*u2 | |||
| z_inv = i1*i2*t | |||
| rotate = lobit(t*self.i*z_inv) | |||
| if rotate: | |||
| x,y = y,x | |||
| den_inv = self.magic * i1 | |||
| else: | |||
| den_inv = i2 | |||
| if rotate ^^ lobit(x*z_inv): y = -y | |||
| s = (z-y) * den_inv | |||
| if s==0: s = F(1) | |||
| if lobit(s): s=-s | |||
| ret = enc_le(s,self.encLen) | |||
| assert ret == right_answer | |||
| return ret | |||
| @classmethod | |||
| def decode(cls,s): | |||
| right_answer = super(cls,OptimizedRistrettoPoint).decode(s) | |||
| # Sanity check s | |||
| if len(s) != cls.encLen: | |||
| raise InvalidEncodingException("wrong length %d" % len(s)) | |||
| s = dec_le(s) | |||
| if s < 0 or s >= cls.F.modulus() or lobit(s): | |||
| raise InvalidEncodingException("%d out of range!" % s) | |||
| s = cls.F(s) | |||
| yden = 1+s^2 | |||
| ynum = 1-s^2 | |||
| yden_sqr = yden^2 | |||
| xden_sqr = -cls.d*ynum^2 - yden_sqr | |||
| isr = isqrt(xden_sqr * yden_sqr) | |||
| xden_inv = isr * yden | |||
| yden_inv = xden_inv * isr * xden_sqr | |||
| x = 2*s*xden_inv | |||
| if lobit(x): x = -x | |||
| y = ynum * yden_inv | |||
| if lobit(x*y) or x==0: | |||
| raise InvalidEncodingException("x*y has high bit") | |||
| ret = cls(x,y) | |||
| assert ret == right_answer | |||
| return ret | |||
| class DecafPoint(Ed25519Point): | |||
| """Works like current decaf""" | |||