Browse Source

Ristretto for Ed448

master
Michael Hamburg 7 years ago
parent
commit
b55ac5ebd1
1 changed files with 53 additions and 91 deletions
  1. +53
    -91
      aux/ristretto.sage

+ 53
- 91
aux/ristretto.sage View File

@@ -22,7 +22,7 @@ class EdwardsPoint(object):
raise NotOnCurveException()

def __repr__(self):
return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y)
return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)

def __iter__(self):
yield self.x
@@ -58,25 +58,8 @@ class EdwardsPoint(object):
z = self.F.random_element()
return x*z,y*z,z,x*y*z

class Ed25519Point(EdwardsPoint):
F = GF(2^255-19)
d = F(-121665/121666)
a = F(-1)
i = sqrt(F(-1))
@classmethod
def base(cls):
y = cls.F(4/5)
x = sqrt((y^2-1)/(cls.d*y^2+1))
if lobit(x): x = -x
return cls(x,y)
def torque(self):
return self.__class__(self.y*self.i, self.x*self.i)

class RistrettoPoint(Ed25519Point):
class RistrettoPoint(EdwardsPoint):
"""Like current decaf but tweaked for simplicity"""
encLen = 32
def __eq__(self,other):
x,y = self
@@ -90,15 +73,22 @@ class RistrettoPoint(Ed25519Point):
if negative(s): s=-s
return s
def encode(self):
def encodeSpec(self):
"""Unoptimized specification for encoding"""
x,y = self
if lobit(x*y) or x==0: (x,y) = (self.i*y,self.i*x)
if lobit(x): x,y = -x,-y
if self.cofactor==8 and (lobit(x*y) or x==0):
(x,y) = (self.i*y,self.i*x)
elif self.cofactor==4 and y==-1:
y = 1 # Doesn't affect impl
if lobit(x): y=-y
s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even"))
return enc_le(s,self.encLen)
@classmethod
def decode(cls,s):
def decodeSpec(cls,s):
"""Unoptimized specification for decoding"""
if len(s) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(s))
s = dec_le(s)
@@ -109,24 +99,13 @@ class RistrettoPoint(Ed25519Point):
x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2))
y = (1-s^2) / (1+s^2)
if lobit(x*y) or x==0:
if cls.cofactor==8 and (lobit(x*y) or x==0):
raise InvalidEncodingException("x*y has high bit")
return cls(x,y)
class OptimizedRistrettoPoint(RistrettoPoint):
magic = isqrt(RistrettoPoint.d+1)
"""Like Ristretto but uses isqrt instead"""
@classmethod
def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs):
s = isqrt(isqrt*inv^2)
return s*inv, s^2*isqrt*inv
def encode(self):
right_answer = super(OptimizedRistrettoPoint,self).encode()
x,y,z,t = self.xyzt()
x *= self.i

u1 = (z+y)*(z-y)
u2 = x*y # = t*z
@@ -135,25 +114,26 @@ class OptimizedRistrettoPoint(RistrettoPoint):
i2 = isr*u2
z_inv = i1*i2*t

rotate = lobit(t*self.i*z_inv)
rotate = self.cofactor==8 and lobit(t*z_inv)
if rotate:
x,y = y,x
den_inv = self.magic * i1
magic = isqrt(-self.d-1)
x,y = y*self.i,x*self.i
den_inv = magic * i1
else:
den_inv = i2

if rotate ^^ lobit(x*z_inv): y = -y
if lobit(x*z_inv): y = -y
s = (z-y) * den_inv
if s==0: s = F(1)
if self.cofactor==8 and s==0: s += 1
if lobit(s): s=-s
ret = enc_le(s,self.encLen)
assert ret == right_answer
assert ret == self.encodeSpec()
return ret
@classmethod
def decode(cls,s):
right_answer = super(cls,OptimizedRistrettoPoint).decode(s)
right_answer = cls.decodeSpec(s)
# Sanity check s
if len(s) != cls.encLen:
@@ -177,65 +157,47 @@ class OptimizedRistrettoPoint(RistrettoPoint):
if lobit(x): x = -x
y = ynum * yden_inv
if lobit(x*y) or x==0:
if cls.cofactor==8 and (lobit(x*y) or x==0):
raise InvalidEncodingException("x*y has high bit")
ret = cls(x,y)
assert ret == right_answer
return ret
def torque(self):
if self.cofactor == 8:
return self.__class__(self.y*self.i, self.x*self.i)
else:
return self.__class__(-self.x, -self.y)

class DecafPoint(Ed25519Point):
"""Works like current decaf"""
dMont = Ed25519Point.F(-121665)
magic = sqrt(dMont-1)
class Ed25519Point(RistrettoPoint):
F = GF(2^255-19)
d = F(-121665/121666)
a = F(-1)
i = sqrt(F(-1))
cofactor = 8
encLen = 32
def __eq__(self,other):
x,y = self
X,Y = other
return x*Y == X*y or x*X == y*Y
def encode(self):
x,y = self
a,d = self.a,self.d
if x*y == 0:
# This will happen anyway with straightforward square root trick
return enc_le(0,self.encLen)

if not is_square((1-y)/(1+y)):
raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
# Choose representative in 4-torsion group
if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x)
if hibit(2*self.magic/x): x,y = -x,-y
s = sqrt((1-y)/(1+y))
if hibit(s): s = -s
return enc_le(s,self.encLen)
@classmethod
def decode(cls,s):
if len(s) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(s))
s = dec_le(s)
if s == 0: return cls(0,1)
if s < 0 or s >= (cls.F.modulus()+1)/2:
raise InvalidEncodingException("%d out of range!" % s)
s = cls.F(s)
if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1):
raise InvalidEncodingException("Not on curve")
t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s
if hibit(t): t = -t
y = (1-s^2)/(1+s^2)
x = 2*cls.magic/t
def base(cls):
y = cls.F(4/5)
x = sqrt((y^2-1)/(cls.d*y^2+1))
if lobit(x): x = -x
return cls(x,y)

class Ed448Point(RistrettoPoint):
F = GF(2^448-2^224-1)
d = F(-39082)
a = F(-1)
cofactor = 4
encLen = 56
if y == 0 or lobit(t/y):
raise InvalidEncodingException("t/y has high bit")
@classmethod
def base(cls):
y = cls.F(6) # FIXME: no it isn't
x = sqrt((y^2-1)/(cls.d*y^2+1))
if lobit(x): x = -x
return cls(x,y)

class TestFailedException(Exception): pass


Loading…
Cancel
Save