You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

224 lines
6.3 KiB

  1. class InvalidEncodingException(Exception): pass
  2. class NotOnCurveException(Exception): pass
  3. def lobit(x): return int(x) & 1
  4. def hibit(x): return lobit(2*x)
  5. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  6. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  7. def isqrt(x,exn=InvalidEncodingException("Not on curve")):
  8. """Return 1/sqrt(x)"""
  9. if x==0: return 0
  10. if not is_square(x): raise exn
  11. return 1/sqrt(x)
  12. class EdwardsPoint(object):
  13. """Abstract class for point an an Edwards curve; needs F,a,d to work"""
  14. def __init__(self,x=0,y=1):
  15. x = self.x = self.F(x)
  16. y = self.y = self.F(y)
  17. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  18. raise NotOnCurveException()
  19. def __repr__(self):
  20. return "%s(0x%x,0x%x)" % (self.__class__.__name__, self.x, self.y)
  21. def __iter__(self):
  22. yield self.x
  23. yield self.y
  24. def __add__(self,other):
  25. x,y = self
  26. X,Y = other
  27. a,d = self.a,self.d
  28. return self.__class__(
  29. (x*Y+y*X)/(1+d*x*y*X*Y),
  30. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  31. )
  32. def __neg__(self): return self.__class__(-self.x,self.y)
  33. def __sub__(self,other): return self + (-other)
  34. def __rmul__(self,other): return self*other
  35. def __eq__(self,other): return tuple(self) == tuple(other)
  36. def __ne__(self,other): return not (self==other)
  37. def __mul__(self,exp):
  38. exp = int(exp)
  39. total = self.__class__()
  40. work = self
  41. while exp != 0:
  42. if exp & 1: total += work
  43. work += work
  44. exp >>= 1
  45. return total
  46. def xyzt(self):
  47. x,y = self
  48. z = self.F.random_element()
  49. return x*z,y*z,z,x*y*z
  50. class RistrettoPoint(EdwardsPoint):
  51. """Like current decaf but tweaked for simplicity"""
  52. def __eq__(self,other):
  53. x,y = self
  54. X,Y = other
  55. return x*Y == X*y or x*X == y*Y
  56. @staticmethod
  57. def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")):
  58. if not is_square(x): raise exn
  59. s = sqrt(x)
  60. if negative(s): s=-s
  61. return s
  62. def encodeSpec(self):
  63. """Unoptimized specification for encoding"""
  64. x,y = self
  65. if self.cofactor==8 and (lobit(x*y) or x==0):
  66. (x,y) = (self.i*y,self.i*x)
  67. elif self.cofactor==4 and y==-1:
  68. y = 1 # Doesn't affect impl
  69. if lobit(x): y=-y
  70. s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even"))
  71. return enc_le(s,self.encLen)
  72. @classmethod
  73. def decodeSpec(cls,s):
  74. """Unoptimized specification for decoding"""
  75. if len(s) != cls.encLen:
  76. raise InvalidEncodingException("wrong length %d" % len(s))
  77. s = dec_le(s)
  78. if s < 0 or s >= cls.F.modulus() or lobit(s):
  79. raise InvalidEncodingException("%d out of range!" % s)
  80. s = cls.F(s)
  81. x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2))
  82. y = (1-s^2) / (1+s^2)
  83. if cls.cofactor==8 and (lobit(x*y) or x==0):
  84. raise InvalidEncodingException("x*y has high bit")
  85. return cls(x,y)
  86. def encode(self):
  87. x,y,z,t = self.xyzt()
  88. u1 = (z+y)*(z-y)
  89. u2 = x*y # = t*z
  90. isr = isqrt(u1 * u2^2)
  91. i1 = isr*u1
  92. i2 = isr*u2
  93. z_inv = i1*i2*t
  94. rotate = self.cofactor==8 and lobit(t*z_inv)
  95. if rotate:
  96. magic = isqrt(-self.d-1)
  97. x,y = y*self.i,x*self.i
  98. den_inv = magic * i1
  99. else:
  100. den_inv = i2
  101. if lobit(x*z_inv): y = -y
  102. s = (z-y) * den_inv
  103. if self.cofactor==8 and s==0: s += 1
  104. if lobit(s): s=-s
  105. ret = enc_le(s,self.encLen)
  106. assert ret == self.encodeSpec()
  107. return ret
  108. @classmethod
  109. def decode(cls,s):
  110. right_answer = cls.decodeSpec(s)
  111. # Sanity check s
  112. if len(s) != cls.encLen:
  113. raise InvalidEncodingException("wrong length %d" % len(s))
  114. s = dec_le(s)
  115. if s < 0 or s >= cls.F.modulus() or lobit(s):
  116. raise InvalidEncodingException("%d out of range!" % s)
  117. s = cls.F(s)
  118. yden = 1+s^2
  119. ynum = 1-s^2
  120. yden_sqr = yden^2
  121. xden_sqr = -cls.d*ynum^2 - yden_sqr
  122. isr = isqrt(xden_sqr * yden_sqr)
  123. xden_inv = isr * yden
  124. yden_inv = xden_inv * isr * xden_sqr
  125. x = 2*s*xden_inv
  126. if lobit(x): x = -x
  127. y = ynum * yden_inv
  128. if cls.cofactor==8 and (lobit(x*y) or x==0):
  129. raise InvalidEncodingException("x*y has high bit")
  130. ret = cls(x,y)
  131. assert ret == right_answer
  132. return ret
  133. def torque(self):
  134. if self.cofactor == 8:
  135. return self.__class__(self.y*self.i, self.x*self.i)
  136. else:
  137. return self.__class__(-self.x, -self.y)
  138. class Ed25519Point(RistrettoPoint):
  139. F = GF(2^255-19)
  140. d = F(-121665/121666)
  141. a = F(-1)
  142. i = sqrt(F(-1))
  143. cofactor = 8
  144. encLen = 32
  145. @classmethod
  146. def base(cls):
  147. y = cls.F(4/5)
  148. x = sqrt((y^2-1)/(cls.d*y^2+1))
  149. if lobit(x): x = -x
  150. return cls(x,y)
  151. class Ed448Point(RistrettoPoint):
  152. F = GF(2^448-2^224-1)
  153. d = F(-39082)
  154. a = F(-1)
  155. cofactor = 4
  156. encLen = 56
  157. @classmethod
  158. def base(cls):
  159. y = cls.F(6) # FIXME: no it isn't
  160. x = sqrt((y^2-1)/(cls.d*y^2+1))
  161. if lobit(x): x = -x
  162. return cls(x,y)
  163. class TestFailedException(Exception): pass
  164. def test(cls,n):
  165. # TODO: test corner cases like 0,1,i
  166. P = cls.base()
  167. Q = cls()
  168. for i in xrange(n):
  169. QQ = cls.decode(Q.encode())
  170. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  171. if Q.encode() != Q.torque().encode():
  172. raise TestFailedException("Can't torque %s" % str(Q))
  173. Q0 = Q + P
  174. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  175. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  176. r = randint(1,1000)
  177. Q1 = Q0*r
  178. Q2 = Q0*(r+1)
  179. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  180. Q = Q1