Browse Source

ristretto work

master
Michael Hamburg 7 years ago
parent
commit
dd193a3ec5
1 changed files with 193 additions and 0 deletions
  1. +193
    -0
      aux/ristretto.sage

+ 193
- 0
aux/ristretto.sage View File

@@ -0,0 +1,193 @@

class InvalidEncodingException(Exception): pass
class NotOnCurveException(Exception): pass

def lobit(x): return int(x) & 1
def hibit(x): return lobit(2*x)
def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))

class EdwardsPoint(object):
"""Abstract class for point an an Edwards curve; needs F,a,d to work"""
def __init__(self,x=0,y=1):
x = self.x = self.F(x)
y = self.y = self.F(y)
if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
raise NotOnCurveException()

def __repr__(self):
return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y)

def __iter__(self):
yield self.x
yield self.y

def __add__(self,other):
x,y = self
X,Y = other
a,d = self.a,self.d
return self.__class__(
(x*Y+y*X)/(1+d*x*y*X*Y),
(y*Y-a*x*X)/(1-d*x*y*X*Y)
)
def __neg__(self): return self.__class__(-self.x,self.y)
def __sub__(self,other): return self + (-other)
def __rmul__(self,other): return self*other
def __eq__(self,other): return tuple(self) == tuple(other)
def __ne__(self,other): return not (self==other)
def __mul__(self,exp):
exp = int(exp)
total = self.__class__()
work = self
while exp != 0:
if exp & 1: total += work
work += work
exp >>= 1
return total

class Ed25519Point(EdwardsPoint):
F = GF(2^255-19)
d = F(-121665/121666)
a = F(-1)
i = sqrt(Ed25519Point.F(-1))
@classmethod
def base(cls):
y = cls.F(4/5)
x = sqrt((y^2-1)/(cls.d*y^2+1))
if lobit(x): x = -x
return cls(x,y)
def torque(self):
return self.__class__(self.y*self.i, self.x*self.i)

class RistrettoOption1Point(Ed25519Point):
"""Like current decaf but tweaked for simplicity"""
dMont = Ed25519Point.F(-121665)
encLen = 32
def __eq__(self,other):
x,y = self
X,Y = other
return x*Y == X*y or x*X == y*Y
def encode(self):
x,y = self
a,d = self.a,self.d
if x*y == 0:
# This happens anyway with straightforward impl
return enc_le(0,self.encLen)

if not is_square((1-y)/(1+y)):
raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
# Choose representative in 4-torsion group
if lobit(x*y): (x,y) = (self.i*y,self.i*x)
if lobit(x): x,y = -x,-y
s = sqrt((1-y)/(1+y))
if lobit(s): s = -s
return enc_le(s,self.encLen)
@classmethod
def decode(cls,s):
if len(s) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(s))
s = dec_le(s)
if s == 0: return cls(0,1)
if s < 0 or s >= cls.F.modulus() or lobit(s):
raise InvalidEncodingException("%d out of range!" % s)
s = cls.F(s)
magic = 4*cls.dMont-4
if not is_square(magic*s^2 / ((s^2-1)^2 - s^2 * magic)):
raise InvalidEncodingException("Not on curve")
x = sqrt(magic*s^2 / ((s^2-1)^2 - magic * s^2))
if lobit(x): x=-x
y = (1-s^2)/(1+s^2)
if lobit(x*y):
raise InvalidEncodingException("x*y has high bit")
return cls(x,y)

class RistrettoOption2Point(Ed25519Point):
"""Works like current decaf"""
dMont = Ed25519Point.F(-121665)
magic = sqrt(dMont-1)
encLen = 32
def __eq__(self,other):
x,y = self
X,Y = other
return x*Y == X*y or x*X == y*Y
def encode(self):
x,y = self
a,d = self.a,self.d
if x*y == 0:
# This will happen anyway with straightforward square root trick
return enc_le(0,self.encLen)

if not is_square((1-y)/(1+y)):
raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
# Choose representative in 4-torsion group
if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x)
if hibit(2*self.magic/x): x,y = -x,-y
s = sqrt((1-y)/(1+y))
if hibit(s): s = -s
return enc_le(s,self.encLen)
@classmethod
def decode(cls,s):
if len(s) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(s))
s = dec_le(s)
if s == 0: return cls(0,1)
if s < 0 or s >= (cls.F.modulus()+1)/2:
raise InvalidEncodingException("%d out of range!" % s)
s = cls.F(s)
if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1):
raise InvalidEncodingException("Not on curve")
t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s
if hibit(t): t = -t
y = (1-s^2)/(1+s^2)
x = 2*cls.magic/t
if y == 0 or lobit(t/y):
raise InvalidEncodingException("t/y has high bit")
return cls(x,y)

class TestFailedException(Exception): pass
def test(cls,n):
# TODO: test corner cases like 0,1,i
P = cls.base()
Q = cls()
for i in xrange(n):
QQ = cls.decode(Q.encode())
if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
if Q.encode() != Q.torque().encode():
raise TestFailedException("Can't torque %s" % str(Q))
Q0 = Q + P
if Q0 == Q: raise TestFailedException("Addition doesn't work")
if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
r = randint(1,1000)
Q1 = Q0*r
Q2 = Q0*(r+1)
if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
Q = Q1

Loading…
Cancel
Save