|
|
@@ -7,6 +7,12 @@ def hibit(x): return lobit(2*x) |
|
|
|
def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) |
|
|
|
def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) |
|
|
|
|
|
|
|
def isqrt(x,exn=InvalidEncodingException("Not on curve")): |
|
|
|
"""Return 1/sqrt(x)""" |
|
|
|
if x==0: return 0 |
|
|
|
if not is_square(x): raise exn |
|
|
|
return 1/sqrt(x) |
|
|
|
|
|
|
|
class EdwardsPoint(object): |
|
|
|
"""Abstract class for point an an Edwards curve; needs F,a,d to work""" |
|
|
|
def __init__(self,x=0,y=1): |
|
|
@@ -46,12 +52,17 @@ class EdwardsPoint(object): |
|
|
|
work += work |
|
|
|
exp >>= 1 |
|
|
|
return total |
|
|
|
|
|
|
|
def xyzt(self): |
|
|
|
x,y = self |
|
|
|
z = self.F.random_element() |
|
|
|
return x*z,y*z,z,x*y*z |
|
|
|
|
|
|
|
class Ed25519Point(EdwardsPoint): |
|
|
|
F = GF(2^255-19) |
|
|
|
d = F(-121665/121666) |
|
|
|
a = F(-1) |
|
|
|
i = sqrt(Ed25519Point.F(-1)) |
|
|
|
i = sqrt(F(-1)) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def base(cls): |
|
|
@@ -100,8 +111,78 @@ class RistrettoPoint(Ed25519Point): |
|
|
|
|
|
|
|
if lobit(x*y) or x==0: |
|
|
|
raise InvalidEncodingException("x*y has high bit") |
|
|
|
|
|
|
|
|
|
|
|
return cls(x,y) |
|
|
|
|
|
|
|
class OptimizedRistrettoPoint(RistrettoPoint): |
|
|
|
magic = isqrt(RistrettoPoint.d+1) |
|
|
|
|
|
|
|
"""Like Ristretto but uses isqrt instead""" |
|
|
|
@classmethod |
|
|
|
def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs): |
|
|
|
s = isqrt(isqrt*inv^2) |
|
|
|
return s*inv, s^2*isqrt*inv |
|
|
|
|
|
|
|
def encode(self): |
|
|
|
right_answer = super(OptimizedRistrettoPoint,self).encode() |
|
|
|
x,y,z,t = self.xyzt() |
|
|
|
x *= self.i |
|
|
|
|
|
|
|
u1 = (z+y)*(z-y) |
|
|
|
u2 = x*y # = t*z |
|
|
|
isr = isqrt(u1 * u2^2) |
|
|
|
i1 = isr*u1 |
|
|
|
i2 = isr*u2 |
|
|
|
z_inv = i1*i2*t |
|
|
|
|
|
|
|
rotate = lobit(t*self.i*z_inv) |
|
|
|
if rotate: |
|
|
|
x,y = y,x |
|
|
|
den_inv = self.magic * i1 |
|
|
|
else: |
|
|
|
den_inv = i2 |
|
|
|
|
|
|
|
if rotate ^^ lobit(x*z_inv): y = -y |
|
|
|
s = (z-y) * den_inv |
|
|
|
if s==0: s = F(1) |
|
|
|
if lobit(s): s=-s |
|
|
|
|
|
|
|
ret = enc_le(s,self.encLen) |
|
|
|
assert ret == right_answer |
|
|
|
return ret |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def decode(cls,s): |
|
|
|
right_answer = super(cls,OptimizedRistrettoPoint).decode(s) |
|
|
|
|
|
|
|
# Sanity check s |
|
|
|
if len(s) != cls.encLen: |
|
|
|
raise InvalidEncodingException("wrong length %d" % len(s)) |
|
|
|
s = dec_le(s) |
|
|
|
if s < 0 or s >= cls.F.modulus() or lobit(s): |
|
|
|
raise InvalidEncodingException("%d out of range!" % s) |
|
|
|
s = cls.F(s) |
|
|
|
|
|
|
|
yden = 1+s^2 |
|
|
|
ynum = 1-s^2 |
|
|
|
yden_sqr = yden^2 |
|
|
|
xden_sqr = -cls.d*ynum^2 - yden_sqr |
|
|
|
|
|
|
|
isr = isqrt(xden_sqr * yden_sqr) |
|
|
|
|
|
|
|
xden_inv = isr * yden |
|
|
|
yden_inv = xden_inv * isr * xden_sqr |
|
|
|
|
|
|
|
x = 2*s*xden_inv |
|
|
|
if lobit(x): x = -x |
|
|
|
y = ynum * yden_inv |
|
|
|
|
|
|
|
if lobit(x*y) or x==0: |
|
|
|
raise InvalidEncodingException("x*y has high bit") |
|
|
|
|
|
|
|
ret = cls(x,y) |
|
|
|
assert ret == right_answer |
|
|
|
return ret |
|
|
|
|
|
|
|
class DecafPoint(Ed25519Point): |
|
|
|
"""Works like current decaf""" |
|
|
|