Browse Source

update ristretto.sage for python3. Also add Ed448RistrettoPoint for reference

master
Mike Hamburg 4 years ago
parent
commit
e8d69e9978
1 changed files with 46 additions and 31 deletions
  1. +46
    -31
      _aux/ristretto/ristretto.sage

+ 46
- 31
_aux/ristretto/ristretto.sage View File

@@ -6,7 +6,7 @@ class SpecException(Exception): pass
def lobit(x): return int(x) & 1
def hibit(x): return lobit(2*x)
def negative(x): return lobit(x)
def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in range(n)])
def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])

@@ -22,17 +22,15 @@ def optimized_version_of(spec):
try: opt_ans = f(self,*args,**kwargs),None
except Exception as e: opt_ans = None,e
if spec_ans[1] is None and opt_ans[1] is not None:
raise
#raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
# % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
% (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
if spec_ans[1] is not None and opt_ans[1] is None:
raise
#raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
# % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
% (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
if spec_ans[0] != opt_ans[0]:
raise SpecException("Mismatch in %s: %s != %s"
% (f.__name__,pr(spec_ans[0]),pr(opt_ans[0])))
if opt_ans[1] is not None: raise
if opt_ans[1] is not None: raise opt_ans[1]
else: return opt_ans[0]
wrapper.__name__ = f.__name__
return wrapper
@@ -133,7 +131,7 @@ class QuotientEdwardsPoint(object):
s = dec_le(bytes)
if mustBeProper and s >= cls.F.order():
raise InvalidEncodingException("%d out of range!" % s)
bitlen = int(ceil(log(cls.F.order())/log(2)))
bitlen = int(ceil(N(log(cls.F.order(),2.))))
if maskHiBits: s &= 2^bitlen-1
s = cls.F(s)
if mustBePositive and negative(s):
@@ -463,8 +461,8 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
if negative(sr) != toggle_r: sr = -sr
ret = self.gfToBytes(sr)
if self.elligator(ret) != self and self.elligator(ret) != -self:
print "WRONG!",[toggle_rotation,toggle_altx,toggle_s]
if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s]
print ("WRONG!",[toggle_rotation,toggle_altx,toggle_s])
if self.elligator(ret) == -self and self != -self: print ("Negated!",[toggle_rotation,toggle_altx,toggle_s])
rets.append(bytes(ret))
return rets

@@ -602,7 +600,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
y = (1-a*s2) / t

selfT = self
for i in xrange(self.cofactor/2):
for i in range(self.cofactor/2):
xT,yT = selfT
polyX = xT^2-x2
polyY = yT-y
@@ -671,6 +669,22 @@ class IsoEd448Point(RistrettoPoint):
345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
-363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
)

class Ed448RistrettoPoint(RistrettoPoint):
F = GF(2^448-2^224-1)
d = F(-39081)
a = F(1)
mneg = F(-1)
qnr = -1
magic = isqrt(a*d-1)
cofactor = 4
encLen = 56
@classmethod
def base(cls):
return 2*cls(
224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
)
class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
F = GF(2^448-2^224-1)
@@ -721,7 +735,7 @@ class IsoEd25519Point(Decaf_1_1_Point):
class TestFailedException(Exception): pass

def test(cls,n):
print "Testing curve %s" % cls.__name__
print ("Testing curve %s" % cls.__name__)
specials = [1]
ii = cls.F(-1)
@@ -744,7 +758,7 @@ def test(cls,n):
P = cls.base()
Q = cls()
for i in xrange(n):
for i in range(n):
#print binascii.hexlify(Q.encode())
QE = Q.encode()
QQ = cls.decode(QE)
@@ -766,7 +780,7 @@ def test(cls,n):
raise TestFailedException("s -> 1/s should work for cofactor 4")
QT = Q
for h in xrange(cls.cofactor):
for h in range(cls.cofactor):
QT = QT.torque()
if QT.encode() != QE:
raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
@@ -782,27 +796,27 @@ def test(cls,n):
Q = Q1
def testElligator(cls,n):
print "Testing elligator on %s" % cls.__name__
for i in xrange(n):
print ("Testing elligator on %s" % cls.__name__)
for i in range(n):
r = randombytes(cls.encLen)
P = cls.elligator(r)
if hasattr(P,"invertElligator"):
iv = P.invertElligator()
modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False,maskHiBits=True)))
iv2 = P.torque().invertElligator()
if modr not in iv: print "Failed to invert Elligator!"
if modr not in iv: print ("Failed to invert Elligator!")
if len(iv) != len(set(iv)):
print "Elligator inverses not unique!", len(set(iv)), len(iv)
print ("Elligator inverses not unique!", len(set(iv)), len(iv))
if iv != iv2:
print "Elligator is untorqueable!"
#print [binascii.hexlify(j) for j in iv]
#print [binascii.hexlify(j) for j in iv2]
print ("Elligator is untorqueable!")
#print ([binascii.hexlify(j) for j in iv])
#print ([binascii.hexlify(j) for j in iv2])
#break
else:
pass # TODO

def gangtest(classes,n):
print "Gang test",[cls.__name__ for cls in classes]
print ("Gang test",[cls.__name__ for cls in classes])
specials = [1]
ii = classes[0].F(-1)
while is_square(ii):
@@ -810,12 +824,12 @@ def gangtest(classes,n):
ii = sqrt(ii)
specials.append(ii)
for i in xrange(n):
for i in range(n):
rets = [bytes((cls.base()*i).encode()) for cls in classes]
if len(set(rets)) != 1:
print "Divergence in encode at %d" % i
print ("Divergence in encode at %d" % i)
for c,ret in zip(classes,rets):
print c,binascii.hexlify(ret)
print (c,binascii.hexlify(ret))
print
if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen)
@@ -823,21 +837,21 @@ def gangtest(classes,n):
rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
if len(set(rets)) != 1:
print "Divergence in elligator at %d" % i
print ("Divergence in elligator at %d" % i)
for c,ret in zip(classes,rets):
print c,binascii.hexlify(ret)
print (c,binascii.hexlify(ret))
print

def testDoubleAndEncode(cls,n):
print "Testing doubleAndEncode on %s" % cls.__name__
print( "Testing doubleAndEncode on %s" % cls.__name__)
P = cls()
for i in xrange(cls.cofactor):
for i in range(cls.cofactor):
Q = P.torque()
assert P.doubleAndEncode() == Q.doubleAndEncode()
P = Q
for i in xrange(n):
for i in range(n):
r1 = randombytes(cls.encLen)
r2 = randombytes(cls.encLen)
u = cls.elligator(r1) + cls.elligator(r2)
@@ -847,6 +861,7 @@ testDoubleAndEncode(Ed25519Point,100)
testDoubleAndEncode(NegEd25519Point,100)
testDoubleAndEncode(IsoEd25519Point,100)
testDoubleAndEncode(IsoEd448Point,100)
testDoubleAndEncode(Ed448RistrettoPoint,100)
testDoubleAndEncode(TwistedEd448GoldilocksPoint,100)
#test(Ed25519Point,100)
#test(NegEd25519Point,100)


Loading…
Cancel
Save