You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

880 lines
29 KiB

  1. import binascii
  2. class InvalidEncodingException(Exception): pass
  3. class NotOnCurveException(Exception): pass
  4. class SpecException(Exception): pass
  5. def lobit(x): return int(x) & 1
  6. def hibit(x): return lobit(2*x)
  7. def negative(x): return lobit(x)
  8. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in range(n)])
  9. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  10. def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
  11. def optimized_version_of(spec):
  12. """Decorator: This function is an optimized version of some specification"""
  13. def decorator(f):
  14. def wrapper(self,*args,**kwargs):
  15. def pr(x):
  16. if isinstance(x,bytearray): return binascii.hexlify(x)
  17. else: return str(x)
  18. try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
  19. except Exception as e: spec_ans = None,e
  20. try: opt_ans = f(self,*args,**kwargs),None
  21. except Exception as e: opt_ans = None,e
  22. if spec_ans[1] is None and opt_ans[1] is not None:
  23. raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
  24. % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
  25. if spec_ans[1] is not None and opt_ans[1] is None:
  26. raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
  27. % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
  28. if spec_ans[0] != opt_ans[0]:
  29. raise SpecException("Mismatch in %s: %s != %s"
  30. % (f.__name__,pr(spec_ans[0]),pr(opt_ans[0])))
  31. if opt_ans[1] is not None: raise opt_ans[1]
  32. else: return opt_ans[0]
  33. wrapper.__name__ = f.__name__
  34. return wrapper
  35. return decorator
  36. def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
  37. """Return sqrt(x)"""
  38. if not is_square(x): raise exn
  39. s = sqrt(x)
  40. if negative(s): s=-s
  41. return s
  42. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  43. """Return 1/sqrt(x)"""
  44. if x==0: return 0
  45. if not is_square(x): raise exn
  46. s = sqrt(x)
  47. #if negative(s): s=-s
  48. return 1/s
  49. def inv0(x): return 1/x if x != 0 else 0
  50. def isqrt_i(x):
  51. """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
  52. if x==0: return True,0
  53. gen = x.parent(-1)
  54. while is_square(gen): gen = sqrt(gen)
  55. if is_square(x): return True,1/sqrt(x)
  56. else: return False,1/sqrt(x*gen)
  57. class QuotientEdwardsPoint(object):
  58. """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
  59. def __init__(self,x=0,y=1):
  60. x = self.x = self.F(x)
  61. y = self.y = self.F(y)
  62. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  63. raise NotOnCurveException(str(self))
  64. def __repr__(self):
  65. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  66. def __iter__(self):
  67. yield self.x
  68. yield self.y
  69. def __add__(self,other):
  70. x,y = self
  71. X,Y = other
  72. a,d = self.a,self.d
  73. return self.__class__(
  74. (x*Y+y*X)/(1+d*x*y*X*Y),
  75. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  76. )
  77. def __neg__(self): return self.__class__(-self.x,self.y)
  78. def __sub__(self,other): return self + (-other)
  79. def __rmul__(self,other): return self*other
  80. def __eq__(self,other):
  81. """NB: this is the only method that is different from the usual one"""
  82. x,y = self
  83. X,Y = other
  84. return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
  85. def __ne__(self,other): return not (self==other)
  86. def __mul__(self,exp):
  87. exp = int(exp)
  88. if exp < 0: exp,self = -exp,-self
  89. total = self.__class__()
  90. work = self
  91. while exp != 0:
  92. if exp & 1: total += work
  93. work += work
  94. exp >>= 1
  95. return total
  96. def xyzt(self):
  97. x,y = self
  98. z = self.F.random_element()
  99. return x*z,y*z,z,x*y*z
  100. def torque(self):
  101. """Apply cofactor group, except keeping the point even"""
  102. if self.cofactor == 8:
  103. if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i)
  104. if self.a == 1: return self.__class__(-self.y, self.x)
  105. else:
  106. return self.__class__(-self.x, -self.y)
  107. def doubleAndEncodeSpec(self):
  108. return (self+self).encode()
  109. # Utility functions
  110. @classmethod
  111. def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False,maskHiBits=False):
  112. """Convert little-endian bytes to field element, sanity check length"""
  113. if len(bytes) != cls.encLen and mustBeProper:
  114. raise InvalidEncodingException("wrong length %d" % len(bytes))
  115. s = dec_le(bytes)
  116. if mustBeProper and s >= cls.F.order():
  117. raise InvalidEncodingException("%d out of range!" % s)
  118. bitlen = int(ceil(N(log(cls.F.order(),2.))))
  119. if maskHiBits: s &= 2^bitlen-1
  120. s = cls.F(s)
  121. if mustBePositive and negative(s):
  122. raise InvalidEncodingException("%d is negative!" % s)
  123. return s
  124. @classmethod
  125. def gfToBytes(cls,x,mustBePositive=False):
  126. """Convert little-endian bytes to field element, sanity check length"""
  127. if negative(x) and mustBePositive: x = -x
  128. return enc_le(x,cls.encLen)
  129. class RistrettoPoint(QuotientEdwardsPoint):
  130. """The new Ristretto group"""
  131. def encodeSpec(self):
  132. """Unoptimized specification for encoding"""
  133. x,y = self
  134. if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque()
  135. if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
  136. if negative(x): x,y = -x,-y
  137. s = xsqrt(self.mneg*(1-y)/(1+y),exn=Exception("Unimplemented: point is odd: " + str(self)))
  138. return self.gfToBytes(s)
  139. @classmethod
  140. def decodeSpec(cls,s):
  141. """Unoptimized specification for decoding"""
  142. s = cls.bytesToGf(s,mustBePositive=True)
  143. a,d = cls.a,cls.d
  144. x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
  145. y = (1+a*s^2) / (1-a*s^2)
  146. if cls.cofactor==8 and (negative(x*y) or y==0):
  147. raise InvalidEncodingException("x*y has high bit")
  148. return cls(x,y)
  149. @optimized_version_of("encodeSpec")
  150. def encode(self):
  151. """Encode, optimized version"""
  152. a,d,mneg = self.a,self.d,self.mneg
  153. x,y,z,t = self.xyzt()
  154. if self.cofactor==8:
  155. u1 = mneg*(z+y)*(z-y)
  156. u2 = x*y # = t*z
  157. isr = isqrt(u1*u2^2)
  158. i1 = isr*u1 # sqrt(mneg*(z+y)*(z-y))/(x*y)
  159. i2 = isr*u2 # 1/sqrt(a*(y+z)*(y-z))
  160. z_inv = i1*i2*t # 1/z
  161. if negative(t*z_inv):
  162. if a==-1:
  163. x,y = y*self.i,x*self.i
  164. den_inv = self.magic * i1
  165. else:
  166. x,y = -y,x
  167. den_inv = self.i * self.magic * i1
  168. else:
  169. den_inv = i2
  170. if negative(x*z_inv): y = -y
  171. s = (z-y) * den_inv
  172. else:
  173. num = mneg*(z+y)*(z-y)
  174. isr = isqrt(num*y^2)
  175. if negative(isr^2*num*y*t): y = -y
  176. s = isr*y*(z-y)
  177. return self.gfToBytes(s,mustBePositive=True)
  178. @optimized_version_of("doubleAndEncodeSpec")
  179. def doubleAndEncode(self):
  180. X,Y,Z,T = self.xyzt()
  181. a,d,mneg = self.a,self.d,self.mneg
  182. if self.cofactor==8:
  183. e = 2*X*Y
  184. f = Z^2+d*T^2
  185. g = Y^2-a*X^2
  186. h = Z^2-d*T^2
  187. inv1 = inv0(e*f*g*h)
  188. z_inv = inv1*e*g # 1 / (f*h)
  189. t_inv = inv1*f*h
  190. if negative(e*g*z_inv):
  191. if a==-1: sqrta = self.i
  192. else: sqrta = -1
  193. e,f,g,h = g,h,-e,f*sqrta
  194. factor = self.i
  195. else:
  196. factor = self.magic
  197. if negative(h*e*z_inv): g=-g
  198. s = (h-g)*factor*g*t_inv
  199. else:
  200. foo = Y^2+a*X^2
  201. bar = X*Y
  202. den = inv0(foo*bar)
  203. if negative(2*bar^2*den): tmp = a*X^2
  204. else: tmp = Y^2
  205. s = self.magic*(Z^2-tmp)*foo*den
  206. return self.gfToBytes(s,mustBePositive=True)
  207. @classmethod
  208. @optimized_version_of("decodeSpec")
  209. def decode(cls,s):
  210. """Decode, optimized version"""
  211. s = cls.bytesToGf(s,mustBePositive=True)
  212. a,d = cls.a,cls.d
  213. yden = 1-a*s^2
  214. ynum = 1+a*s^2
  215. yden_sqr = yden^2
  216. xden_sqr = a*d*ynum^2 - yden_sqr
  217. isr = isqrt(xden_sqr * yden_sqr)
  218. xden_inv = isr * yden
  219. yden_inv = xden_inv * isr * xden_sqr
  220. x = 2*s*xden_inv
  221. if negative(x): x = -x
  222. y = ynum * yden_inv
  223. if cls.cofactor==8 and (negative(x*y) or y==0):
  224. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  225. return cls(x,y)
  226. @classmethod
  227. def fromJacobiQuartic(cls,s,t,sgn=1):
  228. """Convert point from its Jacobi Quartic representation"""
  229. a,d = cls.a,cls.d
  230. assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
  231. x = 2*s*cls.magic / t
  232. y = (1+a*s^2) / (1-a*s^2)
  233. return cls(sgn*x,y)
  234. @classmethod
  235. def elligatorSpec(cls,r0):
  236. a,d = cls.a,cls.d
  237. r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2
  238. den = (d*r-a)*(a*r-d)
  239. if den == 0: return cls()
  240. n1 = cls.a*(r+1)*(a+d)*(d-a)/den
  241. n2 = r*n1
  242. if is_square(n1):
  243. sgn,s,t = 1, xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
  244. else:
  245. sgn,s,t = -1,-xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
  246. return cls.fromJacobiQuartic(s,t)
  247. @classmethod
  248. @optimized_version_of("elligatorSpec")
  249. def elligator(cls,r0):
  250. a,d = cls.a,cls.d
  251. r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
  252. r = cls.qnr * r0^2
  253. den = (d*r-a)*(a*r-d)
  254. num = cls.a*(r+1)*(a+d)*(d-a)
  255. iss,isri = isqrt_i(num*den)
  256. if iss: sgn,twiddle = 1,1
  257. else: sgn,twiddle = -1,r0*cls.qnr
  258. isri *= twiddle
  259. s = isri*num
  260. t = -sgn*isri*s*(r-1)*(d+a)^2 - 1
  261. if negative(s) == iss: s = -s
  262. return cls.fromJacobiQuartic(s,t)
  263. class Decaf_1_1_Point(QuotientEdwardsPoint):
  264. """Like current decaf but tweaked for compatibility with Ristretto"""
  265. def encodeSpec(self):
  266. """Unoptimized specification for encoding"""
  267. a,d = self.a,self.d
  268. x,y = self
  269. if x==0 or y==0: return(self.gfToBytes(0))
  270. if self.cofactor==8 and negative(x*y*self.isoMagic):
  271. x,y = self.torque()
  272. sr = xsqrt(1-a*x^2)
  273. altx = x*y*self.isoMagic / sr
  274. if negative(altx): s = (1+sr)/x
  275. else: s = (1-sr)/x
  276. return self.gfToBytes(s,mustBePositive=True)
  277. @classmethod
  278. def decodeSpec(cls,s):
  279. """Unoptimized specification for decoding"""
  280. a,d = cls.a,cls.d
  281. s = cls.bytesToGf(s,mustBePositive=True)
  282. if s==0: return cls()
  283. t = xsqrt(s^4 + 2*(a-2*d)*s^2 + 1)
  284. altx = 2*s*cls.isoMagic/t
  285. if negative(altx): t = -t
  286. x = 2*s / (1+a*s^2)
  287. y = (1-a*s^2) / t
  288. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  289. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  290. return cls(x,y)
  291. def toJacobiQuartic(self,toggle_rotation=False,toggle_altx=False,toggle_s=False):
  292. "Return s,t on jacobi curve"
  293. a,d = self.a,self.d
  294. x,y,z,t = self.xyzt()
  295. if self.cofactor == 8:
  296. # Cofactor 8 version
  297. # Simulate IMAGINE_TWIST because that's how libdecaf does it
  298. x = self.i*x
  299. t = self.i*t
  300. a = -a
  301. d = -d
  302. # OK, the actual libdecaf code should be here
  303. num = (z+y)*(z-y)
  304. den = x*y
  305. isr = isqrt(num*(a-d)*den^2)
  306. iden = isr * den * self.isoMagic # 1/sqrt((z+y)(z-y)) = 1/sqrt(1-Y^2) / z
  307. inum = isr * num # sqrt(1-Y^2) * z / xysqrt(a-d) ~ 1/sqrt(1-ax^2)/z
  308. if negative(iden*inum*self.i*t^2*(d-a)) != toggle_rotation:
  309. iden,inum = inum,iden
  310. fac = x*sqrt(a)
  311. toggle=(a==-1)
  312. else:
  313. fac = y
  314. toggle=False
  315. imi = self.isoMagic * self.i
  316. altx = inum*t*imi
  317. neg_altx = negative(altx) != toggle_altx
  318. if neg_altx != toggle: inum =- inum
  319. tmp = fac*(inum*z + 1)
  320. s = iden*tmp*imi
  321. negm1 = (negative(s) != toggle_s) != neg_altx
  322. if negm1: m1 = a*fac + z
  323. else: m1 = a*fac - z
  324. swap = toggle_s
  325. else:
  326. # Much simpler cofactor 4 version
  327. num = (x+t)*(x-t)
  328. isr = isqrt(num*(a-d)*x^2)
  329. ratio = isr*num
  330. altx = ratio*self.isoMagic
  331. neg_altx = negative(altx) != toggle_altx
  332. if neg_altx: ratio =- ratio
  333. tmp = ratio*z - t
  334. s = (a-d)*isr*x*tmp
  335. negx = (negative(s) != toggle_s) != neg_altx
  336. if negx: m1 = -a*t + x
  337. else: m1 = -a*t - x
  338. swap = toggle_s
  339. if negative(s): s = -s
  340. return s,m1,a*tmp,swap
  341. def invertElligator(self,toggle_r=False,*args,**kwargs):
  342. "Produce preimage of self under elligator, or None"
  343. a,d = self.a,self.d
  344. rets = []
  345. tr = [False,True] if self.cofactor == 8 else [False]
  346. for toggle_rotation in tr:
  347. for toggle_altx in [False,True]:
  348. for toggle_s in [False,True]:
  349. for toggle_r in [False,True]:
  350. s,m1,m12,swap = self.toJacobiQuartic(toggle_rotation,toggle_altx,toggle_s)
  351. #print
  352. #print toggle_rotation,toggle_altx,toggle_s
  353. #print m1
  354. #print m12
  355. if self == self.__class__():
  356. if self.cofactor == 4:
  357. # Hacks for identity!
  358. if toggle_altx: m12 = 1
  359. elif toggle_s: m1 = 1
  360. elif toggle_r: continue
  361. ## BOTH???
  362. else:
  363. m12 = 1
  364. imi = self.isoMagic * self.i
  365. if toggle_rotation:
  366. if toggle_altx: m1 = -imi
  367. else: m1 = +imi
  368. else:
  369. if toggle_altx: m1 = 0
  370. else: m1 = a-d
  371. rnum = (d*a*m12-m1)
  372. rden = ((d*a-1)*m12+m1)
  373. if swap: rnum,rden = rden,rnum
  374. ok,sr = isqrt_i(rnum*rden*self.qnr)
  375. if not ok: continue
  376. sr *= rnum
  377. #print "Works! %d %x" % (swap,sr)
  378. if negative(sr) != toggle_r: sr = -sr
  379. ret = self.gfToBytes(sr)
  380. if self.elligator(ret) != self and self.elligator(ret) != -self:
  381. print ("WRONG!",[toggle_rotation,toggle_altx,toggle_s])
  382. if self.elligator(ret) == -self and self != -self: print ("Negated!",[toggle_rotation,toggle_altx,toggle_s])
  383. rets.append(bytes(ret))
  384. return rets
  385. @optimized_version_of("encodeSpec")
  386. def encode(self):
  387. """Encode, optimized version"""
  388. return self.gfToBytes(self.toJacobiQuartic()[0])
  389. @classmethod
  390. @optimized_version_of("decodeSpec")
  391. def decode(cls,s):
  392. """Decode, optimized version"""
  393. a,d = cls.a,cls.d
  394. s = cls.bytesToGf(s,mustBePositive=True)
  395. #if s==0: return cls()
  396. s2 = s^2
  397. den = 1+a*s2
  398. num = den^2 - 4*d*s2
  399. isr = isqrt(num*den^2)
  400. altx = 2*s*isr*den*cls.isoMagic
  401. if negative(altx): isr = -isr
  402. x = 2*s *isr^2*den*num
  403. y = (1-a*s2) * isr*den
  404. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  405. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  406. return cls(x,y)
  407. @classmethod
  408. def fromJacobiQuartic(cls,s,t,sgn=1):
  409. """Convert point from its Jacobi Quartic representation"""
  410. a,d = cls.a,cls.d
  411. if s==0: return cls()
  412. x = 2*s / (1+a*s^2)
  413. y = (1-a*s^2) / t
  414. return cls(x,sgn*y)
  415. @optimized_version_of("doubleAndEncodeSpec")
  416. def doubleAndEncode(self):
  417. X,Y,Z,T = self.xyzt()
  418. a,d = self.a,self.d
  419. if self.cofactor == 8:
  420. # Cofactor 8 version
  421. # Simulate IMAGINE_TWIST because that's how libdecaf does it
  422. X = self.i*X
  423. T = self.i*T
  424. a = -a
  425. d = -d
  426. # TODO: This is only being called for a=-1, so could
  427. # be wrong for a=1
  428. e = 2*X*Y
  429. f = Y^2+a*X^2
  430. g = Y^2-a*X^2
  431. h = Z^2-d*T^2
  432. eim = e*self.isoMagic
  433. inv = inv0(eim*g*f*h)
  434. fh_inv = eim*g*inv*self.i
  435. if negative(eim*g*fh_inv):
  436. idf = g*self.isoMagic*self.i
  437. bar = f
  438. foo = g
  439. test = eim*f
  440. else:
  441. idf = eim
  442. bar = h
  443. foo = -eim
  444. test = g*h
  445. if negative(test*fh_inv): bar =- bar
  446. s = idf*(foo+bar)*inv*f*h
  447. else:
  448. xy = X*Y
  449. h = Z^2-d*T^2
  450. inv = inv0(xy*h)
  451. if negative(inv*2*xy^2*self.isoMagic): tmp = Y
  452. else: tmp = X
  453. s = tmp^2*h*inv # = X/Y or Y/X, interestingly
  454. return self.gfToBytes(s,mustBePositive=True)
  455. @classmethod
  456. def elligatorSpec(cls,r0,fromR=False):
  457. a,d = cls.a,cls.d
  458. if fromR: r = r0
  459. else: r = cls.qnr * cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)^2
  460. den = (d*r-(d-a))*((d-a)*r-d)
  461. if den == 0: return cls()
  462. n1 = (r+1)*(a-2*d)/den
  463. n2 = r*n1
  464. if is_square(n1):
  465. sgn,s,t = 1, xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1
  466. else:
  467. sgn,s,t = -1, -xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1
  468. return cls.fromJacobiQuartic(s,t)
  469. @classmethod
  470. @optimized_version_of("elligatorSpec")
  471. def elligator(cls,r0):
  472. a,d = cls.a,cls.d
  473. r0 = cls.bytesToGf(r0,mustBeProper=False,maskHiBits=True)
  474. r = cls.qnr * r0^2
  475. den = (d*r-(d-a))*((d-a)*r-d)
  476. num = (r+1)*(a-2*d)
  477. iss,isri = isqrt_i(num*den)
  478. if iss: sgn,twiddle = 1,1
  479. else: sgn,twiddle = -1,r0*cls.qnr
  480. isri *= twiddle
  481. s = isri*num
  482. t = -sgn*isri*s*(r-1)*(a-2*d)^2 - 1
  483. if negative(s) == iss: s = -s
  484. return cls.fromJacobiQuartic(s,t)
  485. def elligatorInverseBruteForce(self):
  486. """Invert Elligator using SAGE's polynomial solver"""
  487. a,d = self.a,self.d
  488. R.<r0> = self.F[]
  489. r = self.qnr * r0^2
  490. den = (d*r-(d-a))*((d-a)*r-d)
  491. n1 = (r+1)*(a-2*d)/den
  492. n2 = r*n1
  493. ret = set()
  494. for s2,t in [(n1, -(r-1)*(a-2*d)^2 / den - 1),
  495. (n2,r*(r-1)*(a-2*d)^2 / den - 1)]:
  496. x2 = 4*s2/(1+a*s2)^2
  497. y = (1-a*s2) / t
  498. selfT = self
  499. for i in range(self.cofactor/2):
  500. xT,yT = selfT
  501. polyX = xT^2-x2
  502. polyY = yT-y
  503. sx = set(r for r,_ in polyX.numerator().roots())
  504. sy = set(r for r,_ in polyY.numerator().roots())
  505. ret = ret.union(sx.intersection(sy))
  506. selfT = selfT.torque()
  507. ret = [self.gfToBytes(r) for r in ret]
  508. for r in ret:
  509. assert self.elligator(r) in [self,-self]
  510. ret = [r for r in ret if self.elligator(r) == self]
  511. return ret
  512. class Ed25519Point(RistrettoPoint):
  513. F = GF(2^255-19)
  514. d = F(-121665/121666)
  515. a = F(-1)
  516. i = sqrt(F(-1))
  517. mneg = F(1)
  518. qnr = i
  519. magic = isqrt(a*d-1)
  520. cofactor = 8
  521. encLen = 32
  522. @classmethod
  523. def base(cls):
  524. return cls( 15112221349535400772501151409588531511454012693041857206046113283949847762202, 46316835694926478169428394003475163141307993866256225615783033603165251855960
  525. )
  526. class NegEd25519Point(RistrettoPoint):
  527. F = GF(2^255-19)
  528. d = F(121665/121666)
  529. a = F(1)
  530. i = sqrt(F(-1))
  531. mneg = F(-1) # TODO checkme vs 1-ad or whatever
  532. qnr = i
  533. magic = isqrt(a*d-1)
  534. cofactor = 8
  535. encLen = 32
  536. @classmethod
  537. def base(cls):
  538. y = cls.F(4/5)
  539. x = sqrt((y^2-1)/(cls.d*y^2-cls.a))
  540. if negative(x): x = -x
  541. return cls(x,y)
  542. class IsoEd448Point(RistrettoPoint):
  543. F = GF(2^448-2^224-1)
  544. d = F(39082/39081)
  545. a = F(1)
  546. mneg = F(-1)
  547. qnr = -1
  548. magic = isqrt(a*d-1)
  549. cofactor = 4
  550. encLen = 56
  551. @classmethod
  552. def base(cls):
  553. return cls( # RFC has it wrong
  554. 345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
  555. -363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
  556. )
  557. class Ed448RistrettoPoint(RistrettoPoint):
  558. F = GF(2^448-2^224-1)
  559. d = F(-39081)
  560. a = F(1)
  561. mneg = F(-1)
  562. qnr = -1
  563. magic = isqrt(a*d-1)
  564. cofactor = 4
  565. encLen = 56
  566. @classmethod
  567. def base(cls):
  568. return cls(
  569. 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
  570. )
  571. class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
  572. F = GF(2^448-2^224-1)
  573. d = F(-39082)
  574. a = F(-1)
  575. qnr = -1
  576. cofactor = 4
  577. encLen = 56
  578. isoMagic = IsoEd448Point.magic
  579. @classmethod
  580. def base(cls):
  581. return cls.decodeSpec(Ed448GoldilocksPoint.base().encodeSpec())
  582. class Ed448GoldilocksPoint(Decaf_1_1_Point):
  583. F = GF(2^448-2^224-1)
  584. d = F(-39081)
  585. a = F(1)
  586. qnr = -1
  587. cofactor = 4
  588. encLen = 56
  589. isoMagic = IsoEd448Point.magic
  590. @classmethod
  591. def base(cls):
  592. return 2*cls(
  593. 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
  594. )
  595. class IsoEd25519Point(Decaf_1_1_Point):
  596. # TODO: twisted iso too!
  597. # TODO: twisted iso might have to IMAGINE_TWIST or whatever
  598. F = GF(2^255-19)
  599. d = F(-121665)
  600. a = F(1)
  601. i = sqrt(F(-1))
  602. qnr = i
  603. magic = isqrt(a*d-1)
  604. cofactor = 8
  605. encLen = 32
  606. isoMagic = Ed25519Point.magic
  607. isoA = Ed25519Point.a
  608. @classmethod
  609. def base(cls):
  610. return cls.decodeSpec(Ed25519Point.base().encode())
  611. class TestFailedException(Exception): pass
  612. def test(cls,n):
  613. print ("Testing curve %s" % cls.__name__)
  614. specials = [1]
  615. ii = cls.F(-1)
  616. while is_square(ii):
  617. specials.append(ii)
  618. ii = sqrt(ii)
  619. specials.append(ii)
  620. for i in specials:
  621. if negative(cls.F(i)): i = -i
  622. i = enc_le(i,cls.encLen)
  623. try:
  624. Q = cls.decode(i)
  625. QE = Q.encode()
  626. if QE != i:
  627. raise TestFailedException("Round trip special %s != %s" %
  628. (binascii.hexlify(QE),binascii.hexlify(i)))
  629. except NotOnCurveException: pass
  630. except InvalidEncodingException: pass
  631. P = cls.base()
  632. Q = cls()
  633. for i in range(n):
  634. #print binascii.hexlify(Q.encode())
  635. QE = Q.encode()
  636. QQ = cls.decode(QE)
  637. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  638. # Testing s -> 1/s: encodes -point on cofactor
  639. s = cls.bytesToGf(QE)
  640. if s != 0:
  641. ss = cls.gfToBytes(1/s,mustBePositive=True)
  642. try:
  643. QN = cls.decode(ss)
  644. if cls.cofactor == 8:
  645. raise TestFailedException("1/s shouldnt work for cofactor 8")
  646. if QN != -Q:
  647. raise TestFailedException("s -> 1/s should negate point for cofactor 4")
  648. except InvalidEncodingException as e:
  649. # Should be raised iff cofactor==8
  650. if cls.cofactor == 4:
  651. raise TestFailedException("s -> 1/s should work for cofactor 4")
  652. QT = Q
  653. for h in range(cls.cofactor):
  654. QT = QT.torque()
  655. if QT.encode() != QE:
  656. raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
  657. Q0 = Q + P
  658. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  659. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  660. r = randint(1,1000)
  661. Q1 = Q0*r
  662. Q2 = Q0*(r+1)
  663. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  664. Q = Q1
  665. def testElligator(cls,n):
  666. print ("Testing elligator on %s" % cls.__name__)
  667. for i in range(n):
  668. r = randombytes(cls.encLen)
  669. P = cls.elligator(r)
  670. if hasattr(P,"invertElligator"):
  671. iv = P.invertElligator()
  672. modr = bytes(cls.gfToBytes(cls.bytesToGf(r,mustBeProper=False,maskHiBits=True)))
  673. iv2 = P.torque().invertElligator()
  674. if modr not in iv: print ("Failed to invert Elligator!")
  675. if len(iv) != len(set(iv)):
  676. print ("Elligator inverses not unique!", len(set(iv)), len(iv))
  677. if iv != iv2:
  678. print ("Elligator is untorqueable!")
  679. #print ([binascii.hexlify(j) for j in iv])
  680. #print ([binascii.hexlify(j) for j in iv2])
  681. #break
  682. else:
  683. pass # TODO
  684. def gangtest(classes,n):
  685. print ("Gang test",[cls.__name__ for cls in classes])
  686. specials = [1]
  687. ii = classes[0].F(-1)
  688. while is_square(ii):
  689. specials.append(ii)
  690. ii = sqrt(ii)
  691. specials.append(ii)
  692. for i in range(n):
  693. rets = [bytes((cls.base()*i).encode()) for cls in classes]
  694. if len(set(rets)) != 1:
  695. print ("Divergence in encode at %d" % i)
  696. for c,ret in zip(classes,rets):
  697. print (c,binascii.hexlify(ret))
  698. print
  699. if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen)
  700. else: r0 = randombytes(classes[0].encLen)
  701. rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
  702. if len(set(rets)) != 1:
  703. print ("Divergence in elligator at %d" % i)
  704. for c,ret in zip(classes,rets):
  705. print (c,binascii.hexlify(ret))
  706. print
  707. def testDoubleAndEncode(cls,n):
  708. print( "Testing doubleAndEncode on %s" % cls.__name__)
  709. P = cls()
  710. for i in range(cls.cofactor):
  711. Q = P.torque()
  712. assert P.doubleAndEncode() == Q.doubleAndEncode()
  713. P = Q
  714. for i in range(n):
  715. r1 = randombytes(cls.encLen)
  716. r2 = randombytes(cls.encLen)
  717. u = cls.elligator(r1) + cls.elligator(r2)
  718. assert u.doubleAndEncode() == u.torque().doubleAndEncode()
  719. testDoubleAndEncode(Ed25519Point,100)
  720. testDoubleAndEncode(NegEd25519Point,100)
  721. testDoubleAndEncode(IsoEd25519Point,100)
  722. testDoubleAndEncode(IsoEd448Point,100)
  723. testDoubleAndEncode(Ed448RistrettoPoint,100)
  724. testDoubleAndEncode(TwistedEd448GoldilocksPoint,100)
  725. #test(Ed25519Point,100)
  726. #test(NegEd25519Point,100)
  727. #test(IsoEd25519Point,100)
  728. #test(IsoEd448Point,100)
  729. #test(TwistedEd448GoldilocksPoint,100)
  730. #test(Ed448GoldilocksPoint,100)
  731. #testElligator(Ed25519Point,100)
  732. #testElligator(NegEd25519Point,100)
  733. #testElligator(IsoEd25519Point,100)
  734. #testElligator(IsoEd448Point,100)
  735. #testElligator(Ed448GoldilocksPoint,100)
  736. #testElligator(TwistedEd448GoldilocksPoint,100)
  737. gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
  738. gangtest([Ed25519Point,IsoEd25519Point],100)