|
|
@@ -0,0 +1,193 @@ |
|
|
|
|
|
|
|
class InvalidEncodingException(Exception): pass |
|
|
|
class NotOnCurveException(Exception): pass |
|
|
|
|
|
|
|
def lobit(x): return int(x) & 1 |
|
|
|
def hibit(x): return lobit(2*x) |
|
|
|
def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) |
|
|
|
def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) |
|
|
|
|
|
|
|
class EdwardsPoint(object): |
|
|
|
"""Abstract class for point an an Edwards curve; needs F,a,d to work""" |
|
|
|
def __init__(self,x=0,y=1): |
|
|
|
x = self.x = self.F(x) |
|
|
|
y = self.y = self.F(y) |
|
|
|
if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2: |
|
|
|
raise NotOnCurveException() |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y) |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
yield self.x |
|
|
|
yield self.y |
|
|
|
|
|
|
|
def __add__(self,other): |
|
|
|
x,y = self |
|
|
|
X,Y = other |
|
|
|
a,d = self.a,self.d |
|
|
|
return self.__class__( |
|
|
|
(x*Y+y*X)/(1+d*x*y*X*Y), |
|
|
|
(y*Y-a*x*X)/(1-d*x*y*X*Y) |
|
|
|
) |
|
|
|
|
|
|
|
def __neg__(self): return self.__class__(-self.x,self.y) |
|
|
|
def __sub__(self,other): return self + (-other) |
|
|
|
def __rmul__(self,other): return self*other |
|
|
|
def __eq__(self,other): return tuple(self) == tuple(other) |
|
|
|
def __ne__(self,other): return not (self==other) |
|
|
|
|
|
|
|
def __mul__(self,exp): |
|
|
|
exp = int(exp) |
|
|
|
total = self.__class__() |
|
|
|
work = self |
|
|
|
while exp != 0: |
|
|
|
if exp & 1: total += work |
|
|
|
work += work |
|
|
|
exp >>= 1 |
|
|
|
return total |
|
|
|
|
|
|
|
class Ed25519Point(EdwardsPoint): |
|
|
|
F = GF(2^255-19) |
|
|
|
d = F(-121665/121666) |
|
|
|
a = F(-1) |
|
|
|
i = sqrt(Ed25519Point.F(-1)) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def base(cls): |
|
|
|
y = cls.F(4/5) |
|
|
|
x = sqrt((y^2-1)/(cls.d*y^2+1)) |
|
|
|
if lobit(x): x = -x |
|
|
|
return cls(x,y) |
|
|
|
|
|
|
|
def torque(self): |
|
|
|
return self.__class__(self.y*self.i, self.x*self.i) |
|
|
|
|
|
|
|
class RistrettoOption1Point(Ed25519Point): |
|
|
|
"""Like current decaf but tweaked for simplicity""" |
|
|
|
dMont = Ed25519Point.F(-121665) |
|
|
|
encLen = 32 |
|
|
|
|
|
|
|
def __eq__(self,other): |
|
|
|
x,y = self |
|
|
|
X,Y = other |
|
|
|
return x*Y == X*y or x*X == y*Y |
|
|
|
|
|
|
|
def encode(self): |
|
|
|
x,y = self |
|
|
|
a,d = self.a,self.d |
|
|
|
|
|
|
|
if x*y == 0: |
|
|
|
# This happens anyway with straightforward impl |
|
|
|
return enc_le(0,self.encLen) |
|
|
|
|
|
|
|
if not is_square((1-y)/(1+y)): |
|
|
|
raise Exception("Unimplemented: odd point in RistrettoPoint.encode") |
|
|
|
|
|
|
|
# Choose representative in 4-torsion group |
|
|
|
if lobit(x*y): (x,y) = (self.i*y,self.i*x) |
|
|
|
if lobit(x): x,y = -x,-y |
|
|
|
|
|
|
|
s = sqrt((1-y)/(1+y)) |
|
|
|
if lobit(s): s = -s |
|
|
|
return enc_le(s,self.encLen) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def decode(cls,s): |
|
|
|
if len(s) != cls.encLen: |
|
|
|
raise InvalidEncodingException("wrong length %d" % len(s)) |
|
|
|
s = dec_le(s) |
|
|
|
if s == 0: return cls(0,1) |
|
|
|
if s < 0 or s >= cls.F.modulus() or lobit(s): |
|
|
|
raise InvalidEncodingException("%d out of range!" % s) |
|
|
|
s = cls.F(s) |
|
|
|
|
|
|
|
magic = 4*cls.dMont-4 |
|
|
|
if not is_square(magic*s^2 / ((s^2-1)^2 - s^2 * magic)): |
|
|
|
raise InvalidEncodingException("Not on curve") |
|
|
|
|
|
|
|
x = sqrt(magic*s^2 / ((s^2-1)^2 - magic * s^2)) |
|
|
|
if lobit(x): x=-x |
|
|
|
y = (1-s^2)/(1+s^2) |
|
|
|
|
|
|
|
if lobit(x*y): |
|
|
|
raise InvalidEncodingException("x*y has high bit") |
|
|
|
|
|
|
|
return cls(x,y) |
|
|
|
|
|
|
|
class RistrettoOption2Point(Ed25519Point): |
|
|
|
"""Works like current decaf""" |
|
|
|
dMont = Ed25519Point.F(-121665) |
|
|
|
magic = sqrt(dMont-1) |
|
|
|
encLen = 32 |
|
|
|
|
|
|
|
def __eq__(self,other): |
|
|
|
x,y = self |
|
|
|
X,Y = other |
|
|
|
return x*Y == X*y or x*X == y*Y |
|
|
|
|
|
|
|
def encode(self): |
|
|
|
x,y = self |
|
|
|
a,d = self.a,self.d |
|
|
|
|
|
|
|
if x*y == 0: |
|
|
|
# This will happen anyway with straightforward square root trick |
|
|
|
return enc_le(0,self.encLen) |
|
|
|
|
|
|
|
if not is_square((1-y)/(1+y)): |
|
|
|
raise Exception("Unimplemented: odd point in RistrettoPoint.encode") |
|
|
|
|
|
|
|
# Choose representative in 4-torsion group |
|
|
|
if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x) |
|
|
|
if hibit(2*self.magic/x): x,y = -x,-y |
|
|
|
|
|
|
|
s = sqrt((1-y)/(1+y)) |
|
|
|
if hibit(s): s = -s |
|
|
|
return enc_le(s,self.encLen) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def decode(cls,s): |
|
|
|
if len(s) != cls.encLen: |
|
|
|
raise InvalidEncodingException("wrong length %d" % len(s)) |
|
|
|
s = dec_le(s) |
|
|
|
if s == 0: return cls(0,1) |
|
|
|
if s < 0 or s >= (cls.F.modulus()+1)/2: |
|
|
|
raise InvalidEncodingException("%d out of range!" % s) |
|
|
|
s = cls.F(s) |
|
|
|
|
|
|
|
if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1): |
|
|
|
raise InvalidEncodingException("Not on curve") |
|
|
|
|
|
|
|
t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s |
|
|
|
if hibit(t): t = -t |
|
|
|
|
|
|
|
y = (1-s^2)/(1+s^2) |
|
|
|
x = 2*cls.magic/t |
|
|
|
|
|
|
|
if y == 0 or lobit(t/y): |
|
|
|
raise InvalidEncodingException("t/y has high bit") |
|
|
|
|
|
|
|
return cls(x,y) |
|
|
|
|
|
|
|
class TestFailedException(Exception): pass |
|
|
|
def test(cls,n): |
|
|
|
# TODO: test corner cases like 0,1,i |
|
|
|
P = cls.base() |
|
|
|
Q = cls() |
|
|
|
for i in xrange(n): |
|
|
|
QQ = cls.decode(Q.encode()) |
|
|
|
if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q))) |
|
|
|
if Q.encode() != Q.torque().encode(): |
|
|
|
raise TestFailedException("Can't torque %s" % str(Q)) |
|
|
|
|
|
|
|
Q0 = Q + P |
|
|
|
if Q0 == Q: raise TestFailedException("Addition doesn't work") |
|
|
|
if Q0-P != Q: raise TestFailedException("Subtraction doesn't work") |
|
|
|
|
|
|
|
r = randint(1,1000) |
|
|
|
Q1 = Q0*r |
|
|
|
Q2 = Q0*(r+1) |
|
|
|
if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") |
|
|
|
Q = Q1 |
|
|
|
|
|
|
|
|