You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

262 lines
7.5 KiB

  1. class InvalidEncodingException(Exception): pass
  2. class NotOnCurveException(Exception): pass
  3. def lobit(x): return int(x) & 1
  4. def hibit(x): return lobit(2*x)
  5. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  6. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  7. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  8. """Return 1/sqrt(x)"""
  9. if x==0: return 0
  10. if not is_square(x): raise exn
  11. return 1/sqrt(x)
  12. class EdwardsPoint(object):
  13. """Abstract class for point an an Edwards curve; needs F,a,d to work"""
  14. def __init__(self,x=0,y=1):
  15. x = self.x = self.F(x)
  16. y = self.y = self.F(y)
  17. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  18. raise NotOnCurveException()
  19. def __repr__(self):
  20. return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y)
  21. def __iter__(self):
  22. yield self.x
  23. yield self.y
  24. def __add__(self,other):
  25. x,y = self
  26. X,Y = other
  27. a,d = self.a,self.d
  28. return self.__class__(
  29. (x*Y+y*X)/(1+d*x*y*X*Y),
  30. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  31. )
  32. def __neg__(self): return self.__class__(-self.x,self.y)
  33. def __sub__(self,other): return self + (-other)
  34. def __rmul__(self,other): return self*other
  35. def __eq__(self,other): return tuple(self) == tuple(other)
  36. def __ne__(self,other): return not (self==other)
  37. def __mul__(self,exp):
  38. exp = int(exp)
  39. total = self.__class__()
  40. work = self
  41. while exp != 0:
  42. if exp & 1: total += work
  43. work += work
  44. exp >>= 1
  45. return total
  46. def xyzt(self):
  47. x,y = self
  48. z = self.F.random_element()
  49. return x*z,y*z,z,x*y*z
  50. class Ed25519Point(EdwardsPoint):
  51. F = GF(2^255-19)
  52. d = F(-121665/121666)
  53. a = F(-1)
  54. i = sqrt(F(-1))
  55. @classmethod
  56. def base(cls):
  57. y = cls.F(4/5)
  58. x = sqrt((y^2-1)/(cls.d*y^2+1))
  59. if lobit(x): x = -x
  60. return cls(x,y)
  61. def torque(self):
  62. return self.__class__(self.y*self.i, self.x*self.i)
  63. class RistrettoPoint(Ed25519Point):
  64. """Like current decaf but tweaked for simplicity"""
  65. encLen = 32
  66. def __eq__(self,other):
  67. x,y = self
  68. X,Y = other
  69. return x*Y == X*y or x*X == y*Y
  70. @staticmethod
  71. def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")):
  72. if not is_square(x): raise exn
  73. s = sqrt(x)
  74. if negative(s): s=-s
  75. return s
  76. def encode(self):
  77. x,y = self
  78. if lobit(x*y) or x==0: (x,y) = (self.i*y,self.i*x)
  79. if lobit(x): x,y = -x,-y
  80. s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even"))
  81. return enc_le(s,self.encLen)
  82. @classmethod
  83. def decode(cls,s):
  84. if len(s) != cls.encLen:
  85. raise InvalidEncodingException("wrong length %d" % len(s))
  86. s = dec_le(s)
  87. if s < 0 or s >= cls.F.modulus() or lobit(s):
  88. raise InvalidEncodingException("%d out of range!" % s)
  89. s = cls.F(s)
  90. x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2))
  91. y = (1-s^2) / (1+s^2)
  92. if lobit(x*y) or x==0:
  93. raise InvalidEncodingException("x*y has high bit")
  94. return cls(x,y)
  95. class OptimizedRistrettoPoint(RistrettoPoint):
  96. magic = isqrt(RistrettoPoint.d+1)
  97. """Like Ristretto but uses isqrt instead"""
  98. @classmethod
  99. def isqrt_and_inv(cls,isqrt,inv,*args,**kwargs):
  100. s = isqrt(isqrt*inv^2)
  101. return s*inv, s^2*isqrt*inv
  102. def encode(self):
  103. right_answer = super(OptimizedRistrettoPoint,self).encode()
  104. x,y,z,t = self.xyzt()
  105. x *= self.i
  106. u1 = (z+y)*(z-y)
  107. u2 = x*y # = t*z
  108. isr = isqrt(u1 * u2^2)
  109. i1 = isr*u1
  110. i2 = isr*u2
  111. z_inv = i1*i2*t
  112. rotate = lobit(t*self.i*z_inv)
  113. if rotate:
  114. x,y = y,x
  115. den_inv = self.magic * i1
  116. else:
  117. den_inv = i2
  118. if rotate ^^ lobit(x*z_inv): y = -y
  119. s = (z-y) * den_inv
  120. if s==0: s = F(1)
  121. if lobit(s): s=-s
  122. ret = enc_le(s,self.encLen)
  123. assert ret == right_answer
  124. return ret
  125. @classmethod
  126. def decode(cls,s):
  127. right_answer = super(cls,OptimizedRistrettoPoint).decode(s)
  128. # Sanity check s
  129. if len(s) != cls.encLen:
  130. raise InvalidEncodingException("wrong length %d" % len(s))
  131. s = dec_le(s)
  132. if s < 0 or s >= cls.F.modulus() or lobit(s):
  133. raise InvalidEncodingException("%d out of range!" % s)
  134. s = cls.F(s)
  135. yden = 1+s^2
  136. ynum = 1-s^2
  137. yden_sqr = yden^2
  138. xden_sqr = -cls.d*ynum^2 - yden_sqr
  139. isr = isqrt(xden_sqr * yden_sqr)
  140. xden_inv = isr * yden
  141. yden_inv = xden_inv * isr * xden_sqr
  142. x = 2*s*xden_inv
  143. if lobit(x): x = -x
  144. y = ynum * yden_inv
  145. if lobit(x*y) or x==0:
  146. raise InvalidEncodingException("x*y has high bit")
  147. ret = cls(x,y)
  148. assert ret == right_answer
  149. return ret
  150. class DecafPoint(Ed25519Point):
  151. """Works like current decaf"""
  152. dMont = Ed25519Point.F(-121665)
  153. magic = sqrt(dMont-1)
  154. encLen = 32
  155. def __eq__(self,other):
  156. x,y = self
  157. X,Y = other
  158. return x*Y == X*y or x*X == y*Y
  159. def encode(self):
  160. x,y = self
  161. a,d = self.a,self.d
  162. if x*y == 0:
  163. # This will happen anyway with straightforward square root trick
  164. return enc_le(0,self.encLen)
  165. if not is_square((1-y)/(1+y)):
  166. raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
  167. # Choose representative in 4-torsion group
  168. if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x)
  169. if hibit(2*self.magic/x): x,y = -x,-y
  170. s = sqrt((1-y)/(1+y))
  171. if hibit(s): s = -s
  172. return enc_le(s,self.encLen)
  173. @classmethod
  174. def decode(cls,s):
  175. if len(s) != cls.encLen:
  176. raise InvalidEncodingException("wrong length %d" % len(s))
  177. s = dec_le(s)
  178. if s == 0: return cls(0,1)
  179. if s < 0 or s >= (cls.F.modulus()+1)/2:
  180. raise InvalidEncodingException("%d out of range!" % s)
  181. s = cls.F(s)
  182. if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1):
  183. raise InvalidEncodingException("Not on curve")
  184. t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s
  185. if hibit(t): t = -t
  186. y = (1-s^2)/(1+s^2)
  187. x = 2*cls.magic/t
  188. if y == 0 or lobit(t/y):
  189. raise InvalidEncodingException("t/y has high bit")
  190. return cls(x,y)
  191. class TestFailedException(Exception): pass
  192. def test(cls,n):
  193. # TODO: test corner cases like 0,1,i
  194. P = cls.base()
  195. Q = cls()
  196. for i in xrange(n):
  197. QQ = cls.decode(Q.encode())
  198. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  199. if Q.encode() != Q.torque().encode():
  200. raise TestFailedException("Can't torque %s" % str(Q))
  201. Q0 = Q + P
  202. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  203. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  204. r = randint(1,1000)
  205. Q1 = Q0*r
  206. Q2 = Q0*(r+1)
  207. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  208. Q = Q1