From e8d69e9978bf2c57db3d45f5f1658386b7703e91 Mon Sep 17 00:00:00 2001 From: Mike Hamburg Date: Mon, 4 May 2020 09:22:14 -0700 Subject: [PATCH] update ristretto.sage for python3. Also add Ed448RistrettoPoint for reference --- _aux/ristretto/ristretto.sage | 77 +++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/_aux/ristretto/ristretto.sage b/_aux/ristretto/ristretto.sage index 41b0d22..b85f5d3 100644 --- a/_aux/ristretto/ristretto.sage +++ b/_aux/ristretto/ristretto.sage @@ -6,7 +6,7 @@ class SpecException(Exception): pass def lobit(x): return int(x) & 1 def hibit(x): return lobit(2*x) def negative(x): return lobit(x) -def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) +def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in range(n)]) def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) @@ -22,17 +22,15 @@ def optimized_version_of(spec): try: opt_ans = f(self,*args,**kwargs),None except Exception as e: opt_ans = None,e if spec_ans[1] is None and opt_ans[1] is not None: - raise - #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s" - # % (f.__name__,str(spec_ans[0]),str(opt_ans[1]))) + raise SpecException("Mismatch in %s: spec returned %s but opt threw %s" + % (f.__name__,str(spec_ans[0]),str(opt_ans[1]))) if spec_ans[1] is not None and opt_ans[1] is None: - raise - #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s" - # % (f.__name__,str(spec_ans[1]),str(opt_ans[0]))) + raise SpecException("Mismatch in %s: spec threw %s but opt returned %s" + % (f.__name__,str(spec_ans[1]),str(opt_ans[0]))) if spec_ans[0] != opt_ans[0]: raise SpecException("Mismatch in %s: %s != %s" % (f.__name__,pr(spec_ans[0]),pr(opt_ans[0]))) - if opt_ans[1] is not None: raise + if opt_ans[1] is not None: raise opt_ans[1] else: return opt_ans[0] wrapper.__name__ = f.__name__ return wrapper @@ -133,7 +131,7 @@ class QuotientEdwardsPoint(object): s = dec_le(bytes) if mustBeProper and s >= cls.F.order(): raise InvalidEncodingException("%d out of range!" % s) - bitlen = int(ceil(log(cls.F.order())/log(2))) + bitlen = int(ceil(N(log(cls.F.order(),2.)))) if maskHiBits: s &= 2^bitlen-1 s = cls.F(s) if mustBePositive and negative(s): @@ -463,8 +461,8 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): if negative(sr) != toggle_r: sr = -sr ret = self.gfToBytes(sr) if self.elligator(ret) != self and self.elligator(ret) != -self: - print "WRONG!",[toggle_rotation,toggle_altx,toggle_s] - if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s] + print ("WRONG!",[toggle_rotation,toggle_altx,toggle_s]) + if self.elligator(ret) == -self and self != -self: print ("Negated!",[toggle_rotation,toggle_altx,toggle_s]) rets.append(bytes(ret)) return rets @@ -602,7 +600,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint): y = (1-a*s2) / t selfT = self - for i in xrange(self.cofactor/2): + for i in range(self.cofactor/2): xT,yT = selfT polyX = xT^2-x2 polyY = yT-y @@ -671,6 +669,22 @@ class IsoEd448Point(RistrettoPoint): 345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732, -363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718 ) + +class Ed448RistrettoPoint(RistrettoPoint): + F = GF(2^448-2^224-1) + d = F(-39081) + a = F(1) + mneg = F(-1) + qnr = -1 + magic = isqrt(a*d-1) + cofactor = 4 + encLen = 56 + + @classmethod + def base(cls): + return 2*cls( + 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660 + ) class TwistedEd448GoldilocksPoint(Decaf_1_1_Point): F = GF(2^448-2^224-1) @@ -721,7 +735,7 @@ class IsoEd25519Point(Decaf_1_1_Point): class TestFailedException(Exception): pass def test(cls,n): - print "Testing curve %s" % cls.__name__ + print ("Testing curve %s" % cls.__name__) specials = [1] ii = cls.F(-1) @@ -744,7 +758,7 @@ def test(cls,n): P = cls.base() Q = cls() - for i in xrange(n): + for i in range(n): #print binascii.hexlify(Q.encode()) QE = Q.encode() QQ = cls.decode(QE) @@ -766,7 +780,7 @@ def test(cls,n): raise TestFailedException("s -> 1/s should work for cofactor 4") QT = Q - for h in xrange(cls.cofactor): + for h in range(cls.cofactor): QT = QT.torque() if QT.encode() != QE: raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1)) @@ -782,27 +796,27 @@ def test(cls,n): Q = Q1 def testElligator(cls,n): - print "Testing elligator on %s" % cls.__name__ - for i in xrange(n): + print ("Testing elligator on %s" % cls.__name__) + for i in range(n): r = randombytes(cls.encLen) P = cls.elligator(r) if hasattr(P,"invertElligator"): iv = P.invertElligator() modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False,maskHiBits=True))) iv2 = P.torque().invertElligator() - if modr not in iv: print "Failed to invert Elligator!" + if modr not in iv: print ("Failed to invert Elligator!") if len(iv) != len(set(iv)): - print "Elligator inverses not unique!", len(set(iv)), len(iv) + print ("Elligator inverses not unique!", len(set(iv)), len(iv)) if iv != iv2: - print "Elligator is untorqueable!" - #print [binascii.hexlify(j) for j in iv] - #print [binascii.hexlify(j) for j in iv2] + print ("Elligator is untorqueable!") + #print ([binascii.hexlify(j) for j in iv]) + #print ([binascii.hexlify(j) for j in iv2]) #break else: pass # TODO def gangtest(classes,n): - print "Gang test",[cls.__name__ for cls in classes] + print ("Gang test",[cls.__name__ for cls in classes]) specials = [1] ii = classes[0].F(-1) while is_square(ii): @@ -810,12 +824,12 @@ def gangtest(classes,n): ii = sqrt(ii) specials.append(ii) - for i in xrange(n): + for i in range(n): rets = [bytes((cls.base()*i).encode()) for cls in classes] if len(set(rets)) != 1: - print "Divergence in encode at %d" % i + print ("Divergence in encode at %d" % i) for c,ret in zip(classes,rets): - print c,binascii.hexlify(ret) + print (c,binascii.hexlify(ret)) print if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen) @@ -823,21 +837,21 @@ def gangtest(classes,n): rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes] if len(set(rets)) != 1: - print "Divergence in elligator at %d" % i + print ("Divergence in elligator at %d" % i) for c,ret in zip(classes,rets): - print c,binascii.hexlify(ret) + print (c,binascii.hexlify(ret)) print def testDoubleAndEncode(cls,n): - print "Testing doubleAndEncode on %s" % cls.__name__ + print( "Testing doubleAndEncode on %s" % cls.__name__) P = cls() - for i in xrange(cls.cofactor): + for i in range(cls.cofactor): Q = P.torque() assert P.doubleAndEncode() == Q.doubleAndEncode() P = Q - for i in xrange(n): + for i in range(n): r1 = randombytes(cls.encLen) r2 = randombytes(cls.encLen) u = cls.elligator(r1) + cls.elligator(r2) @@ -847,6 +861,7 @@ testDoubleAndEncode(Ed25519Point,100) testDoubleAndEncode(NegEd25519Point,100) testDoubleAndEncode(IsoEd25519Point,100) testDoubleAndEncode(IsoEd448Point,100) +testDoubleAndEncode(Ed448RistrettoPoint,100) testDoubleAndEncode(TwistedEd448GoldilocksPoint,100) #test(Ed25519Point,100) #test(NegEd25519Point,100)