You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

353 lines
11 KiB

  1. import binascii
  2. class InvalidEncodingException(Exception): pass
  3. class NotOnCurveException(Exception): pass
  4. class SpecException(Exception): pass
  5. def lobit(x): return int(x) & 1
  6. def hibit(x): return lobit(2*x)
  7. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  8. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  9. def randombytes(n): return bytearray([randint(0,255) for _ in range(n)])
  10. def optimized_version_of(spec):
  11. def decorator(f):
  12. def wrapper(self,*args,**kwargs):
  13. try: spec_ans = getattr(self,spec,spec)(*args,**kwargs),None
  14. except Exception as e: spec_ans = None,e
  15. try: opt_ans = f(self,*args,**kwargs),None
  16. except Exception as e: opt_ans = None,e
  17. if spec_ans[1] is None and opt_ans[1] is not None:
  18. raise SpecException("Mismatch in %s: spec returned %s but opt threw %s"
  19. % (f.__name__,str(spec_ans[0]),str(opt_ans[1])))
  20. if spec_ans[1] is not None and opt_ans[1] is None:
  21. raise SpecException("Mismatch in %s: spec threw %s but opt returned %s"
  22. % (f.__name__,str(spec_ans[1]),str(opt_ans[0])))
  23. if spec_ans[0] != opt_ans[0]:
  24. raise SpecException("Mismatch in %s: %s != %s"
  25. % (f.__name__,str(spec_ans[0]),str(opt_ans[0])))
  26. if opt_ans[1] is not None: raise opt_ans[1]
  27. else: return opt_ans[0]
  28. wrapper.__name__ = f.__name__
  29. return wrapper
  30. return decorator
  31. def xsqrt(x,exn=InvalidEncodingException("Not on curve")):
  32. """Return sqrt(x)"""
  33. if not is_square(x): raise exn
  34. s = sqrt(x)
  35. if lobit(s): s=-s
  36. return s
  37. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  38. """Return 1/sqrt(x)"""
  39. if x==0: return 0
  40. if not is_square(x): raise exn
  41. return 1/sqrt(x)
  42. def isqrt_i(x):
  43. """Return 1/sqrt(x) or 1/sqrt(zeta * x)"""
  44. if x==0: return 0
  45. gen = x.parent(-1)
  46. while is_square(gen): gen = sqrt(gen)
  47. if is_square(x): return True,1/sqrt(x)
  48. else: return False,1/sqrt(x*gen)
  49. class EdwardsPoint(object):
  50. """Abstract class for point an an Edwards curve; needs F,a,d to work"""
  51. def __init__(self,x=0,y=1):
  52. x = self.x = self.F(x)
  53. y = self.y = self.F(y)
  54. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  55. raise NotOnCurveException(str(self))
  56. def __repr__(self):
  57. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  58. def __iter__(self):
  59. yield self.x
  60. yield self.y
  61. def __add__(self,other):
  62. x,y = self
  63. X,Y = other
  64. a,d = self.a,self.d
  65. return self.__class__(
  66. (x*Y+y*X)/(1+d*x*y*X*Y),
  67. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  68. )
  69. def __neg__(self): return self.__class__(-self.x,self.y)
  70. def __sub__(self,other): return self + (-other)
  71. def __rmul__(self,other): return self*other
  72. def __eq__(self,other): return tuple(self) == tuple(other)
  73. def __ne__(self,other): return not (self==other)
  74. def __mul__(self,exp):
  75. exp = int(exp)
  76. total = self.__class__()
  77. work = self
  78. while exp != 0:
  79. if exp & 1: total += work
  80. work += work
  81. exp >>= 1
  82. return total
  83. def xyzt(self):
  84. x,y = self
  85. z = self.F.random_element()
  86. return x*z,y*z,z,x*y*z
  87. def torque(self):
  88. """Apply cofactor group, except keeping the point even"""
  89. if self.cofactor == 8:
  90. return self.__class__(self.y*self.i, self.x*self.i)
  91. else:
  92. return self.__class__(-self.x, -self.y)
  93. class RistrettoPoint(EdwardsPoint):
  94. """Like current decaf but tweaked for simplicity"""
  95. def __eq__(self,other):
  96. x,y = self
  97. X,Y = other
  98. return x*Y == X*y or x*X == y*Y
  99. @classmethod
  100. def bytesToGf(cls,bytes,mustBeProper=True,mustBePositive=False):
  101. """Convert little-endian bytes to field element, sanity check length"""
  102. if len(bytes) != cls.encLen:
  103. raise InvalidEncodingException("wrong length %d" % len(bytes))
  104. s = dec_le(bytes)
  105. if mustBeProper and s >= cls.F.modulus():
  106. raise InvalidEncodingException("%d out of range!" % s)
  107. if mustBePositive and lobit(s):
  108. raise InvalidEncodingException("%d is negative!" % s)
  109. return cls.F(s)
  110. def encodeSpec(self):
  111. """Unoptimized specification for encoding"""
  112. x,y = self
  113. if self.cofactor==8 and (lobit(x*y) or y==0):
  114. (x,y) = (self.i*y,self.i*x)
  115. if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl
  116. if lobit(x): x,y = -x,-y
  117. s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self)))
  118. return enc_le(s,self.encLen)
  119. @classmethod
  120. def decodeSpec(cls,s):
  121. """Unoptimized specification for decoding"""
  122. s = cls.bytesToGf(s,mustBePositive=True)
  123. a,d = cls.a,cls.d
  124. x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2))
  125. y = (1+a*s^2) / (1-a*s^2)
  126. if cls.cofactor==8 and (lobit(x*y) or y==0):
  127. raise InvalidEncodingException("x*y has high bit")
  128. return cls(x,y)
  129. @optimized_version_of("encodeSpec")
  130. def encode(self):
  131. """Encode, optimized version"""
  132. a,d = self.a,self.d
  133. x,y,z,t = self.xyzt()
  134. u1 = a*(y+z)*(y-z)
  135. u2 = x*y # = t*z
  136. isr = isqrt(u1*u2^2)
  137. i1 = isr*u1
  138. i2 = isr*u2
  139. z_inv = i1*i2*t
  140. rotate = self.cofactor==8 and lobit(t*z_inv)
  141. if rotate:
  142. x,y = y*self.i,x*self.i
  143. den_inv = self.magic * i1
  144. else:
  145. den_inv = i2
  146. if lobit(x*z_inv): y = -y
  147. s = (z-y) * den_inv
  148. if lobit(s): s=-s
  149. return enc_le(s,self.encLen)
  150. @classmethod
  151. @optimized_version_of("decodeSpec")
  152. def decode(cls,s):
  153. """Decode, optimized version"""
  154. s = cls.bytesToGf(s,mustBePositive=True)
  155. a,d = cls.a,cls.d
  156. yden = 1-a*s^2
  157. ynum = 1+a*s^2
  158. yden_sqr = yden^2
  159. xden_sqr = a*d*ynum^2 - yden_sqr
  160. isr = isqrt(xden_sqr * yden_sqr)
  161. xden_inv = isr * yden
  162. yden_inv = xden_inv * isr * xden_sqr
  163. x = 2*s*xden_inv
  164. if lobit(x): x = -x
  165. y = ynum * yden_inv
  166. if cls.cofactor==8 and (lobit(x*y) or y==0):
  167. raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y))
  168. return cls(x,y)
  169. @classmethod
  170. def fromJacobiQuartic(cls,s,t,sgn=1):
  171. """Convert point from its Jacobi Quartic representation"""
  172. a,d = cls.a,cls.d
  173. assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2
  174. x = 2*s*cls.magic / t
  175. if lobit(x): x = -x # TODO: doesn't work without resolving x
  176. y = (1+a*s^2) / (1-a*s^2)
  177. return cls(sgn*x,y)
  178. @classmethod
  179. def elligatorSpec(cls,r0):
  180. a,d = cls.a,cls.d
  181. r = cls.qnr * cls.bytesToGf(r0)^2
  182. den = (d*r-a)*(a*r-d)
  183. n1 = cls.a*(r+1)*(a+d)*(d-a)/den
  184. n2 = r*n1
  185. if is_square(n1):
  186. sgn,s,t = 1,xsqrt(n1), -(r-1)*(a+d)^2 / den - 1
  187. else:
  188. sgn,s,t = -1,xsqrt(n2), r*(r-1)*(a+d)^2 / den - 1
  189. ret = cls.fromJacobiQuartic(s,t,sgn)
  190. return ret
  191. @classmethod
  192. @optimized_version_of("elligatorSpec")
  193. def elligator(cls,r0):
  194. a,d = cls.a,cls.d
  195. r0 = cls.bytesToGf(r0)
  196. r = cls.qnr * r0^2
  197. den = (d*r-a)*(a*r-d)
  198. num = cls.a*(r+1)*(a+d)*(d-a)
  199. iss,isri = isqrt_i(num*den)
  200. if iss: sgn,twiddle = 1,1
  201. else: sgn,twiddle = -1,r0*cls.qnr
  202. isri *= twiddle
  203. s = isri*num
  204. t = isri*s*(r-1)*(d+a)^2 + sgn
  205. return cls.fromJacobiQuartic(s,t,sgn)
  206. class Ed25519Point(RistrettoPoint):
  207. F = GF(2^255-19)
  208. d = F(-121665/121666)
  209. a = F(-1)
  210. i = sqrt(F(-1))
  211. qnr = i
  212. magic = isqrt(a*d-1)
  213. cofactor = 8
  214. encLen = 32
  215. @classmethod
  216. def base(cls):
  217. y = cls.F(4/5)
  218. x = sqrt((y^2-1)/(cls.d*y^2+1))
  219. if lobit(x): x = -x
  220. return cls(x,y)
  221. class TwistedEd448GoldilocksPoint(RistrettoPoint):
  222. F = GF(2^448-2^224-1)
  223. d = F(-39082)
  224. a = F(-1)
  225. qnr = -1
  226. magic = isqrt(a*d-1)
  227. cofactor = 4
  228. encLen = 56
  229. @classmethod
  230. def base(cls):
  231. y = cls.F(6) # TODO: no it isn't
  232. x = sqrt((y^2-1)/(cls.d*y^2+1))
  233. if lobit(x): x = -x
  234. return cls(x,y)
  235. class Ed448GoldilocksPoint(RistrettoPoint):
  236. # TODO: decaf vs ristretto
  237. F = GF(2^448-2^224-1)
  238. d = F(-39081)
  239. a = F(1)
  240. qnr = -1
  241. magic = isqrt(a*d-1)
  242. cofactor = 4
  243. encLen = 56
  244. @classmethod
  245. def base(cls):
  246. return cls(
  247. 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555,
  248. 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed
  249. )
  250. class IsoEd448Point(RistrettoPoint):
  251. F = GF(2^448-2^224-1)
  252. d = F(1/39081+1)
  253. a = F(1)
  254. qnr = -1
  255. magic = isqrt(a*d-1)
  256. cofactor = 4
  257. encLen = 56
  258. @classmethod
  259. def base(cls):
  260. # = ..., -3/2
  261. return cls.decodeSpec(bytearray(binascii.unhexlify(
  262. "00000000000000000000000000000000000000000000000000000000"+
  263. "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff")))
  264. class TestFailedException(Exception): pass
  265. def test(cls,n):
  266. # TODO: test corner cases like 0,1,i
  267. P = cls.base()
  268. Q = cls()
  269. for i in xrange(n):
  270. #print binascii.hexlify(Q.encode())
  271. QQ = cls.decode(Q.encode())
  272. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  273. QT = Q
  274. QE = Q.encode()
  275. for h in xrange(cls.cofactor):
  276. QT = QT.torque()
  277. if QT.encode() != QE:
  278. raise TestFailedException("Can't torque %s,%d" % (str(Q),h+1))
  279. Q0 = Q + P
  280. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  281. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  282. r = randint(1,1000)
  283. Q1 = Q0*r
  284. Q2 = Q0*(r+1)
  285. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  286. Q = Q1
  287. test(Ed25519Point,100)
  288. test(TwistedEd448GoldilocksPoint,100)
  289. test(Ed448GoldilocksPoint,100)
  290. test(IsoEd448Point,100)
  291. def testElligator(cls,n):
  292. for i in xrange(n):
  293. cls.elligator(randombytes(cls.encLen))
  294. testElligator(Ed25519Point,100)
  295. testElligator(Ed448GoldilocksPoint,100)
  296. testElligator(TwistedEd448GoldilocksPoint,100)
  297. testElligator(IsoEd448Point,100)