Browse Source

working on ristretto sage files

master
Michael Hamburg 7 years ago
parent
commit
4a5aad2241
1 changed files with 190 additions and 62 deletions
  1. +190
    -62
      aux/ristretto.sage

+ 190
- 62
aux/ristretto.sage View File

@@ -1,11 +1,42 @@
import binascii
class InvalidEncodingException(Exception): pass
class NotOnCurveException(Exception): pass
class SpecException(Exception): pass

def lobit(x): return int(x) & 1
def hibit(x): return lobit(2*x)
def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])

def optimized_version_of(spec):
def decorator(f):
def wrapper(self,*args,**kwargs):
try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
except Exception as e: spec_ans = None,e
try: opt_ans = f(self,*args,**kwargs),None
except Exception as e: opt_ans = None,e
if spec_ans[1] is None and opt_ans[1] is not None:
raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
% (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
if spec_ans[1] is not None and opt_ans[1] is None:
raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
% (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
if spec_ans[0] != opt_ans[0]:
raise SpecException("Mismatch in %s: %s != %s"
% (f.__name__,str(spec_ans[0]),str(opt_ans[0])))
if opt_ans[1] is not None: raise opt_ans[1]
else: return opt_ans[0]
wrapper.__name__ = f.__name__
return wrapper
return decorator
def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
"""Return sqrt(x)"""
if not is_square(x): raise exn
s = sqrt(x)
if lobit(s): s=-s
return s

def isqrt(x,exn=InvalidEncodingException("Not on curve")):
"""Return 1/sqrt(x)"""
@@ -13,13 +44,21 @@ def isqrt(x,exn=InvalidEncodingException("Not on curve")):
if not is_square(x): raise exn
return 1/sqrt(x)

def isqrt_i(x):
"""Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
if x==0: return 0
gen = x.parent(-1)
while is_square(gen): gen = sqrt(gen)
if is_square(x): return True,1/sqrt(x)
else: return False,1/sqrt(x*gen)

class EdwardsPoint(object):
"""Abstract class for point an an Edwards curve; needs F,a,d to work"""
def __init__(self,x=0,y=1):
x = self.x = self.F(x)
y = self.y = self.F(y)
if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
raise NotOnCurveException()
raise NotOnCurveException(str(self))

def __repr__(self):
return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
@@ -57,96 +96,97 @@ class EdwardsPoint(object):
x,y = self
z = self.F.random_element()
return x*z,y*z,z,x*y*z
def torque(self):
"""Apply cofactor group, except keeping the point even"""
if self.cofactor == 8:
return self.__class__(self.y*self.i, self.x*self.i)
else:
return self.__class__(-self.x, -self.y)

class RistrettoPoint(EdwardsPoint):
"""Like current decaf but tweaked for simplicity"""
def __eq__(self,other):
x,y = self
X,Y = other
return x*Y == X*y or x*X == y*Y
@staticmethod
def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")):
if not is_square(x): raise exn
s = sqrt(x)
if negative(s): s=-s
return s
@classmethod
def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
"""Convert little-endian bytes to field element, sanity check length"""
if len(bytes) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(bytes))
s = dec_le(bytes)
if mustBeProper and s >= cls.F.modulus():
raise InvalidEncodingException("%d out of range!" % s)
if mustBePositive and lobit(s):
raise InvalidEncodingException("%d is negative!" % s)
return cls.F(s)
def encodeSpec(self):
"""Unoptimized specification for encoding"""
x,y = self
if self.cofactor==8 and (lobit(x*y) or x==0):
if self.cofactor==8 and (lobit(x*y) or y==0):
(x,y) = (self.i*y,self.i*x)
elif self.cofactor==4 and y==-1:
y = 1 # Doesn't affect impl
if lobit(x): y=-y
s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even"))
if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
if lobit(x): x,y = -x,-y
s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self)))
return enc_le(s,self.encLen)
@classmethod
def decodeSpec(cls,s):
"""Unoptimized specification for decoding"""
if len(s) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(s))
s = dec_le(s)
if s < 0 or s >= cls.F.modulus() or lobit(s):
raise InvalidEncodingException("%d out of range!" % s)
s = cls.F(s)
s = cls.bytesToGf(s,mustBePositive=True)
x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2))
y = (1-s^2) / (1+s^2)
a,d = cls.a,cls.d
x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
y = (1+a*s^2) / (1-a*s^2)
if cls.cofactor==8 and (lobit(x*y) or x==0):
if cls.cofactor==8 and (lobit(x*y) or y==0):
raise InvalidEncodingException("x*y has high bit")
return cls(x,y)

@optimized_version_of("encodeSpec")
def encode(self):
"""Encode, optimized version"""
a,d = self.a,self.d
x,y,z,t = self.xyzt()

u1 = (z+y)*(z-y)
u1 = a*(y+z)*(y-z)
u2 = x*y # = t*z
isr = isqrt(u1 * u2^2)
isr = isqrt(u1*u2^2)
i1 = isr*u1
i2 = isr*u2
z_inv = i1*i2*t

rotate = self.cofactor==8 and lobit(t*z_inv)
if rotate:
magic = isqrt(-self.d-1)
x,y = y*self.i,x*self.i
den_inv = magic * i1
den_inv = self.magic * i1
else:
den_inv = i2

if lobit(x*z_inv): y = -y
s = (z-y) * den_inv
if self.cofactor==8 and s==0: s += 1
if lobit(s): s=-s
ret = enc_le(s,self.encLen)
assert ret == self.encodeSpec()
return ret
return enc_le(s,self.encLen)
@classmethod
@optimized_version_of("decodeSpec")
def decode(cls,s):
right_answer = cls.decodeSpec(s)
"""Decode, optimized version"""
s = cls.bytesToGf(s,mustBePositive=True)
# Sanity check s
if len(s) != cls.encLen:
raise InvalidEncodingException("wrong length %d" % len(s))
s = dec_le(s)
if s < 0 or s >= cls.F.modulus() or lobit(s):
raise InvalidEncodingException("%d out of range!" % s)
s = cls.F(s)
yden = 1+s^2
ynum = 1-s^2
a,d = cls.a,cls.d
yden = 1-a*s^2
ynum = 1+a*s^2
yden_sqr = yden^2
xden_sqr = -cls.d*ynum^2 - yden_sqr
xden_sqr = a*d*ynum^2 - yden_sqr
isr = isqrt(xden_sqr * yden_sqr)
@@ -157,25 +197,60 @@ class RistrettoPoint(EdwardsPoint):
if lobit(x): x = -x
y = ynum * yden_inv
if cls.cofactor==8 and (lobit(x*y) or x==0):
raise InvalidEncodingException("x*y has high bit")
if cls.cofactor==8 and (lobit(x*y) or y==0):
raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
ret = cls(x,y)
assert ret == right_answer
return ret
def torque(self):
if self.cofactor == 8:
return self.__class__(self.y*self.i, self.x*self.i)
return cls(x,y)
@classmethod
def fromJacobiQuartic(cls,s,t,sgn=1):
"""Convert point from its Jacobi Quartic representation"""
a,d = cls.a,cls.d
assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
x = 2*s*cls.magic / t
if lobit(x): x = -x # TODO: doesn't work without resolving x
y = (1+a*s^2) / (1-a*s^2)
return cls(sgn*x,y)
@classmethod
def elligatorSpec(cls,r0):
a,d = cls.a,cls.d
r = cls.qnr * cls.bytesToGf(r0)^2
den = (d*r-a)*(a*r-d)
n1 = cls.a*(r+1)*(a+d)*(d-a)/den
n2 = r*n1
if is_square(n1):
sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
else:
return self.__class__(-self.x, -self.y)
sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
ret = cls.fromJacobiQuartic(s,t,sgn)
return ret

@classmethod
@optimized_version_of("elligatorSpec")
def elligator(cls,r0):
a,d = cls.a,cls.d
r0 = cls.bytesToGf(r0)
r = cls.qnr * r0^2
den = (d*r-a)*(a*r-d)
num = cls.a*(r+1)*(a+d)*(d-a)
iss,isri = isqrt_i(num*den)
if iss: sgn,twiddle = 1,1
else: sgn,twiddle = -1,r0*cls.qnr
isri *= twiddle
s = isri*num
t = isri*s*(r-1)*(d+a)^2 + sgn
return cls.fromJacobiQuartic(s,t,sgn)
class Ed25519Point(RistrettoPoint):
F = GF(2^255-19)
d = F(-121665/121666)
a = F(-1)
i = sqrt(F(-1))
qnr = i
magic = isqrt(a*d-1)
cofactor = 8
encLen = 32
@@ -186,30 +261,72 @@ class Ed25519Point(RistrettoPoint):
if lobit(x): x = -x
return cls(x,y)

class Ed448Point(RistrettoPoint):
class TwistedEd448GoldilocksPoint(RistrettoPoint):
F = GF(2^448-2^224-1)
d = F(-39082)
a = F(-1)
qnr = -1
magic = isqrt(a*d-1)
cofactor = 4
encLen = 56
@classmethod
def base(cls):
y = cls.F(6) # FIXME: no it isn't
y = cls.F(6) # TODO: no it isn't
x = sqrt((y^2-1)/(cls.d*y^2+1))
if lobit(x): x = -x
return cls(x,y)

class Ed448GoldilocksPoint(RistrettoPoint):
# TODO: decaf vs ristretto
F = GF(2^448-2^224-1)
d = F(-39081)
a = F(1)
qnr = -1
magic = isqrt(a*d-1)
cofactor = 4
encLen = 56
@classmethod
def base(cls):
return cls(
0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555,
0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed
)

class IsoEd448Point(RistrettoPoint):
F = GF(2^448-2^224-1)
d = F(1/39081+1)
a = F(1)
qnr = -1
magic = isqrt(a*d-1)
cofactor = 4
encLen = 56
@classmethod
def base(cls):
# = ..., -3/2
return cls.decodeSpec(bytearray(binascii.unhexlify(
"00000000000000000000000000000000000000000000000000000000"+
"fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))

class TestFailedException(Exception): pass

def test(cls,n):
# TODO: test corner cases like 0,1,i
P = cls.base()
Q = cls()
for i in xrange(n):
#print binascii.hexlify(Q.encode())
QQ = cls.decode(Q.encode())
if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
if Q.encode() != Q.torque().encode():
raise TestFailedException("Can't torque %s" % str(Q))
QT = Q
QE = Q.encode()
for h in xrange(cls.cofactor):
QT = QT.torque()
if QT.encode() != QE:
raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
Q0 = Q + P
if Q0 == Q: raise TestFailedException("Addition doesn't work")
@@ -220,5 +337,16 @@ def test(cls,n):
Q2 = Q0*(r+1)
if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
Q = Q1
test(Ed25519Point,100)
test(TwistedEd448GoldilocksPoint,100)
test(Ed448GoldilocksPoint,100)
test(IsoEd448Point,100)

def testElligator(cls,n):
for i in xrange(n):
cls.elligator(randombytes(cls.encLen))
testElligator(Ed25519Point,100)
testElligator(Ed448GoldilocksPoint,100)
testElligator(TwistedEd448GoldilocksPoint,100)
testElligator(IsoEd448Point,100)

Loading…
Cancel
Save