You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

748 lines
24 KiB

  1. import binascii
  2. class InvalidEncodingException(Exception): pass
  3. class NotOnCurveException(Exception): pass
  4. class SpecException(Exception): pass
  5. def lobit(x): return int(x) & 1
  6. def hibit(x): return lobit(2*x)
  7. def negative(x): return lobit(x)
  8. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  9. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  10. def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
  11. def optimized_version_of(spec):
  12. """Decorator: This function is an optimized version of some specification"""
  13. def decorator(f):
  14. def wrapper(self,*args,**kwargs):
  15. def pr(x):
  16. if isinstance(x,bytearray): return binascii.hexlify(x)
  17. else: return str(x)
  18. try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
  19. except Exception as e: spec_ans = None,e
  20. try: opt_ans = f(self,*args,**kwargs),None
  21. except Exception as e: opt_ans = None,e
  22. if spec_ans[1] is None and opt_ans[1] is not None:
  23. raise
  24. #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
  25. # % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
  26. if spec_ans[1] is not None and opt_ans[1] is None:
  27. raise
  28. #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
  29. # % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
  30. if spec_ans[0] != opt_ans[0]:
  31. raise SpecException("Mismatch in %s: %s != %s"
  32. % (f.__name__,pr(spec_ans[0]),pr(opt_ans[0])))
  33. if opt_ans[1] is not None: raise
  34. else: return opt_ans[0]
  35. wrapper.__name__ = f.__name__
  36. return wrapper
  37. return decorator
  38. def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
  39. """Return sqrt(x)"""
  40. if not is_square(x): raise exn
  41. s = sqrt(x)
  42. if negative(s): s=-s
  43. return s
  44. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  45. """Return 1/sqrt(x)"""
  46. if x==0: return 0
  47. if not is_square(x): raise exn
  48. s = sqrt(x)
  49. if negative(s): s=-s
  50. return 1/s
  51. def isqrt_i(x):
  52. """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
  53. if x==0: return True,0
  54. gen = x.parent(-1)
  55. while is_square(gen): gen = sqrt(gen)
  56. if is_square(x): return True,1/sqrt(x)
  57. else: return False,1/sqrt(x*gen)
  58. class QuotientEdwardsPoint(object):
  59. """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
  60. def __init__(self,x=0,y=1):
  61. x = self.x = self.F(x)
  62. y = self.y = self.F(y)
  63. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  64. raise NotOnCurveException(str(self))
  65. def __repr__(self):
  66. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  67. def __iter__(self):
  68. yield self.x
  69. yield self.y
  70. def __add__(self,other):
  71. x,y = self
  72. X,Y = other
  73. a,d = self.a,self.d
  74. return self.__class__(
  75. (x*Y+y*X)/(1+d*x*y*X*Y),
  76. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  77. )
  78. def __neg__(self): return self.__class__(-self.x,self.y)
  79. def __sub__(self,other): return self + (-other)
  80. def __rmul__(self,other): return self*other
  81. def __eq__(self,other):
  82. """NB: this is the only method that is different from the usual one"""
  83. x,y = self
  84. X,Y = other
  85. return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
  86. def __ne__(self,other): return not (self==other)
  87. def __mul__(self,exp):
  88. exp = int(exp)
  89. if exp < 0: exp,self = -exp,-self
  90. total = self.__class__()
  91. work = self
  92. while exp != 0:
  93. if exp & 1: total += work
  94. work += work
  95. exp >>= 1
  96. return total
  97. def xyzt(self):
  98. x,y = self
  99. z = self.F.random_element()
  100. return x*z,y*z,z,x*y*z
  101. def torque(self):
  102. """Apply cofactor group, except keeping the point even"""
  103. if self.cofactor == 8:
  104. if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i)
  105. if self.a == 1: return self.__class__(-self.y, self.x)
  106. else:
  107. return self.__class__(-self.x, -self.y)
  108. # Utility functions
  109. @classmethod
  110. def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
  111. """Convert little-endian bytes to field element, sanity check length"""
  112. if len(bytes) != cls.encLen:
  113. raise InvalidEncodingException("wrong length %d" % len(bytes))
  114. s = dec_le(bytes)
  115. if mustBeProper and s >= cls.F.modulus():
  116. raise InvalidEncodingException("%d out of range!" % s)
  117. s = cls.F(s)
  118. if mustBePositive and negative(s):
  119. raise InvalidEncodingException("%d is negative!" % s)
  120. return s
  121. @classmethod
  122. def gfToBytes(cls,x,mustBePositive=False):
  123. """Convert little-endian bytes to field element, sanity check length"""
  124. if negative(x) and mustBePositive: x = -x
  125. return enc_le(x,cls.encLen)
  126. class RistrettoPoint(QuotientEdwardsPoint):
  127. """The new Ristretto group"""
  128. def encodeSpec(self):
  129. """Unoptimized specification for encoding"""
  130. x,y = self
  131. if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque()
  132. if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
  133. if negative(x): x,y = -x,-y
  134. s = xsqrt(self.mneg*(1-y)/(1+y),exn=Exception("Unimplemented: point is odd: " + str(self)))
  135. return self.gfToBytes(s)
  136. @classmethod
  137. def decodeSpec(cls,s):
  138. """Unoptimized specification for decoding"""
  139. s = cls.bytesToGf(s,mustBePositive=True)
  140. a,d = cls.a,cls.d
  141. x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
  142. y = (1+a*s^2) / (1-a*s^2)
  143. if cls.cofactor==8 and (negative(x*y) or y==0):
  144. raise InvalidEncodingException("x*y has high bit")
  145. return cls(x,y)
  146. @optimized_version_of("encodeSpec")
  147. def encode(self):
  148. """Encode, optimized version"""
  149. a,d,mneg = self.a,self.d,self.mneg
  150. x,y,z,t = self.xyzt()
  151. if self.cofactor==8:
  152. u1 = mneg*(z+y)*(z-y)
  153. u2 = x*y # = t*z
  154. isr = isqrt(u1*u2^2)
  155. i1 = isr*u1 # sqrt(mneg*(z+y)*(z-y))/(x*y)
  156. i2 = isr*u2 # 1/sqrt(a*(y+z)*(y-z))
  157. z_inv = i1*i2*t # 1/z
  158. if negative(t*z_inv):
  159. if a==-1:
  160. x,y = y*self.i,x*self.i
  161. den_inv = self.magic * i1
  162. else:
  163. x,y = -y,x
  164. den_inv = self.i * self.magic * i1
  165. else:
  166. den_inv = i2
  167. if negative(x*z_inv): y = -y
  168. s = (z-y) * den_inv
  169. else:
  170. num = mneg*(z+y)*(z-y)
  171. isr = isqrt(num*y^2)
  172. if negative(isr^2*num*y*t): y = -y
  173. s = isr*y*(z-y)
  174. return self.gfToBytes(s,mustBePositive=True)
  175. @classmethod
  176. @optimized_version_of("decodeSpec")
  177. def decode(cls,s):
  178. """Decode, optimized version"""
  179. s = cls.bytesToGf(s,mustBePositive=True)
  180. a,d = cls.a,cls.d
  181. yden = 1-a*s^2
  182. ynum = 1+a*s^2
  183. yden_sqr = yden^2
  184. xden_sqr = a*d*ynum^2 - yden_sqr
  185. isr = isqrt(xden_sqr * yden_sqr)
  186. xden_inv = isr * yden
  187. yden_inv = xden_inv * isr * xden_sqr
  188. x = 2*s*xden_inv
  189. if negative(x): x = -x
  190. y = ynum * yden_inv
  191. if cls.cofactor==8 and (negative(x*y) or y==0):
  192. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  193. return cls(x,y)
  194. @classmethod
  195. def fromJacobiQuartic(cls,s,t,sgn=1):
  196. """Convert point from its Jacobi Quartic representation"""
  197. a,d = cls.a,cls.d
  198. assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
  199. x = 2*s*cls.magic / t
  200. y = (1+a*s^2) / (1-a*s^2)
  201. return cls(sgn*x,y)
  202. @classmethod
  203. def elligatorSpec(cls,r0):
  204. a,d = cls.a,cls.d
  205. r = cls.qnr * cls.bytesToGf(r0)^2
  206. den = (d*r-a)*(a*r-d)
  207. if den == 0: return cls()
  208. n1 = cls.a*(r+1)*(a+d)*(d-a)/den
  209. n2 = r*n1
  210. if is_square(n1):
  211. sgn,s,t = 1, xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
  212. else:
  213. sgn,s,t = -1,-xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
  214. return cls.fromJacobiQuartic(s,t)
  215. @classmethod
  216. @optimized_version_of("elligatorSpec")
  217. def elligator(cls,r0):
  218. a,d = cls.a,cls.d
  219. r0 = cls.bytesToGf(r0)
  220. r = cls.qnr * r0^2
  221. den = (d*r-a)*(a*r-d)
  222. num = cls.a*(r+1)*(a+d)*(d-a)
  223. iss,isri = isqrt_i(num*den)
  224. if iss: sgn,twiddle = 1,1
  225. else: sgn,twiddle = -1,r0*cls.qnr
  226. isri *= twiddle
  227. s = isri*num
  228. t = -sgn*isri*s*(r-1)*(d+a)^2 - 1
  229. if negative(s) == iss: s = -s
  230. return cls.fromJacobiQuartic(s,t)
  231. class Decaf_1_1_Point(QuotientEdwardsPoint):
  232. """Like current decaf but tweaked for simplicity"""
  233. def encodeSpec(self):
  234. """Unoptimized specification for encoding"""
  235. a,d = self.a,self.d
  236. x,y = self
  237. if x==0 or y==0: return(self.gfToBytes(0))
  238. if self.cofactor==8 and negative(x*y*self.isoMagic):
  239. x,y = self.torque()
  240. sr = xsqrt(1-a*x^2)
  241. altx = x*y*self.isoMagic / sr
  242. if negative(altx): s = (1+sr)/x
  243. else: s = (1-sr)/x
  244. return self.gfToBytes(s,mustBePositive=True)
  245. @classmethod
  246. def decodeSpec(cls,s):
  247. """Unoptimized specification for decoding"""
  248. a,d = cls.a,cls.d
  249. s = cls.bytesToGf(s,mustBePositive=True)
  250. if s==0: return cls()
  251. t = xsqrt(s^4 + 2*(a-2*d)*s^2 + 1)
  252. altx = 2*s*cls.isoMagic/t
  253. if negative(altx): t = -t
  254. x = 2*s / (1+a*s^2)
  255. y = (1-a*s^2) / t
  256. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  257. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  258. return cls(x,y)
  259. def toJacobiQuartic(self,toggle_rotation=False,toggle_altx=False,toggle_s=False):
  260. "Return s,t on jacobi curve"
  261. a,d = self.a,self.d
  262. x,y,z,t = self.xyzt()
  263. if self.cofactor == 8:
  264. # Cofactor 8 version
  265. # Simulate IMAGINE_TWIST because that's how libdecaf does it
  266. x = self.i*x
  267. t = self.i*t
  268. a = -a
  269. d = -d
  270. # OK, the actual libdecaf code should be here
  271. num = (z+y)*(z-y)
  272. den = x*y
  273. isr = isqrt(num*(a-d)*den^2)
  274. iden = isr * den * self.isoMagic
  275. inum = isr * num
  276. if negative(iden*inum*self.i*t^2*(d-a)) != toggle_rotation:
  277. iden,inum = inum,iden
  278. fac = x*sqrt(a)
  279. toggle=(a==-1)
  280. else:
  281. fac = y
  282. toggle=False
  283. imi = self.isoMagic * self.i
  284. altx = inum*t*imi
  285. neg_altx = negative(altx) != toggle_altx
  286. if neg_altx != toggle: inum =- inum
  287. tmp = fac*(inum*z + 1)
  288. s = iden*tmp*imi
  289. negm1 = (negative(s) != toggle_s) != neg_altx
  290. if negm1: m1 = a*fac + z
  291. else: m1 = a*fac - z
  292. swap = toggle_s
  293. else:
  294. # Much simpler cofactor 4 version
  295. num = (x+t)*(x-t)
  296. isr = isqrt(num*(a-d)*x^2)
  297. ratio = isr*num
  298. altx = ratio*self.isoMagic
  299. neg_altx = negative(altx) != toggle_altx
  300. if neg_altx: ratio =- ratio
  301. tmp = ratio*z - t
  302. s = (a-d)*isr*x*tmp
  303. negx = (negative(s) != toggle_s) != neg_altx
  304. if negx: m1 = -a*t + x
  305. else: m1 = -a*t - x
  306. swap = toggle_s
  307. if negative(s): s = -s
  308. return s,m1,a*tmp,swap
  309. def invertElligator(self,toggle_r=False,*args,**kwargs):
  310. "Produce preimage of self under elligator, or None"
  311. a,d = self.a,self.d
  312. rets = []
  313. tr = [False,True] if self.cofactor == 8 else [False]
  314. for toggle_rotation in tr:
  315. for toggle_altx in [False,True]:
  316. for toggle_s in [False,True]:
  317. for toggle_r in [False,True]:
  318. s,m1,m12,swap = self.toJacobiQuartic(toggle_rotation,toggle_altx,toggle_s)
  319. print
  320. print toggle_rotation,toggle_altx,toggle_s
  321. print m1
  322. print m12
  323. if self == self.__class__() and self.cofactor == 4:
  324. # Hacks for identity!
  325. if toggle_altx: m12 = 1
  326. elif toggle_s: m1 = 1
  327. elif toggle_r: continue
  328. ## BOTH???
  329. rnum = (d*a*m12-m1)
  330. rden = ((d*a-1)*m12+m1)
  331. if swap: rnum,rden = rden,rnum
  332. ok,sr = isqrt_i(rnum*rden*self.qnr)
  333. if not ok: continue
  334. sr *= rnum
  335. print "Works! %d %x" % (swap,sr)
  336. if negative(sr) != toggle_r: sr = -sr
  337. ret = self.gfToBytes(sr)
  338. if self.elligator(ret) != self and self.elligator(ret) != -self:
  339. print "WRONG!",[toggle_rotation,toggle_altx,toggle_s]
  340. if self.elligator(ret) == -self and self != -self: print "Negated!",[toggle_rotation,toggle_altx,toggle_s]
  341. rets.append(bytes(ret))
  342. return rets
  343. @optimized_version_of("encodeSpec")
  344. def encode(self):
  345. """Encode, optimized version"""
  346. return self.gfToBytes(self.toJacobiQuartic()[0])
  347. @classmethod
  348. @optimized_version_of("decodeSpec")
  349. def decode(cls,s):
  350. """Decode, optimized version"""
  351. a,d = cls.a,cls.d
  352. s = cls.bytesToGf(s,mustBePositive=True)
  353. #if s==0: return cls()
  354. s2 = s^2
  355. den = 1+a*s2
  356. num = den^2 - 4*d*s2
  357. isr = isqrt(num*den^2)
  358. altx = 2*s*isr*den*cls.isoMagic
  359. if negative(altx): isr = -isr
  360. x = 2*s *isr^2*den*num
  361. y = (1-a*s^2) * isr*den
  362. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  363. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  364. return cls(x,y)
  365. @classmethod
  366. def fromJacobiQuartic(cls,s,t,sgn=1):
  367. """Convert point from its Jacobi Quartic representation"""
  368. a,d = cls.a,cls.d
  369. if s==0: return cls()
  370. x = 2*s / (1+a*s^2)
  371. y = (1-a*s^2) / t
  372. return cls(x,sgn*y)
  373. @classmethod
  374. def elligatorSpec(cls,r0,fromR=False):
  375. a,d = cls.a,cls.d
  376. if fromR: r = r0
  377. else: r = cls.qnr * cls.bytesToGf(r0)^2
  378. den = (d*r-(d-a))*((d-a)*r-d)
  379. if den == 0: return cls()
  380. n1 = (r+1)*(a-2*d)/den
  381. n2 = r*n1
  382. if is_square(n1):
  383. sgn,s,t = 1, xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1
  384. else:
  385. sgn,s,t = -1, -xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1
  386. return cls.fromJacobiQuartic(s,t)
  387. @classmethod
  388. @optimized_version_of("elligatorSpec")
  389. def elligator(cls,r0):
  390. a,d = cls.a,cls.d
  391. r0 = cls.bytesToGf(r0)
  392. r = cls.qnr * r0^2
  393. den = (d*r-(d-a))*((d-a)*r-d)
  394. num = (r+1)*(a-2*d)
  395. iss,isri = isqrt_i(num*den)
  396. if iss: sgn,twiddle = 1,1
  397. else: sgn,twiddle = -1,r0*cls.qnr
  398. isri *= twiddle
  399. s = isri*num
  400. t = -sgn*isri*s*(r-1)*(a-2*d)^2 - 1
  401. if negative(s) == iss: s = -s
  402. return cls.fromJacobiQuartic(s,t)
  403. def elligatorInverseBruteForce(self):
  404. """Invert Elligator using SAGE's polynomial solver"""
  405. a,d = self.a,self.d
  406. R.<r0> = self.F[]
  407. r = self.qnr * r0^2
  408. den = (d*r-(d-a))*((d-a)*r-d)
  409. n1 = (r+1)*(a-2*d)/den
  410. n2 = r*n1
  411. ret = set()
  412. for s2,t in [(n1, -(r-1)*(a-2*d)^2 / den - 1),
  413. (n2,r*(r-1)*(a-2*d)^2 / den - 1)]:
  414. x2 = 4*s2/(1+a*s2)^2
  415. y = (1-a*s2) / t
  416. selfT = self
  417. for i in xrange(self.cofactor/2):
  418. xT,yT = selfT
  419. polyX = xT^2-x2
  420. polyY = yT-y
  421. sx = set(r for r,_ in polyX.numerator().roots())
  422. sy = set(r for r,_ in polyY.numerator().roots())
  423. ret = ret.union(sx.intersection(sy))
  424. selfT = selfT.torque()
  425. ret = [self.gfToBytes(r) for r in ret]
  426. for r in ret:
  427. assert self.elligator(r) in [self,-self]
  428. ret = [r for r in ret if self.elligator(r) == self]
  429. return ret
  430. class Ed25519Point(RistrettoPoint):
  431. F = GF(2^255-19)
  432. d = F(-121665/121666)
  433. a = F(-1)
  434. i = sqrt(F(-1))
  435. mneg = F(1)
  436. qnr = i
  437. magic = isqrt(a*d-1)
  438. cofactor = 8
  439. encLen = 32
  440. @classmethod
  441. def base(cls):
  442. return cls( 15112221349535400772501151409588531511454012693041857206046113283949847762202, 46316835694926478169428394003475163141307993866256225615783033603165251855960
  443. )
  444. class NegEd25519Point(RistrettoPoint):
  445. F = GF(2^255-19)
  446. d = F(121665/121666)
  447. a = F(1)
  448. i = sqrt(F(-1))
  449. mneg = F(-1) # TODO checkme vs 1-ad or whatever
  450. qnr = i
  451. magic = isqrt(a*d-1)
  452. cofactor = 8
  453. encLen = 32
  454. @classmethod
  455. def base(cls):
  456. y = cls.F(4/5)
  457. x = sqrt((y^2-1)/(cls.d*y^2-cls.a))
  458. if negative(x): x = -x
  459. return cls(x,y)
  460. class IsoEd448Point(RistrettoPoint):
  461. F = GF(2^448-2^224-1)
  462. d = F(39082/39081)
  463. a = F(1)
  464. mneg = F(-1)
  465. qnr = -1
  466. magic = isqrt(a*d-1)
  467. cofactor = 4
  468. encLen = 56
  469. @classmethod
  470. def base(cls):
  471. return cls( # RFC has it wrong
  472. 345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
  473. -363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
  474. )
  475. class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
  476. F = GF(2^448-2^224-1)
  477. d = F(-39082)
  478. a = F(-1)
  479. qnr = -1
  480. cofactor = 4
  481. encLen = 56
  482. isoMagic = IsoEd448Point.magic
  483. @classmethod
  484. def base(cls):
  485. return cls.decodeSpec(Ed448GoldilocksPoint.base().encodeSpec())
  486. class Ed448GoldilocksPoint(Decaf_1_1_Point):
  487. F = GF(2^448-2^224-1)
  488. d = F(-39081)
  489. a = F(1)
  490. qnr = -1
  491. cofactor = 4
  492. encLen = 56
  493. isoMagic = IsoEd448Point.magic
  494. @classmethod
  495. def base(cls):
  496. return 2*cls(
  497. 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
  498. )
  499. class IsoEd25519Point(Decaf_1_1_Point):
  500. # TODO: twisted iso too!
  501. # TODO: twisted iso might have to IMAGINE_TWIST or whatever
  502. F = GF(2^255-19)
  503. d = F(-121665)
  504. a = F(1)
  505. i = sqrt(F(-1))
  506. qnr = i
  507. magic = isqrt(a*d-1)
  508. cofactor = 8
  509. encLen = 32
  510. isoMagic = Ed25519Point.magic
  511. isoA = Ed25519Point.a
  512. @classmethod
  513. def base(cls):
  514. return cls.decodeSpec(Ed25519Point.base().encode())
  515. class TestFailedException(Exception): pass
  516. def test(cls,n):
  517. print "Testing curve %s" % cls.__name__
  518. specials = [1]
  519. ii = cls.F(-1)
  520. while is_square(ii):
  521. specials.append(ii)
  522. ii = sqrt(ii)
  523. specials.append(ii)
  524. for i in specials:
  525. if negative(cls.F(i)): i = -i
  526. i = enc_le(i,cls.encLen)
  527. try:
  528. Q = cls.decode(i)
  529. QE = Q.encode()
  530. if QE != i:
  531. raise TestFailedException("Round trip special %s != %s" %
  532. (binascii.hexlify(QE),binascii.hexlify(i)))
  533. except NotOnCurveException: pass
  534. except InvalidEncodingException: pass
  535. P = cls.base()
  536. Q = cls()
  537. for i in xrange(n):
  538. #print binascii.hexlify(Q.encode())
  539. QE = Q.encode()
  540. QQ = cls.decode(QE)
  541. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  542. # Testing s -> 1/s: encodes -point on cofactor
  543. s = cls.bytesToGf(QE)
  544. if s != 0:
  545. ss = cls.gfToBytes(1/s,mustBePositive=True)
  546. try:
  547. QN = cls.decode(ss)
  548. if cls.cofactor == 8:
  549. raise TestFailedException("1/s shouldnt work for cofactor 8")
  550. if QN != -Q:
  551. raise TestFailedException("s -> 1/s should negate point for cofactor 4")
  552. except InvalidEncodingException as e:
  553. # Should be raised iff cofactor==8
  554. if cls.cofactor == 4:
  555. raise TestFailedException("s -> 1/s should work for cofactor 4")
  556. QT = Q
  557. for h in xrange(cls.cofactor):
  558. QT = QT.torque()
  559. if QT.encode() != QE:
  560. raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
  561. Q0 = Q + P
  562. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  563. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  564. r = randint(1,1000)
  565. Q1 = Q0*r
  566. Q2 = Q0*(r+1)
  567. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  568. Q = Q1
  569. test(Ed25519Point,100)
  570. test(NegEd25519Point,100)
  571. test(IsoEd25519Point,100)
  572. test(IsoEd448Point,100)
  573. test(TwistedEd448GoldilocksPoint,100)
  574. test(Ed448GoldilocksPoint,100)
  575. def testElligator(cls,n):
  576. print "Testing elligator on %s" % cls.__name__
  577. for i in xrange(n):
  578. r = randombytes(cls.encLen)
  579. P = cls.elligator(r)
  580. if hasattr(P,"invertElligator"):
  581. iv = P.invertElligator()
  582. modr = bytes(cls.gfToBytes(cls.bytesToGf(r)))
  583. iv2 = P.torque().invertElligator()
  584. if modr not in iv: print "Failed to invert Elligator!"
  585. if len(iv) != len(set(iv)):
  586. print "Elligator inverses not unique!", len(set(iv)), len(iv)
  587. if iv != iv2:
  588. print "Elligator is untorqueable!"
  589. #print [binascii.hexlify(j) for j in iv]
  590. #print [binascii.hexlify(j) for j in iv2]
  591. #break
  592. else:
  593. pass # TODO
  594. testElligator(Ed25519Point,100)
  595. testElligator(NegEd25519Point,100)
  596. testElligator(IsoEd25519Point,100)
  597. testElligator(IsoEd448Point,100)
  598. testElligator(Ed448GoldilocksPoint,100)
  599. testElligator(TwistedEd448GoldilocksPoint,100)
  600. def gangtest(classes,n):
  601. print "Gang test",[cls.__name__ for cls in classes]
  602. specials = [1]
  603. ii = classes[0].F(-1)
  604. while is_square(ii):
  605. specials.append(ii)
  606. ii = sqrt(ii)
  607. specials.append(ii)
  608. for i in xrange(n):
  609. rets = [bytes((cls.base()*i).encode()) for cls in classes]
  610. if len(set(rets)) != 1:
  611. print "Divergence in encode at %d" % i
  612. for c,ret in zip(classes,rets):
  613. print c,binascii.hexlify(ret)
  614. print
  615. if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen)
  616. else: r0 = randombytes(classes[0].encLen)
  617. rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
  618. if len(set(rets)) != 1:
  619. print "Divergence in elligator at %d" % i
  620. for c,ret in zip(classes,rets):
  621. print c,binascii.hexlify(ret)
  622. print
  623. gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
  624. gangtest([Ed25519Point,IsoEd25519Point],100)