Browse Source

passes gang tests. ship it?

master
Michael Hamburg 7 years ago
parent
commit
b9b855f172
1 changed files with 72 additions and 13 deletions
  1. +72
    -13
      aux/ristretto.sage

+ 72
- 13
aux/ristretto.sage View File

@@ -308,6 +308,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
x,y,z,t = self.xyzt() x,y,z,t = self.xyzt()
if self.cofactor == 8: if self.cofactor == 8:
# Cofactor 8 version
num = (z+y)*(z-y) num = (z+y)*(z-y)
den = x*y den = x*y
tmp = isqrt(num*(a-d)*den^2) tmp = isqrt(num*(a-d)*den^2)
@@ -327,6 +328,7 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
s = tmp*den*yisr*(tiisr*z - 1) s = tmp*den*yisr*(tiisr*z - 1)
else: else:
# Much simpler cofactor 4 version
num = (x+t)*(x-t) num = (x+t)*(x-t)
isr = isqrt(num*(a-d)*x^2) isr = isqrt(num*(a-d)*x^2)
ratio = isr*num ratio = isr*num
@@ -339,8 +341,55 @@ class Decaf_1_1_Point(QuotientEdwardsPoint):
@optimized_version_of("decodeSpec") @optimized_version_of("decodeSpec")
def decode(cls,s): def decode(cls,s):
"""Decode, optimized version""" """Decode, optimized version"""
return cls.decodeSpec(s) # TODO
a,d = cls.a,cls.d
s = cls.bytesToGf(s,mustBePositive=True)
if s==0: return cls()
s2 = s^2
den = 1+a*s2
num = den^2 - 4*d*s2
isr = isqrt(num*den^2)
altx = 2*s*isr*den*cls.isoMagic
if negative(altx): isr = -isr
x = 2*s *isr^2*den*num
y = (1-a*s^2) * isr*den
if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
return cls(x,y)


@classmethod
def fromJacobiQuartic(cls,s,t,sgn=1):
"""Convert point from its Jacobi Quartic representation"""
a,d = cls.a,cls.d
if s==0: return cls()
altx = 2*s*cls.isoMagic / t
if negative(altx): t = -t
x = 2*s / (1+a*s^2)
y = (1-a*s^2) / t
return cls(x,sgn*y)
@classmethod
def elligatorSpec(cls,r0):
a,d = cls.a,cls.d
r = cls.qnr * cls.bytesToGf(r0)^2
den = (d*r-(d-a))*((d-a)*r-d)
n1 = (r+1)*(a-2*d)/den
n2 = r*n1
if is_square(n1):
sgn,s,t = 1,xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1
else:
sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1
return cls.fromJacobiQuartic(s,t,sgn)
@classmethod
@optimized_version_of("elligatorSpec")
def elligator(cls,r0):
return cls.elligatorSpec(r0)
class Ed25519Point(RistrettoPoint): class Ed25519Point(RistrettoPoint):
F = GF(2^255-19) F = GF(2^255-19)
d = F(-121665/121666) d = F(-121665/121666)
@@ -427,6 +476,7 @@ class IsoEd25519Point(Decaf_1_1_Point):
class TestFailedException(Exception): pass class TestFailedException(Exception): pass


def test(cls,n): def test(cls,n):
print "Testing curve %s" % cls.__name__
# TODO: test corner cases like 0,1,i # TODO: test corner cases like 0,1,i
P = cls.base() P = cls.base()
Q = cls() Q = cls()
@@ -451,30 +501,39 @@ def test(cls,n):
Q2 = Q0*(r+1) Q2 = Q0*(r+1)
if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
Q = Q1 Q = Q1

test(Ed25519Point,100) test(Ed25519Point,100)
test(IsoEd25519Point,100) test(IsoEd25519Point,100)
test(IsoEd448Point,100) test(IsoEd448Point,100)
test(TwistedEd448GoldilocksPoint,100) test(TwistedEd448GoldilocksPoint,100)
test(Ed448GoldilocksPoint,100) test(Ed448GoldilocksPoint,100)
def testElligator(cls,n):
print "Testing elligator on %s" % cls.__name__
for i in xrange(n):
cls.elligator(randombytes(cls.encLen))
testElligator(Ed25519Point,100)
testElligator(IsoEd448Point,100)
testElligator(Ed448GoldilocksPoint,100)
testElligator(TwistedEd448GoldilocksPoint,100)


def gangtest(classes,n): def gangtest(classes,n):
print "Gang test",[cls.__name__ for cls in classes]
for i in xrange(n): for i in xrange(n):
rets = [bytes((cls.base()*i).encode()) for cls in classes] rets = [bytes((cls.base()*i).encode()) for cls in classes]
if len(set(rets)) != 1: if len(set(rets)) != 1:
print "Divergence at %d" % i
print "Divergence in encode at %d" % i
for c,ret in zip(classes,rets):
print c,binascii.hexlify(ret)
print
r0 = randombytes(classes[0].encLen)
rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
if len(set(rets)) != 1:
print "Divergence in elligator at %d" % i
for c,ret in zip(classes,rets): for c,ret in zip(classes,rets):
print c,binascii.hexlify(ret) print c,binascii.hexlify(ret)
print print
gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100) gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
gangtest([Ed25519Point,IsoEd25519Point],100) gangtest([Ed25519Point,IsoEd25519Point],100)

def testElligator(cls,n):
for i in xrange(n):
cls.elligator(randombytes(cls.encLen))
testElligator(Ed25519Point,100)
testElligator(IsoEd448Point,100)
# testElligator(Ed448GoldilocksPoint,100)
# testElligator(TwistedEd448GoldilocksPoint,100)

Loading…
Cancel
Save