You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

412 lines
13 KiB

  1. import binascii
  2. class InvalidEncodingException(Exception): pass
  3. class NotOnCurveException(Exception): pass
  4. class SpecException(Exception): pass
  5. def lobit(x): return int(x) & 1
  6. def hibit(x): return lobit(2*x)
  7. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  8. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  9. def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
  10. def optimized_version_of(spec):
  11. """Decorator: This function is an optimized version of some specification"""
  12. def decorator(f):
  13. def wrapper(self,*args,**kwargs):
  14. try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
  15. except Exception as e: spec_ans = None,e
  16. try: opt_ans = f(self,*args,**kwargs),None
  17. except Exception as e: opt_ans = None,e
  18. if spec_ans[1] is None and opt_ans[1] is not None:
  19. raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
  20. % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
  21. if spec_ans[1] is not None and opt_ans[1] is None:
  22. raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
  23. % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
  24. if spec_ans[0] != opt_ans[0]:
  25. raise SpecException("Mismatch in %s: %s != %s"
  26. % (f.__name__,str(spec_ans[0]),str(opt_ans[0])))
  27. if opt_ans[1] is not None: raise
  28. else: return opt_ans[0]
  29. wrapper.__name__ = f.__name__
  30. return wrapper
  31. return decorator
  32. def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
  33. """Return sqrt(x)"""
  34. if not is_square(x): raise exn
  35. s = sqrt(x)
  36. if lobit(s): s=-s
  37. return s
  38. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  39. """Return 1/sqrt(x)"""
  40. if x==0: return 0
  41. if not is_square(x): raise exn
  42. return 1/sqrt(x)
  43. def isqrt_i(x):
  44. """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
  45. if x==0: return 0
  46. gen = x.parent(-1)
  47. while is_square(gen): gen = sqrt(gen)
  48. if is_square(x): return True,1/sqrt(x)
  49. else: return False,1/sqrt(x*gen)
  50. class QuotientEdwardsPoint(object):
  51. """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
  52. def __init__(self,x=0,y=1):
  53. x = self.x = self.F(x)
  54. y = self.y = self.F(y)
  55. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  56. raise NotOnCurveException(str(self))
  57. def __repr__(self):
  58. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  59. def __iter__(self):
  60. yield self.x
  61. yield self.y
  62. def __add__(self,other):
  63. x,y = self
  64. X,Y = other
  65. a,d = self.a,self.d
  66. return self.__class__(
  67. (x*Y+y*X)/(1+d*x*y*X*Y),
  68. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  69. )
  70. def __neg__(self): return self.__class__(-self.x,self.y)
  71. def __sub__(self,other): return self + (-other)
  72. def __rmul__(self,other): return self*other
  73. def __eq__(self,other):
  74. """NB: this is the only method that is different from the usual one"""
  75. x,y = self
  76. X,Y = other
  77. return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
  78. def __ne__(self,other): return not (self==other)
  79. def __mul__(self,exp):
  80. exp = int(exp)
  81. if exp < 0: exp,self = -exp,-self
  82. total = self.__class__()
  83. work = self
  84. while exp != 0:
  85. if exp & 1: total += work
  86. work += work
  87. exp >>= 1
  88. return total
  89. def xyzt(self):
  90. x,y = self
  91. z = self.F.random_element()
  92. return x*z,y*z,z,x*y*z
  93. def torque(self):
  94. """Apply cofactor group, except keeping the point even"""
  95. if self.cofactor == 8:
  96. return self.__class__(self.y*self.i, self.x*self.i)
  97. else:
  98. return self.__class__(-self.x, -self.y)
  99. # Utility functions
  100. @classmethod
  101. def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
  102. """Convert little-endian bytes to field element, sanity check length"""
  103. if len(bytes) != cls.encLen:
  104. raise InvalidEncodingException("wrong length %d" % len(bytes))
  105. s = dec_le(bytes)
  106. if mustBeProper and s >= cls.F.modulus():
  107. raise InvalidEncodingException("%d out of range!" % s)
  108. if mustBePositive and lobit(s):
  109. raise InvalidEncodingException("%d is negative!" % s)
  110. return cls.F(s)
  111. @classmethod
  112. def gfToBytes(cls,x,mustBePositive=False):
  113. """Convert little-endian bytes to field element, sanity check length"""
  114. if lobit(x) and mustBePositive: x = -x
  115. return enc_le(x,cls.encLen)
  116. class RistrettoPoint(QuotientEdwardsPoint):
  117. """The new Ristretto group"""
  118. def encodeSpec(self):
  119. """Unoptimized specification for encoding"""
  120. x,y = self
  121. if self.cofactor==8 and (lobit(x*y) or y==0):
  122. (x,y) = (self.i*y,self.i*x)
  123. if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
  124. if lobit(x): x,y = -x,-y
  125. s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self)))
  126. return self.gfToBytes(s)
  127. @classmethod
  128. def decodeSpec(cls,s):
  129. """Unoptimized specification for decoding"""
  130. s = cls.bytesToGf(s,mustBePositive=True)
  131. a,d = cls.a,cls.d
  132. x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
  133. y = (1+a*s^2) / (1-a*s^2)
  134. if cls.cofactor==8 and (lobit(x*y) or y==0):
  135. raise InvalidEncodingException("x*y has high bit")
  136. return cls(x,y)
  137. @optimized_version_of("encodeSpec")
  138. def encode(self):
  139. """Encode, optimized version"""
  140. a,d = self.a,self.d
  141. x,y,z,t = self.xyzt()
  142. u1 = a*(y+z)*(y-z)
  143. u2 = x*y # = t*z
  144. isr = isqrt(u1*u2^2)
  145. i1 = isr*u1
  146. i2 = isr*u2
  147. z_inv = i1*i2*t
  148. if self.cofactor==8 and lobit(t*z_inv):
  149. x,y = y*self.i,x*self.i
  150. den_inv = self.magic * i1
  151. else:
  152. den_inv = i2
  153. if lobit(x*z_inv): y = -y
  154. s = (z-y) * den_inv
  155. return self.gfToBytes(s,mustBePositive=True)
  156. @classmethod
  157. @optimized_version_of("decodeSpec")
  158. def decode(cls,s):
  159. """Decode, optimized version"""
  160. s = cls.bytesToGf(s,mustBePositive=True)
  161. a,d = cls.a,cls.d
  162. yden = 1-a*s^2
  163. ynum = 1+a*s^2
  164. yden_sqr = yden^2
  165. xden_sqr = a*d*ynum^2 - yden_sqr
  166. isr = isqrt(xden_sqr * yden_sqr)
  167. xden_inv = isr * yden
  168. yden_inv = xden_inv * isr * xden_sqr
  169. x = 2*s*xden_inv
  170. if lobit(x): x = -x
  171. y = ynum * yden_inv
  172. if cls.cofactor==8 and (lobit(x*y) or y==0):
  173. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  174. return cls(x,y)
  175. @classmethod
  176. def fromJacobiQuartic(cls,s,t,sgn=1):
  177. """Convert point from its Jacobi Quartic representation"""
  178. a,d = cls.a,cls.d
  179. assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
  180. x = 2*s*cls.magic / t
  181. if lobit(x): x = -x # TODO: doesn't work without resolving x
  182. y = (1+a*s^2) / (1-a*s^2)
  183. return cls(sgn*x,y)
  184. @classmethod
  185. def elligatorSpec(cls,r0):
  186. a,d = cls.a,cls.d
  187. r = cls.qnr * cls.bytesToGf(r0)^2
  188. den = (d*r-a)*(a*r-d)
  189. n1 = cls.a*(r+1)*(a+d)*(d-a)/den
  190. n2 = r*n1
  191. if is_square(n1):
  192. sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
  193. else:
  194. sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
  195. return cls.fromJacobiQuartic(s,t,sgn)
  196. @classmethod
  197. @optimized_version_of("elligatorSpec")
  198. def elligator(cls,r0):
  199. a,d = cls.a,cls.d
  200. r0 = cls.bytesToGf(r0)
  201. r = cls.qnr * r0^2
  202. den = (d*r-a)*(a*r-d)
  203. num = cls.a*(r+1)*(a+d)*(d-a)
  204. iss,isri = isqrt_i(num*den)
  205. if iss: sgn,twiddle = 1,1
  206. else: sgn,twiddle = -1,r0*cls.qnr
  207. isri *= twiddle
  208. s = isri*num
  209. t = isri*s*(r-1)*(d+a)^2 + sgn
  210. return cls.fromJacobiQuartic(s,t,sgn)
  211. class Decaf1Point(QuotientEdwardsPoint):
  212. """Like current decaf but tweaked for simplicity"""
  213. def encodeSpec(self):
  214. """Unoptimized specification for encoding"""
  215. a,d = self.a,self.d
  216. x,y = self
  217. if x==0: return(self.gfToBytes(0))
  218. isr2 = isqrt(a*(y^2-1)) / self.magic
  219. altx = 1/isr2*self.isoMagic
  220. if lobit(altx): s = (1+x*y*isr2)/(a*x)
  221. else: s = (1-x*y*isr2)/(a*x)
  222. # TODO: cofactor 8
  223. return self.gfToBytes(s,mustBePositive=True)
  224. @classmethod
  225. def decodeSpec(cls,s):
  226. """Unoptimized specification for decoding"""
  227. a,d = cls.a,cls.d
  228. s = cls.bytesToGf(s,mustBePositive=True)
  229. if s==0: return cls()
  230. isr = isqrt(s^4 + 2*(a-2*d)*s^2 + 1)
  231. altx = 2*s*isr*cls.isoMagic
  232. if lobit(altx): isr = -isr
  233. x = 2*s / (1+a*s^2)
  234. y = (1-a*s^2) * isr
  235. # TODO: cofactor 8
  236. return cls(x,y)
  237. @optimized_version_of("encodeSpec")
  238. def encode(self):
  239. """Encode, optimized version"""
  240. return self.encodeSpec() # TODO
  241. @classmethod
  242. @optimized_version_of("decodeSpec")
  243. def decode(cls,s):
  244. """Decode, optimized version"""
  245. return cls.decodeSpec(s) # TODO
  246. class Ed25519Point(RistrettoPoint):
  247. F = GF(2^255-19)
  248. d = F(-121665/121666)
  249. a = F(-1)
  250. i = sqrt(F(-1))
  251. qnr = i
  252. magic = isqrt(a*d-1)
  253. cofactor = 8
  254. encLen = 32
  255. @classmethod
  256. def base(cls):
  257. y = cls.F(4/5)
  258. x = sqrt((y^2-1)/(cls.d*y^2+1))
  259. if lobit(x): x = -x
  260. return cls(x,y)
  261. class IsoEd448Point(RistrettoPoint):
  262. F = GF(2^448-2^224-1)
  263. d = F(39082/39081)
  264. a = F(1)
  265. qnr = -1
  266. magic = isqrt(a*d-1)
  267. cofactor = 4
  268. encLen = 56
  269. @classmethod
  270. def base(cls):
  271. # = ..., -3/2
  272. return cls.decodeSpec(bytearray(binascii.unhexlify(
  273. "00000000000000000000000000000000000000000000000000000000"+
  274. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  275. class TwistedEd448GoldilocksPoint(Decaf1Point):
  276. F = GF(2^448-2^224-1)
  277. d = F(-39082)
  278. a = F(-1)
  279. qnr = -1
  280. magic = isqrt(a*d-1)
  281. cofactor = 4
  282. encLen = 56
  283. isoMagic = IsoEd448Point.magic
  284. @classmethod
  285. def base(cls):
  286. return cls.decodeSpec(bytearray(binascii.unhexlify(
  287. "00000000000000000000000000000000000000000000000000000000"+
  288. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  289. class Ed448GoldilocksPoint(Decaf1Point):
  290. F = GF(2^448-2^224-1)
  291. d = F(-39081)
  292. a = F(1)
  293. qnr = -1
  294. magic = isqrt(a*d-1)
  295. cofactor = 4
  296. encLen = 56
  297. isoMagic = IsoEd448Point.magic
  298. @classmethod
  299. def base(cls):
  300. return cls.decodeSpec(bytearray(binascii.unhexlify(
  301. "00000000000000000000000000000000000000000000000000000000"+
  302. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  303. class TestFailedException(Exception): pass
  304. def test(cls,n):
  305. # TODO: test corner cases like 0,1,i
  306. P = cls.base()
  307. Q = cls()
  308. for i in xrange(n):
  309. #print binascii.hexlify(Q.encode())
  310. QQ = cls.decode(Q.encode())
  311. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  312. QT = Q
  313. QE = Q.encode()
  314. for h in xrange(cls.cofactor):
  315. QT = QT.torque()
  316. if QT.encode() != QE:
  317. raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
  318. Q0 = Q + P
  319. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  320. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  321. r = randint(1,1000)
  322. Q1 = Q0*r
  323. Q2 = Q0*(r+1)
  324. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  325. Q = Q1
  326. test(Ed25519Point,100)
  327. test(IsoEd448Point,100)
  328. test(TwistedEd448GoldilocksPoint,100)
  329. test(Ed448GoldilocksPoint,100)
  330. def gangtest(classes,n):
  331. for i in xrange(n):
  332. rets = [bytes((cls.base()*i).encode()) for cls in classes]
  333. if len(set(rets)) != 1:
  334. print "Divergence at %d" % i
  335. for c,ret in zip(classes,rets):
  336. print c,binascii.hexlify(ret)
  337. print
  338. gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
  339. def testElligator(cls,n):
  340. for i in xrange(n):
  341. cls.elligator(randombytes(cls.encLen))
  342. testElligator(Ed25519Point,100)
  343. testElligator(IsoEd448Point,100)
  344. # testElligator(Ed448GoldilocksPoint,100)
  345. # testElligator(TwistedEd448GoldilocksPoint,100)