You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

613 lines
19 KiB

  1. import binascii
  2. class InvalidEncodingException(Exception): pass
  3. class NotOnCurveException(Exception): pass
  4. class SpecException(Exception): pass
  5. def lobit(x): return int(x) & 1
  6. def hibit(x): return lobit(2*x)
  7. def negative(x): return lobit(x)
  8. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  9. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  10. def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
  11. def optimized_version_of(spec):
  12. """Decorator: This function is an optimized version of some specification"""
  13. def decorator(f):
  14. def wrapper(self,*args,**kwargs):
  15. def pr(x):
  16. if isinstance(x,bytearray): return binascii.hexlify(x)
  17. else: return str(x)
  18. try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
  19. except Exception as e: spec_ans = None,e
  20. try: opt_ans = f(self,*args,**kwargs),None
  21. except Exception as e: opt_ans = None,e
  22. if spec_ans[1] is None and opt_ans[1] is not None:
  23. raise
  24. #raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
  25. # % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
  26. if spec_ans[1] is not None and opt_ans[1] is None:
  27. raise
  28. #raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
  29. # % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
  30. if spec_ans[0] != opt_ans[0]:
  31. raise SpecException("Mismatch in %s: %s != %s"
  32. % (f.__name__,pr(spec_ans[0]),pr(opt_ans[0])))
  33. if opt_ans[1] is not None: raise
  34. else: return opt_ans[0]
  35. wrapper.__name__ = f.__name__
  36. return wrapper
  37. return decorator
  38. def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
  39. """Return sqrt(x)"""
  40. if not is_square(x): raise exn
  41. s = sqrt(x)
  42. if negative(s): s=-s
  43. return s
  44. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  45. """Return 1/sqrt(x)"""
  46. if x==0: return 0
  47. if not is_square(x): raise exn
  48. return 1/sqrt(x)
  49. def isqrt_i(x):
  50. """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
  51. if x==0: return True,0
  52. gen = x.parent(-1)
  53. while is_square(gen): gen = sqrt(gen)
  54. if is_square(x): return True,1/sqrt(x)
  55. else: return False,1/sqrt(x*gen)
  56. class QuotientEdwardsPoint(object):
  57. """Abstract class for point an a quotiented Edwards curve; needs F,a,d,cofactor to work"""
  58. def __init__(self,x=0,y=1):
  59. x = self.x = self.F(x)
  60. y = self.y = self.F(y)
  61. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  62. raise NotOnCurveException(str(self))
  63. def __repr__(self):
  64. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  65. def __iter__(self):
  66. yield self.x
  67. yield self.y
  68. def __add__(self,other):
  69. x,y = self
  70. X,Y = other
  71. a,d = self.a,self.d
  72. return self.__class__(
  73. (x*Y+y*X)/(1+d*x*y*X*Y),
  74. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  75. )
  76. def __neg__(self): return self.__class__(-self.x,self.y)
  77. def __sub__(self,other): return self + (-other)
  78. def __rmul__(self,other): return self*other
  79. def __eq__(self,other):
  80. """NB: this is the only method that is different from the usual one"""
  81. x,y = self
  82. X,Y = other
  83. return x*Y == X*y or (self.cofactor==8 and -self.a*x*X == y*Y)
  84. def __ne__(self,other): return not (self==other)
  85. def __mul__(self,exp):
  86. exp = int(exp)
  87. if exp < 0: exp,self = -exp,-self
  88. total = self.__class__()
  89. work = self
  90. while exp != 0:
  91. if exp & 1: total += work
  92. work += work
  93. exp >>= 1
  94. return total
  95. def xyzt(self):
  96. x,y = self
  97. z = self.F.random_element()
  98. return x*z,y*z,z,x*y*z
  99. def torque(self):
  100. """Apply cofactor group, except keeping the point even"""
  101. if self.cofactor == 8:
  102. if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i)
  103. if self.a == 1: return self.__class__(-self.y, self.x)
  104. else:
  105. return self.__class__(-self.x, -self.y)
  106. # Utility functions
  107. @classmethod
  108. def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
  109. """Convert little-endian bytes to field element, sanity check length"""
  110. if len(bytes) != cls.encLen:
  111. raise InvalidEncodingException("wrong length %d" % len(bytes))
  112. s = dec_le(bytes)
  113. if mustBeProper and s >= cls.F.modulus():
  114. raise InvalidEncodingException("%d out of range!" % s)
  115. s = cls.F(s)
  116. if mustBePositive and negative(s):
  117. raise InvalidEncodingException("%d is negative!" % s)
  118. return s
  119. @classmethod
  120. def gfToBytes(cls,x,mustBePositive=False):
  121. """Convert little-endian bytes to field element, sanity check length"""
  122. if negative(x) and mustBePositive: x = -x
  123. return enc_le(x,cls.encLen)
  124. class RistrettoPoint(QuotientEdwardsPoint):
  125. """The new Ristretto group"""
  126. def encodeSpec(self):
  127. """Unoptimized specification for encoding"""
  128. x,y = self
  129. if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque()
  130. if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
  131. if negative(x): x,y = -x,-y
  132. s = xsqrt(self.mneg*(1-y)/(1+y),exn=Exception("Unimplemented: point is odd: " + str(self)))
  133. return self.gfToBytes(s)
  134. @classmethod
  135. def decodeSpec(cls,s):
  136. """Unoptimized specification for decoding"""
  137. s = cls.bytesToGf(s,mustBePositive=True)
  138. a,d = cls.a,cls.d
  139. x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
  140. y = (1+a*s^2) / (1-a*s^2)
  141. if cls.cofactor==8 and (negative(x*y) or y==0):
  142. raise InvalidEncodingException("x*y has high bit")
  143. return cls(x,y)
  144. @optimized_version_of("encodeSpec")
  145. def encode(self):
  146. """Encode, optimized version"""
  147. a,d,mneg = self.a,self.d,self.mneg
  148. x,y,z,t = self.xyzt()
  149. if self.cofactor==8:
  150. u1 = mneg*(z+y)*(z-y)
  151. u2 = x*y # = t*z
  152. isr = isqrt(u1*u2^2)
  153. i1 = isr*u1 # sqrt(mneg*(z+y)*(z-y))/(x*y)
  154. i2 = isr*u2 # 1/sqrt(a*(y+z)*(y-z))
  155. z_inv = i1*i2*t # 1/z
  156. if negative(t*z_inv):
  157. if a==-1:
  158. x,y = y*self.i,x*self.i
  159. den_inv = self.magic * i1
  160. else:
  161. x,y = -y,x
  162. den_inv = self.i * self.magic * i1
  163. else:
  164. den_inv = i2
  165. if negative(x*z_inv): y = -y
  166. s = (z-y) * den_inv
  167. else:
  168. num = mneg*(z+y)*(z-y)
  169. isr = isqrt(num*y^2)
  170. if negative(isr^2*num*y*t): y = -y
  171. s = isr*y*(z-y)
  172. return self.gfToBytes(s,mustBePositive=True)
  173. @classmethod
  174. @optimized_version_of("decodeSpec")
  175. def decode(cls,s):
  176. """Decode, optimized version"""
  177. s = cls.bytesToGf(s,mustBePositive=True)
  178. a,d = cls.a,cls.d
  179. yden = 1-a*s^2
  180. ynum = 1+a*s^2
  181. yden_sqr = yden^2
  182. xden_sqr = a*d*ynum^2 - yden_sqr
  183. isr = isqrt(xden_sqr * yden_sqr)
  184. xden_inv = isr * yden
  185. yden_inv = xden_inv * isr * xden_sqr
  186. x = 2*s*xden_inv
  187. if negative(x): x = -x
  188. y = ynum * yden_inv
  189. if cls.cofactor==8 and (negative(x*y) or y==0):
  190. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  191. return cls(x,y)
  192. @classmethod
  193. def fromJacobiQuartic(cls,s,t,sgn=1):
  194. """Convert point from its Jacobi Quartic representation"""
  195. a,d = cls.a,cls.d
  196. assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
  197. x = 2*s*cls.magic / t
  198. y = (1+a*s^2) / (1-a*s^2)
  199. return cls(sgn*x,y)
  200. @classmethod
  201. def elligatorSpec(cls,r0):
  202. a,d = cls.a,cls.d
  203. r = cls.qnr * cls.bytesToGf(r0)^2
  204. den = (d*r-a)*(a*r-d)
  205. n1 = cls.a*(r+1)*(a+d)*(d-a)/den
  206. n2 = r*n1
  207. if is_square(n1):
  208. sgn,s,t = 1, xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
  209. else:
  210. sgn,s,t = -1,-xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
  211. return cls.fromJacobiQuartic(s,t)
  212. @classmethod
  213. @optimized_version_of("elligatorSpec")
  214. def elligator(cls,r0):
  215. a,d = cls.a,cls.d
  216. r0 = cls.bytesToGf(r0)
  217. r = cls.qnr * r0^2
  218. den = (d*r-a)*(a*r-d)
  219. num = cls.a*(r+1)*(a+d)*(d-a)
  220. iss,isri = isqrt_i(num*den)
  221. if iss: sgn,twiddle = 1,1
  222. else: sgn,twiddle = -1,r0*cls.qnr
  223. isri *= twiddle
  224. s = isri*num
  225. t = -sgn*isri*s*(r-1)*(d+a)^2 - 1
  226. if negative(s) == iss: s = -s
  227. return cls.fromJacobiQuartic(s,t)
  228. class Decaf_1_1_Point(QuotientEdwardsPoint):
  229. """Like current decaf but tweaked for simplicity"""
  230. def encodeSpec(self):
  231. """Unoptimized specification for encoding"""
  232. a,d = self.a,self.d
  233. x,y = self
  234. if x==0 or y==0: return(self.gfToBytes(0))
  235. if self.cofactor==8 and negative(x*y*self.isoMagic):
  236. x,y = self.torque()
  237. sr = xsqrt(1-a*x^2)
  238. altx = x*y*self.isoMagic / sr
  239. if negative(altx): s = (1+sr)/x
  240. else: s = (1-sr)/x
  241. return self.gfToBytes(s,mustBePositive=True)
  242. @classmethod
  243. def decodeSpec(cls,s):
  244. """Unoptimized specification for decoding"""
  245. a,d = cls.a,cls.d
  246. s = cls.bytesToGf(s,mustBePositive=True)
  247. if s==0: return cls()
  248. t = xsqrt(s^4 + 2*(a-2*d)*s^2 + 1)
  249. altx = 2*s*cls.isoMagic/t
  250. if negative(altx): t = -t
  251. x = 2*s / (1+a*s^2)
  252. y = (1-a*s^2) / t
  253. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  254. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  255. return cls(x,y)
  256. @optimized_version_of("encodeSpec")
  257. def encode(self):
  258. """Encode, optimized version"""
  259. a,d = self.a,self.d
  260. x,y,z,t = self.xyzt()
  261. if self.cofactor == 8:
  262. # Cofactor 8 version
  263. num = (z+y)*(z-y)
  264. den = x*y
  265. tmp = isqrt(num*(a-d)*den^2)
  266. if negative(tmp^2*den*num*(a-d)*t^2*self.isoMagic):
  267. den,num = num,den
  268. tmp *= sqrt(a-d) # witness that cofactor is 8
  269. yisr = x*sqrt(a)
  270. toggle = (a==1)
  271. else:
  272. yisr = y*(a*d-1)
  273. toggle = False
  274. tiisr = tmp*num
  275. altx = tiisr*t*self.isoMagic
  276. if negative(altx) != toggle: tiisr =- tiisr
  277. s = tmp*den*yisr*(tiisr*z - 1)
  278. else:
  279. # Much simpler cofactor 4 version
  280. num = (x+t)*(x-t)
  281. isr = isqrt(num*(a-d)*x^2)
  282. ratio = isr*num
  283. if negative(ratio*self.isoMagic): ratio=-ratio
  284. s = (a-d)*isr*x*(ratio*z - t)
  285. return self.gfToBytes(s,mustBePositive=True)
  286. @classmethod
  287. @optimized_version_of("decodeSpec")
  288. def decode(cls,s):
  289. """Decode, optimized version"""
  290. a,d = cls.a,cls.d
  291. s = cls.bytesToGf(s,mustBePositive=True)
  292. if s==0: return cls()
  293. s2 = s^2
  294. den = 1+a*s2
  295. num = den^2 - 4*d*s2
  296. isr = isqrt(num*den^2)
  297. altx = 2*s*isr*den*cls.isoMagic
  298. if negative(altx): isr = -isr
  299. x = 2*s *isr^2*den*num
  300. y = (1-a*s^2) * isr*den
  301. if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0):
  302. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  303. return cls(x,y)
  304. @classmethod
  305. def fromJacobiQuartic(cls,s,t,sgn=1):
  306. """Convert point from its Jacobi Quartic representation"""
  307. a,d = cls.a,cls.d
  308. if s==0: return cls()
  309. x = 2*s / (1+a*s^2)
  310. y = (1-a*s^2) / t
  311. return cls(x,sgn*y)
  312. @classmethod
  313. def elligatorSpec(cls,r0):
  314. a,d = cls.a,cls.d
  315. r = cls.qnr * cls.bytesToGf(r0)^2
  316. den = (d*r-(d-a))*((d-a)*r-d)
  317. n1 = (r+1)*(a-2*d)/den
  318. n2 = r*n1
  319. if is_square(n1):
  320. sgn,s,t = 1, xsqrt(n1), -(r-1)*(a-2*d)^2 / den - 1
  321. else:
  322. sgn,s,t = -1, -xsqrt(n2), r*(r-1)*(a-2*d)^2 / den - 1
  323. return cls.fromJacobiQuartic(s,t)
  324. @classmethod
  325. @optimized_version_of("elligatorSpec")
  326. def elligator(cls,r0):
  327. a,d = cls.a,cls.d
  328. r0 = cls.bytesToGf(r0)
  329. r = cls.qnr * r0^2
  330. den = (d*r-(d-a))*((d-a)*r-d)
  331. num = (r+1)*(a-2*d)
  332. iss,isri = isqrt_i(num*den)
  333. if iss: sgn,twiddle = 1,1
  334. else: sgn,twiddle = -1,r0*cls.qnr
  335. isri *= twiddle
  336. s = isri*num
  337. t = -sgn*isri*s*(r-1)*(a-2*d)^2 - 1
  338. if negative(s) == iss: s = -s
  339. return cls.fromJacobiQuartic(s,t)
  340. class Ed25519Point(RistrettoPoint):
  341. F = GF(2^255-19)
  342. d = F(-121665/121666)
  343. a = F(-1)
  344. i = sqrt(F(-1))
  345. mneg = F(1)
  346. qnr = i
  347. magic = isqrt(a*d-1)
  348. cofactor = 8
  349. encLen = 32
  350. @classmethod
  351. def base(cls):
  352. return cls( 15112221349535400772501151409588531511454012693041857206046113283949847762202, 46316835694926478169428394003475163141307993866256225615783033603165251855960
  353. )
  354. class NegEd25519Point(RistrettoPoint):
  355. F = GF(2^255-19)
  356. d = F(121665/121666)
  357. a = F(1)
  358. i = sqrt(F(-1))
  359. mneg = F(-1) # TODO checkme vs 1-ad or whatever
  360. qnr = i
  361. magic = isqrt(a*d-1)
  362. cofactor = 8
  363. encLen = 32
  364. @classmethod
  365. def base(cls):
  366. y = cls.F(4/5)
  367. x = sqrt((y^2-1)/(cls.d*y^2-cls.a))
  368. if negative(x): x = -x
  369. return cls(x,y)
  370. class IsoEd448Point(RistrettoPoint):
  371. F = GF(2^448-2^224-1)
  372. d = F(39082/39081)
  373. a = F(1)
  374. mneg = F(-1)
  375. qnr = -1
  376. magic = isqrt(a*d-1)
  377. cofactor = 4
  378. encLen = 56
  379. @classmethod
  380. def base(cls):
  381. return cls( # RFC has it wrong
  382. 345397493039729516374008604150537410266655260075183290216406970281645695073672344430481787759340633221708391583424041788924124567700732,
  383. -363419362147803445274661903944002267176820680343659030140745099590306164083365386343198191849338272965044442230921818680526749009182718
  384. )
  385. class TwistedEd448GoldilocksPoint(Decaf_1_1_Point):
  386. F = GF(2^448-2^224-1)
  387. d = F(-39082)
  388. a = F(-1)
  389. qnr = -1
  390. cofactor = 4
  391. encLen = 56
  392. isoMagic = IsoEd448Point.magic
  393. @classmethod
  394. def base(cls):
  395. return cls.decodeSpec(Ed448GoldilocksPoint.base().encodeSpec())
  396. class Ed448GoldilocksPoint(Decaf_1_1_Point):
  397. F = GF(2^448-2^224-1)
  398. d = F(-39081)
  399. a = F(1)
  400. qnr = -1
  401. cofactor = 4
  402. encLen = 56
  403. isoMagic = IsoEd448Point.magic
  404. @classmethod
  405. def base(cls):
  406. return 2*cls(
  407. 224580040295924300187604334099896036246789641632564134246125461686950415467406032909029192869357953282578032075146446173674602635247710, 298819210078481492676017930443930673437544040154080242095928241372331506189835876003536878655418784733982303233503462500531545062832660
  408. )
  409. class IsoEd25519Point(Decaf_1_1_Point):
  410. # TODO: twisted iso too!
  411. # TODO: twisted iso might have to IMAGINE_TWIST or whatever
  412. F = GF(2^255-19)
  413. d = F(-121665)
  414. a = F(1)
  415. i = sqrt(F(-1))
  416. qnr = i
  417. magic = isqrt(a*d-1)
  418. cofactor = 8
  419. encLen = 32
  420. isoMagic = Ed25519Point.magic
  421. isoA = Ed25519Point.a
  422. @classmethod
  423. def base(cls):
  424. return cls.decodeSpec(Ed25519Point.base().encode())
  425. class TestFailedException(Exception): pass
  426. def test(cls,n):
  427. print "Testing curve %s" % cls.__name__
  428. specials = [1]
  429. ii = cls.F(-1)
  430. while is_square(ii):
  431. specials.append(ii)
  432. ii = sqrt(ii)
  433. specials.append(ii)
  434. for i in specials:
  435. if negative(cls.F(i)): i = -i
  436. i = enc_le(i,cls.encLen)
  437. try:
  438. Q = cls.decode(i)
  439. QE = Q.encode()
  440. if QE != i:
  441. raise TestFailedException("Round trip special %s != %s" %
  442. (binascii.hexlify(QE),binascii.hexlify(i)))
  443. except NotOnCurveException: pass
  444. except InvalidEncodingException: pass
  445. P = cls.base()
  446. Q = cls()
  447. for i in xrange(n):
  448. #print binascii.hexlify(Q.encode())
  449. QE = Q.encode()
  450. QQ = cls.decode(QE)
  451. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  452. # Testing s -> 1/s: encodes -point on cofactor
  453. s = cls.bytesToGf(QE)
  454. if s != 0:
  455. ss = cls.gfToBytes(1/s,mustBePositive=True)
  456. try:
  457. QN = cls.decode(ss)
  458. if cls.cofactor == 8:
  459. raise TestFailedException("1/s shouldnt work for cofactor 8")
  460. if QN != -Q:
  461. raise TestFailedException("s -> 1/s should negate point for cofactor 4")
  462. except InvalidEncodingException as e:
  463. # Should be raised iff cofactor==8
  464. if cls.cofactor == 4:
  465. raise TestFailedException("s -> 1/s should work for cofactor 4")
  466. QT = Q
  467. for h in xrange(cls.cofactor):
  468. QT = QT.torque()
  469. if QT.encode() != QE:
  470. raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
  471. Q0 = Q + P
  472. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  473. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  474. r = randint(1,1000)
  475. Q1 = Q0*r
  476. Q2 = Q0*(r+1)
  477. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  478. Q = Q1
  479. test(Ed25519Point,100)
  480. test(NegEd25519Point,100)
  481. test(IsoEd25519Point,100)
  482. test(IsoEd448Point,100)
  483. test(TwistedEd448GoldilocksPoint,100)
  484. test(Ed448GoldilocksPoint,100)
  485. def testElligator(cls,n):
  486. print "Testing elligator on %s" % cls.__name__
  487. for i in xrange(n):
  488. cls.elligator(randombytes(cls.encLen))
  489. testElligator(Ed25519Point,100)
  490. testElligator(NegEd25519Point,100)
  491. testElligator(IsoEd448Point,100)
  492. testElligator(Ed448GoldilocksPoint,100)
  493. testElligator(TwistedEd448GoldilocksPoint,100)
  494. def gangtest(classes,n):
  495. print "Gang test",[cls.__name__ for cls in classes]
  496. specials = [1]
  497. ii = classes[0].F(-1)
  498. while is_square(ii):
  499. specials.append(ii)
  500. ii = sqrt(ii)
  501. specials.append(ii)
  502. for i in xrange(n):
  503. rets = [bytes((cls.base()*i).encode()) for cls in classes]
  504. if len(set(rets)) != 1:
  505. print "Divergence in encode at %d" % i
  506. for c,ret in zip(classes,rets):
  507. print c,binascii.hexlify(ret)
  508. print
  509. if i < len(specials): r0 = enc_le(specials[i],classes[0].encLen)
  510. else: r0 = randombytes(classes[0].encLen)
  511. rets = [bytes((cls.elligator(r0)*i).encode()) for cls in classes]
  512. if len(set(rets)) != 1:
  513. print "Divergence in elligator at %d" % i
  514. for c,ret in zip(classes,rets):
  515. print c,binascii.hexlify(ret)
  516. print
  517. gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100)
  518. gangtest([Ed25519Point,IsoEd25519Point],100)