|
- import binascii
- class InvalidEncodingException(Exception): pass
- class NotOnCurveException(Exception): pass
- class SpecException(Exception): pass
-
- def lobit(x): return int(x) & 1
- def hibit(x): return lobit(2*x)
- def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
- def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
- def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
-
- def optimized_version_of(spec):
- """Decorator: This function is an optimized version of some specification"""
- def decorator(f):
- def wrapper(self,*args,**kwargs):
- try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
- except Exception as e: spec_ans = None,e
- try: opt_ans = f(self,*args,**kwargs),None
- except Exception as e: opt_ans = None,e
- if spec_ans[1] is None and opt_ans[1] is not None:
- raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
- % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
- if spec_ans[1] is not None and opt_ans[1] is None:
- raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
- % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
- if spec_ans[0] != opt_ans[0]:
- raise SpecException("Mismatch in %s: %s != %s"
- % (f.__name__,str(spec_ans[0]),str(opt_ans[0])))
- if opt_ans[1] is not None: raise
- else: return opt_ans[0]
- wrapper.__name__ = f.__name__
- return wrapper
- return decorator
-
- def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
- """Return sqrt(x)"""
- if not is_square(x): raise exn
- s = sqrt(x)
- if lobit(s): s=-s
- return s
-
- def isqrt(x,exn=InvalidEncodingException("Not on curve")):
- """Return 1/sqrt(x)"""
- if x==0: return 0
- if not is_square(x): raise exn
- return 1/sqrt(x)
-
- def isqrt_i(x):
- """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
- if x==0: return 0
- gen = x.parent(-1)
- while is_square(gen): gen = sqrt(gen)
- if is_square(x): return True,1/sqrt(x)
- else: return False,1/sqrt(x*gen)
-
- class QuotientEdwardsPoint(object):
- """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
- def __init__(self,x=0,y=1):
- x = self.x = self.F(x)
- y = self.y = self.F(y)
- if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
- raise NotOnCurveException(str(self))
-
- def __repr__(self):
- return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
-
- def __iter__(self):
- yield self.x
- yield self.y
-
- def __add__(self,other):
- x,y = self
- X,Y = other
- a,d = self.a,self.d
- return self.__class__(
- (x*Y+y*X)/(1+d*x*y*X*Y),
- (y*Y-a*x*X)/(1-d*x*y*X*Y)
- )
-
- def __neg__(self): return self.__class__(-self.x,self.y)
- def __sub__(self,other): return self + (-other)
- def __rmul__(self,other): return self*other
- def __eq__(self,other):
- """NB: this is the only method that is different from the usual one"""
- x,y = self
- X,Y = other
- return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
- def __ne__(self,other): return not (self==other)
-
- def __mul__(self,exp):
- exp = int(exp)
- if exp < 0: exp,self = -exp,-self
- total = self.__class__()
- work = self
- while exp != 0:
- if exp & 1: total += work
- work += work
- exp >>= 1
- return total
-
- def xyzt(self):
- x,y = self
- z = self.F.random_element()
- return x*z,y*z,z,x*y*z
-
- def torque(self):
- """Apply cofactor group, except keeping the point even"""
- if self.cofactor == 8:
- return self.__class__(self.y*self.i, self.x*self.i)
- else:
- return self.__class__(-self.x, -self.y)
-
-
- # Utility functions
- @classmethod
- def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
- """Convert little-endian bytes to field element, sanity check length"""
- if len(bytes) != cls.encLen:
- raise InvalidEncodingException("wrong length %d" % len(bytes))
- s = dec_le(bytes)
- if mustBeProper and s >= cls.F.modulus():
- raise InvalidEncodingException("%d out of range!" % s)
- if mustBePositive and lobit(s):
- raise InvalidEncodingException("%d is negative!" % s)
- return cls.F(s)
-
- @classmethod
- def gfToBytes(cls,x,mustBePositive=False):
- """Convert little-endian bytes to field element, sanity check length"""
- if lobit(x) and mustBePositive: x = -x
- return enc_le(x,cls.encLen)
-
- class RistrettoPoint(QuotientEdwardsPoint):
- """The new Ristretto group"""
- def encodeSpec(self):
- """Unoptimized specification for encoding"""
- x,y = self
- if self.cofactor==8 and (lobit(x*y) or y==0):
- (x,y) = (self.i*y,self.i*x)
-
- if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
-
- if lobit(x): x,y = -x,-y
- s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self)))
- return self.gfToBytes(s)
-
- @classmethod
- def decodeSpec(cls,s):
- """Unoptimized specification for decoding"""
- s = cls.bytesToGf(s,mustBePositive=True)
-
- a,d = cls.a,cls.d
- x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
- y = (1+a*s^2) / (1-a*s^2)
-
- if cls.cofactor==8 and (lobit(x*y) or y==0):
- raise InvalidEncodingException("x*y has high bit")
-
- return cls(x,y)
-
- @optimized_version_of("encodeSpec")
- def encode(self):
- """Encode, optimized version"""
- a,d = self.a,self.d
- x,y,z,t = self.xyzt()
-
- u1 = a*(y+z)*(y-z)
- u2 = x*y # = t*z
- isr = isqrt(u1*u2^2)
- i1 = isr*u1
- i2 = isr*u2
- z_inv = i1*i2*t
-
- if self.cofactor==8 and lobit(t*z_inv):
- x,y = y*self.i,x*self.i
- den_inv = self.magic * i1
- else:
- den_inv = i2
-
- if lobit(x*z_inv): y = -y
- s = (z-y) * den_inv
- return self.gfToBytes(s,mustBePositive=True)
-
- @classmethod
- @optimized_version_of("decodeSpec")
- def decode(cls,s):
- """Decode, optimized version"""
- s = cls.bytesToGf(s,mustBePositive=True)
-
- a,d = cls.a,cls.d
- yden = 1-a*s^2
- ynum = 1+a*s^2
- yden_sqr = yden^2
- xden_sqr = a*d*ynum^2 - yden_sqr
-
- isr = isqrt(xden_sqr * yden_sqr)
-
- xden_inv = isr * yden
- yden_inv = xden_inv * isr * xden_sqr
-
- x = 2*s*xden_inv
- if lobit(x): x = -x
- y = ynum * yden_inv
-
- if cls.cofactor==8 and (lobit(x*y) or y==0):
- raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
-
- return cls(x,y)
-
- @classmethod
- def fromJacobiQuartic(cls,s,t,sgn=1):
- """Convert point from its Jacobi Quartic representation"""
- a,d = cls.a,cls.d
- assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
- x = 2*s*cls.magic / t
- if lobit(x): x = -x # TODO: doesn't work without resolving x
- y = (1+a*s^2) / (1-a*s^2)
- return cls(sgn*x,y)
-
- @classmethod
- def elligatorSpec(cls,r0):
- a,d = cls.a,cls.d
- r = cls.qnr * cls.bytesToGf(r0)^2
- den = (d*r-a)*(a*r-d)
- n1 = cls.a*(r+1)*(a+d)*(d-a)/den
- n2 = r*n1
- if is_square(n1):
- sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
- else:
- sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
-
- return cls.fromJacobiQuartic(s,t,sgn)
-
- @classmethod
- @optimized_version_of("elligatorSpec")
- def elligator(cls,r0):
- a,d = cls.a,cls.d
- r0 = cls.bytesToGf(r0)
- r = cls.qnr * r0^2
- den = (d*r-a)*(a*r-d)
- num = cls.a*(r+1)*(a+d)*(d-a)
-
- iss,isri = isqrt_i(num*den)
- if iss: sgn,twiddle = 1,1
- else: sgn,twiddle = -1,r0*cls.qnr
- isri *= twiddle
- s = isri*num
- t = isri*s*(r-1)*(d+a)^2 + sgn
- return cls.fromJacobiQuartic(s,t,sgn)
-
-
- class Decaf1Point(QuotientEdwardsPoint):
- """Like current decaf but tweaked for simplicity"""
- def encodeSpec(self):
- """Unoptimized specification for encoding"""
- a,d = self.a,self.d
- x,y = self
- if x==0: return(self.gfToBytes(0))
-
- isr2 = isqrt(a*(y^2-1)) / self.magic
- altx = 1/isr2*self.isoMagic
- if lobit(altx): s = (1+x*y*isr2)/(a*x)
- else: s = (1-x*y*isr2)/(a*x)
-
- # TODO: cofactor 8
- return self.gfToBytes(s,mustBePositive=True)
-
- @classmethod
- def decodeSpec(cls,s):
- """Unoptimized specification for decoding"""
- a,d = cls.a,cls.d
- s = cls.bytesToGf(s,mustBePositive=True)
-
- if s==0: return cls()
- isr = isqrt(s^4 + 2*(a-2*d)*s^2 + 1)
- altx = 2*s*isr*cls.isoMagic
- if lobit(altx): isr = -isr
- x = 2*s / (1+a*s^2)
- y = (1-a*s^2) * isr
-
- # TODO: cofactor 8
- return cls(x,y)
-
- @optimized_version_of("encodeSpec")
- def encode(self):
- """Encode, optimized version"""
- return self.encodeSpec() # TODO
-
- @classmethod
- @optimized_version_of("decodeSpec")
- def decode(cls,s):
- """Decode, optimized version"""
- return cls.decodeSpec(s) # TODO
-
- class Ed25519Point(RistrettoPoint):
- F = GF(2^255-19)
- d = F(-121665/121666)
- a = F(-1)
- i = sqrt(F(-1))
- qnr = i
- magic = isqrt(a*d-1)
- cofactor = 8
- encLen = 32
-
- @classmethod
- def base(cls):
- y = cls.F(4/5)
- x = sqrt((y^2-1)/(cls.d*y^2+1))
- if lobit(x): x = -x
- return cls(x,y)
-
- class IsoEd448Point(RistrettoPoint):
- F = GF(2^448-2^224-1)
- d = F(39082/39081)
- a = F(1)
- qnr = -1
- magic = isqrt(a*d-1)
- cofactor = 4
- encLen = 56
-
- @classmethod
- def base(cls):
- # = ..., -3/2
- return cls.decodeSpec(bytearray(binascii.unhexlify(
- "00000000000000000000000000000000000000000000000000000000"+
- "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
-
- class TwistedEd448GoldilocksPoint(Decaf1Point):
- F = GF(2^448-2^224-1)
- d = F(-39082)
- a = F(-1)
- qnr = -1
- magic = isqrt(a*d-1)
- cofactor = 4
- encLen = 56
- isoMagic = IsoEd448Point.magic
-
- @classmethod
- def base(cls):
- return cls.decodeSpec(bytearray(binascii.unhexlify(
- "00000000000000000000000000000000000000000000000000000000"+
- "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
-
- class Ed448GoldilocksPoint(Decaf1Point):
- F = GF(2^448-2^224-1)
- d = F(-39081)
- a = F(1)
- qnr = -1
- magic = isqrt(a*d-1)
- cofactor = 4
- encLen = 56
- isoMagic = IsoEd448Point.magic
-
- @classmethod
- def base(cls):
- return cls.decodeSpec(bytearray(binascii.unhexlify(
- "00000000000000000000000000000000000000000000000000000000"+
- "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
-
- class TestFailedException(Exception): pass
-
- def test(cls,n):
- # TODO: test corner cases like 0,1,i
- P = cls.base()
- Q = cls()
- for i in xrange(n):
- #print binascii.hexlify(Q.encode())
- QQ = cls.decode(Q.encode())
- if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
-
- QT = Q
- QE = Q.encode()
- for h in xrange(cls.cofactor):
- QT = QT.torque()
- if QT.encode() != QE:
- raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
-
- Q0 = Q + P
- if Q0 == Q: raise TestFailedException("Addition doesn't work")
- if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
-
- r = randint(1,1000)
- Q1 = Q0*r
- Q2 = Q0*(r+1)
- if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
- Q = Q1
- test(Ed25519Point,100)
- test(IsoEd448Point,100)
- test(TwistedEd448GoldilocksPoint,100)
- test(Ed448GoldilocksPoint,100)
-
- def gangtest(classes,n):
- for i in xrange(n):
- rets = [bytes((cls.base()*i).encode()) for cls in classes]
- if len(set(rets)) != 1:
- print "Divergence at %d" % i
- for c,ret in zip(classes,rets):
- print c,binascii.hexlify(ret)
- print
- gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
-
-
-
-
- def testElligator(cls,n):
- for i in xrange(n):
- cls.elligator(randombytes(cls.encLen))
- testElligator(Ed25519Point,100)
- testElligator(IsoEd448Point,100)
- # testElligator(Ed448GoldilocksPoint,100)
- # testElligator(TwistedEd448GoldilocksPoint,100)
|