You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

487 lines
15 KiB

  1. import binascii
  2. class InvalidEncodingException(Exception): pass
  3. class NotOnCurveException(Exception): pass
  4. class SpecException(Exception): pass
  5. def lobit(x): return int(x) & 1
  6. def hibit(x): return lobit(2*x)
  7. def negative(x): return lobit(x)
  8. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  9. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  10. def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
  11. def optimized_version_of(spec):
  12. """Decorator: This function is an optimized version of some specification"""
  13. def decorator(f):
  14. def wrapper(self,*args,**kwargs):
  15. try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
  16. except Exception as e: spec_ans = None,e
  17. try: opt_ans = f(self,*args,**kwargs),None
  18. except Exception as e: opt_ans = None,e
  19. if spec_ans[1] is None and opt_ans[1] is not None:
  20. raise
  21. #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
  22. # % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
  23. if spec_ans[1] is not None and opt_ans[1] is None:
  24. raise
  25. #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
  26. # % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
  27. if spec_ans[0] != opt_ans[0]:
  28. raise SpecException("Mismatch in %s: %s != %s"
  29. % (f.__name__,str(spec_ans[0]),str(opt_ans[0])))
  30. if opt_ans[1] is not None: raise
  31. else: return opt_ans[0]
  32. wrapper.__name__ = f.__name__
  33. return wrapper
  34. return decorator
  35. def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
  36. """Return sqrt(x)"""
  37. if not is_square(x): raise exn
  38. s = sqrt(x)
  39. if negative(s): s=-s
  40. return s
  41. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  42. """Return 1/sqrt(x)"""
  43. if x==0: return 0
  44. if not is_square(x): raise exn
  45. return 1/sqrt(x)
  46. def isqrt_i(x):
  47. """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
  48. if x==0: return 0
  49. gen = x.parent(-1)
  50. while is_square(gen): gen = sqrt(gen)
  51. if is_square(x): return True,1/sqrt(x)
  52. else: return False,1/sqrt(x*gen)
  53. class QuotientEdwardsPoint(object):
  54. """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
  55. def __init__(self,x=0,y=1):
  56. x = self.x = self.F(x)
  57. y = self.y = self.F(y)
  58. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  59. raise NotOnCurveException(str(self))
  60. def __repr__(self):
  61. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  62. def __iter__(self):
  63. yield self.x
  64. yield self.y
  65. def __add__(self,other):
  66. x,y = self
  67. X,Y = other
  68. a,d = self.a,self.d
  69. return self.__class__(
  70. (x*Y+y*X)/(1+d*x*y*X*Y),
  71. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  72. )
  73. def __neg__(self): return self.__class__(-self.x,self.y)
  74. def __sub__(self,other): return self + (-other)
  75. def __rmul__(self,other): return self*other
  76. def __eq__(self,other):
  77. """NB: this is the only method that is different from the usual one"""
  78. x,y = self
  79. X,Y = other
  80. return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
  81. def __ne__(self,other): return not (self==other)
  82. def __mul__(self,exp):
  83. exp = int(exp)
  84. if exp < 0: exp,self = -exp,-self
  85. total = self.__class__()
  86. work = self
  87. while exp != 0:
  88. if exp & 1: total += work
  89. work += work
  90. exp >>= 1
  91. return total
  92. def xyzt(self):
  93. x,y = self
  94. z = self.F.random_element()
  95. return x*z,y*z,z,x*y*z
  96. def torque(self):
  97. """Apply cofactor group, except keeping the point even"""
  98. if self.cofactor == 8:
  99. if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i)
  100. if self.a == 1: return self.__class__(-self.y, self.x)
  101. else:
  102. return self.__class__(-self.x, -self.y)
  103. # Utility functions
  104. @classmethod
  105. def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
  106. """Convert little-endian bytes to field element, sanity check length"""
  107. if len(bytes) != cls.encLen:
  108. raise InvalidEncodingException("wrong length %d" % len(bytes))
  109. s = dec_le(bytes)
  110. if mustBeProper and s >= cls.F.modulus():
  111. raise InvalidEncodingException("%d out of range!" % s)
  112. s = cls.F(s)
  113. if mustBePositive and negative(s):
  114. raise InvalidEncodingException("%d is negative!" % s)
  115. return s
  116. @classmethod
  117. def gfToBytes(cls,x,mustBePositive=False):
  118. """Convert little-endian bytes to field element, sanity check length"""
  119. if negative(x) and mustBePositive: x = -x
  120. return enc_le(x,cls.encLen)
  121. class RistrettoPoint(QuotientEdwardsPoint):
  122. """The new Ristretto group"""
  123. def encodeSpec(self):
  124. """Unoptimized specification for encoding"""
  125. x,y = self
  126. if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque()
  127. if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
  128. if negative(x): x,y = -x,-y
  129. s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self)))
  130. return self.gfToBytes(s)
  131. @classmethod
  132. def decodeSpec(cls,s):
  133. """Unoptimized specification for decoding"""
  134. s = cls.bytesToGf(s,mustBePositive=True)
  135. a,d = cls.a,cls.d
  136. x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
  137. y = (1+a*s^2) / (1-a*s^2)
  138. if cls.cofactor==8 and (negative(x*y) or y==0):
  139. raise InvalidEncodingException("x*y has high bit")
  140. return cls(x,y)
  141. @optimized_version_of("encodeSpec")
  142. def encode(self):
  143. """Encode, optimized version"""
  144. a,d = self.a,self.d
  145. x,y,z,t = self.xyzt()
  146. if self.cofactor==8:
  147. u1 = a*(y+z)*(y-z)
  148. u2 = x*y # = t*z
  149. isr = isqrt(u1*u2^2)
  150. i1 = isr*u1
  151. i2 = isr*u2
  152. z_inv = i1*i2*t
  153. if self.cofactor==8 and negative(t*z_inv):
  154. if a==-1: x,y = y*self.i,x*self.i
  155. else: x,y = -y,x # TODO: test
  156. den_inv = self.magic * i1
  157. else:
  158. den_inv = i2
  159. if negative(x*z_inv): y = -y
  160. s = (z-y) * den_inv
  161. else:
  162. u1 = a*(y+z)*(y-z)
  163. u2 = x*y # = t*z
  164. isr = isqrt(u1*u2^2)
  165. i1 = isr*u1
  166. i2 = isr*u2
  167. z_inv = i1*i2*t
  168. den_inv = i2
  169. if negative(x*z_inv): y = -y
  170. s = (z-y) * den_inv
  171. return self.gfToBytes(s,mustBePositive=True)
  172. @classmethod
  173. @optimized_version_of("decodeSpec")
  174. def decode(cls,s):
  175. """Decode, optimized version"""
  176. s = cls.bytesToGf(s,mustBePositive=True)
  177. a,d = cls.a,cls.d
  178. yden = 1-a*s^2
  179. ynum = 1+a*s^2
  180. yden_sqr = yden^2
  181. xden_sqr = a*d*ynum^2 - yden_sqr
  182. isr = isqrt(xden_sqr * yden_sqr)
  183. xden_inv = isr * yden
  184. yden_inv = xden_inv * isr * xden_sqr
  185. x = 2*s*xden_inv
  186. if negative(x): x = -x
  187. y = ynum * yden_inv
  188. if cls.cofactor==8 and (negative(x*y) or y==0):
  189. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  190. return cls(x,y)
  191. @classmethod
  192. def fromJacobiQuartic(cls,s,t,sgn=1):
  193. """Convert point from its Jacobi Quartic representation"""
  194. a,d = cls.a,cls.d
  195. assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
  196. x = 2*s*cls.magic / t
  197. if negative(x): x = -x # TODO: doesn't work without resolving x
  198. y = (1+a*s^2) / (1-a*s^2)
  199. return cls(sgn*x,y)
  200. @classmethod
  201. def elligatorSpec(cls,r0):
  202. a,d = cls.a,cls.d
  203. r = cls.qnr * cls.bytesToGf(r0)^2
  204. den = (d*r-a)*(a*r-d)
  205. n1 = cls.a*(r+1)*(a+d)*(d-a)/den
  206. n2 = r*n1
  207. if is_square(n1):
  208. sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
  209. else:
  210. sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
  211. return cls.fromJacobiQuartic(s,t,sgn)
  212. @classmethod
  213. @optimized_version_of("elligatorSpec")
  214. def elligator(cls,r0):
  215. a,d = cls.a,cls.d
  216. r0 = cls.bytesToGf(r0)
  217. r = cls.qnr * r0^2
  218. den = (d*r-a)*(a*r-d)
  219. num = cls.a*(r+1)*(a+d)*(d-a)
  220. iss,isri = isqrt_i(num*den)
  221. if iss: sgn,twiddle = 1,1
  222. else: sgn,twiddle = -1,r0*cls.qnr
  223. isri *= twiddle
  224. s = isri*num
  225. t = isri*s*(r-1)*(d+a)^2 + sgn
  226. return cls.fromJacobiQuartic(s,t,sgn)
  227. class Decaf_1_1_Point(QuotientEdwardsPoint):
  228. """Like current decaf but tweaked for simplicity"""
  229. def encodeSpec(self):
  230. """Unoptimized specification for encoding"""
  231. a,d = self.a,self.d
  232. x,y = self
  233. if x==0 or y==0: return(self.gfToBytes(0))
  234. if self.cofactor==8 and negative(x*y*self.isoMagic):
  235. x,y = self.torque()
  236. isr2 = isqrt(a*(y^2-1)) * sqrt(a*d-1)
  237. sr = xsqrt(1-a*x^2)
  238. assert sr in [isr2*x*y,-isr2*x*y]
  239. altx = 1/isr2*self.isoMagic
  240. if negative(altx): s = (1+x*y*isr2)/(a*x)
  241. else: s = (1-x*y*isr2)/(a*x)
  242. return self.gfToBytes(s,mustBePositive=True)
  243. @classmethod
  244. def decodeSpec(cls,s):
  245. """Unoptimized specification for decoding"""
  246. a,d = cls.a,cls.d
  247. s = cls.bytesToGf(s,mustBePositive=True)
  248. if s==0: return cls()
  249. isr = isqrt(s^4 + 2*(a-2*d)*s^2 + 1)
  250. altx = 2*s*isr*cls.isoMagic
  251. if negative(altx): isr = -isr
  252. x = 2*s / (1+a*s^2)
  253. y = (1-a*s^2) * isr
  254. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  255. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  256. return cls(x,y)
  257. @optimized_version_of("encodeSpec")
  258. def encode(self):
  259. """Encode, optimized version"""
  260. a,d = self.a,self.d
  261. x,y,z,t = self.xyzt()
  262. if self.cofactor == 8:
  263. num = (z+y)*(z-y)
  264. den = x*y
  265. tmp = isqrt(num*(a-d)*den^2)
  266. if negative(tmp^2*den*num*(a-d)*t^2*self.isoMagic):
  267. den,num = num,den
  268. tmp *= sqrt(a-d) # witness that cofactor is 8
  269. yisr = x*sqrt(a)
  270. toggle = (a==1)
  271. else:
  272. yisr = y*(a*d-1)
  273. toggle = False
  274. tiisr = tmp*num
  275. altx = tiisr*t*self.isoMagic
  276. if negative(altx) != toggle: tiisr =- tiisr
  277. s = tmp*den*yisr*(tiisr*z - 1)
  278. else:
  279. num = (x+t)*(x-t)
  280. tmp = isqrt(num*(a-d)*x^2)
  281. ratio = tmp*num
  282. if negative(ratio*self.isoMagic): ratio=-ratio
  283. s = (a-d)*x*tmp*(z*ratio - t)
  284. return self.gfToBytes(s,mustBePositive=True)
  285. @classmethod
  286. @optimized_version_of("decodeSpec")
  287. def decode(cls,s):
  288. """Decode, optimized version"""
  289. return cls.decodeSpec(s) # TODO
  290. class Ed25519Point(RistrettoPoint):
  291. F = GF(2^255-19)
  292. d = F(-121665/121666)
  293. a = F(-1)
  294. i = sqrt(F(-1))
  295. qnr = i
  296. magic = isqrt(a*d-1)
  297. cofactor = 8
  298. encLen = 32
  299. @classmethod
  300. def base(cls):
  301. y = cls.F(4/5)
  302. x = sqrt((y^2-1)/(cls.d*y^2+1))
  303. if lobit(x): x = -x
  304. return cls(x,y)
  305. class IsoEd448Point(RistrettoPoint):
  306. F = GF(2^448-2^224-1)
  307. d = F(39082/39081)
  308. a = F(1)
  309. qnr = -1
  310. magic = isqrt(a*d-1)
  311. cofactor = 4
  312. encLen = 56
  313. @classmethod
  314. def base(cls):
  315. # = ..., -3/2
  316. return cls.decodeSpec(bytearray(binascii.unhexlify(
  317. "00000000000000000000000000000000000000000000000000000000"+
  318. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  319. class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
  320. F = GF(2^448-2^224-1)
  321. d = F(-39082)
  322. a = F(-1)
  323. qnr = -1
  324. magic = isqrt(a*d-1)
  325. cofactor = 4
  326. encLen = 56
  327. isoMagic = IsoEd448Point.magic
  328. @classmethod
  329. def base(cls):
  330. return cls.decodeSpec(bytearray(binascii.unhexlify(
  331. "00000000000000000000000000000000000000000000000000000000"+
  332. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  333. class Ed448GoldilocksPoint(Decaf_1_1_Point):
  334. F = GF(2^448-2^224-1)
  335. d = F(-39081)
  336. a = F(1)
  337. qnr = -1
  338. magic = isqrt(a*d-1)
  339. cofactor = 4
  340. encLen = 56
  341. isoMagic = IsoEd448Point.magic
  342. @classmethod
  343. def base(cls):
  344. return cls.decodeSpec(bytearray(binascii.unhexlify(
  345. "00000000000000000000000000000000000000000000000000000000"+
  346. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  347. class IsoEd25519Point(Decaf_1_1_Point):
  348. # TODO: twisted iso too!
  349. # TODO: twisted iso might have to IMAGINE_TWIST or whatever
  350. F = GF(2^255-19)
  351. d = F(-121665)
  352. a = F(1)
  353. i = sqrt(F(-1))
  354. qnr = i
  355. magic = isqrt(a*d-1)
  356. cofactor = 8
  357. encLen = 32
  358. isoMagic = Ed25519Point.magic
  359. isoA = Ed25519Point.a
  360. @classmethod
  361. def base(cls):
  362. return cls.decodeSpec(Ed25519Point.base().encode())
  363. class TestFailedException(Exception): pass
  364. def test(cls,n):
  365. # TODO: test corner cases like 0,1,i
  366. P = cls.base()
  367. Q = cls()
  368. for i in xrange(n):
  369. #print binascii.hexlify(Q.encode())
  370. QQ = cls.decode(Q.encode())
  371. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  372. QT = Q
  373. QE = Q.encode()
  374. for h in xrange(cls.cofactor):
  375. QT = QT.torque()
  376. if QT.encode() != QE:
  377. raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
  378. Q0 = Q + P
  379. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  380. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  381. r = randint(1,1000)
  382. Q1 = Q0*r
  383. Q2 = Q0*(r+1)
  384. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  385. Q = Q1
  386. test(Ed25519Point,100)
  387. test(IsoEd25519Point,100)
  388. test(IsoEd448Point,100)
  389. test(TwistedEd448GoldilocksPoint,100)
  390. test(Ed448GoldilocksPoint,100)
  391. def gangtest(classes,n):
  392. for i in xrange(n):
  393. rets = [bytes((cls.base()*i).encode()) for cls in classes]
  394. if len(set(rets)) != 1:
  395. print "Divergence at %d" % i
  396. for c,ret in zip(classes,rets):
  397. print c,binascii.hexlify(ret)
  398. print
  399. gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
  400. gangtest([Ed25519Point,IsoEd25519Point],100)
  401. def testElligator(cls,n):
  402. for i in xrange(n):
  403. cls.elligator(randombytes(cls.encLen))
  404. testElligator(Ed25519Point,100)
  405. testElligator(IsoEd448Point,100)
  406. # testElligator(Ed448GoldilocksPoint,100)
  407. # testElligator(TwistedEd448GoldilocksPoint,100)