You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

193 lines
5.7 KiB

  1. class InvalidEncodingException(Exception): pass
  2. class NotOnCurveException(Exception): pass
  3. def lobit(x): return int(x) & 1
  4. def hibit(x): return lobit(2*x)
  5. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  6. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  7. class EdwardsPoint(object):
  8. """Abstract class for point an an Edwards curve; needs F,a,d to work"""
  9. def __init__(self,x=0,y=1):
  10. x = self.x = self.F(x)
  11. y = self.y = self.F(y)
  12. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  13. raise NotOnCurveException()
  14. def __repr__(self):
  15. return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y)
  16. def __iter__(self):
  17. yield self.x
  18. yield self.y
  19. def __add__(self,other):
  20. x,y = self
  21. X,Y = other
  22. a,d = self.a,self.d
  23. return self.__class__(
  24. (x*Y+y*X)/(1+d*x*y*X*Y),
  25. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  26. )
  27. def __neg__(self): return self.__class__(-self.x,self.y)
  28. def __sub__(self,other): return self + (-other)
  29. def __rmul__(self,other): return self*other
  30. def __eq__(self,other): return tuple(self) == tuple(other)
  31. def __ne__(self,other): return not (self==other)
  32. def __mul__(self,exp):
  33. exp = int(exp)
  34. total = self.__class__()
  35. work = self
  36. while exp != 0:
  37. if exp & 1: total += work
  38. work += work
  39. exp >>= 1
  40. return total
  41. class Ed25519Point(EdwardsPoint):
  42. F = GF(2^255-19)
  43. d = F(-121665/121666)
  44. a = F(-1)
  45. i = sqrt(Ed25519Point.F(-1))
  46. @classmethod
  47. def base(cls):
  48. y = cls.F(4/5)
  49. x = sqrt((y^2-1)/(cls.d*y^2+1))
  50. if lobit(x): x = -x
  51. return cls(x,y)
  52. def torque(self):
  53. return self.__class__(self.y*self.i, self.x*self.i)
  54. class RistrettoOption1Point(Ed25519Point):
  55. """Like current decaf but tweaked for simplicity"""
  56. dMont = Ed25519Point.F(-121665)
  57. encLen = 32
  58. def __eq__(self,other):
  59. x,y = self
  60. X,Y = other
  61. return x*Y == X*y or x*X == y*Y
  62. def encode(self):
  63. x,y = self
  64. a,d = self.a,self.d
  65. if x*y == 0:
  66. # This happens anyway with straightforward impl
  67. return enc_le(0,self.encLen)
  68. if not is_square((1-y)/(1+y)):
  69. raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
  70. # Choose representative in 4-torsion group
  71. if lobit(x*y): (x,y) = (self.i*y,self.i*x)
  72. if lobit(x): x,y = -x,-y
  73. s = sqrt((1-y)/(1+y))
  74. if lobit(s): s = -s
  75. return enc_le(s,self.encLen)
  76. @classmethod
  77. def decode(cls,s):
  78. if len(s) != cls.encLen:
  79. raise InvalidEncodingException("wrong length %d" % len(s))
  80. s = dec_le(s)
  81. if s == 0: return cls(0,1)
  82. if s < 0 or s >= cls.F.modulus() or lobit(s):
  83. raise InvalidEncodingException("%d out of range!" % s)
  84. s = cls.F(s)
  85. magic = 4*cls.dMont-4
  86. if not is_square(magic*s^2 / ((s^2-1)^2 - s^2 * magic)):
  87. raise InvalidEncodingException("Not on curve")
  88. x = sqrt(magic*s^2 / ((s^2-1)^2 - magic * s^2))
  89. if lobit(x): x=-x
  90. y = (1-s^2)/(1+s^2)
  91. if lobit(x*y):
  92. raise InvalidEncodingException("x*y has high bit")
  93. return cls(x,y)
  94. class RistrettoOption2Point(Ed25519Point):
  95. """Works like current decaf"""
  96. dMont = Ed25519Point.F(-121665)
  97. magic = sqrt(dMont-1)
  98. encLen = 32
  99. def __eq__(self,other):
  100. x,y = self
  101. X,Y = other
  102. return x*Y == X*y or x*X == y*Y
  103. def encode(self):
  104. x,y = self
  105. a,d = self.a,self.d
  106. if x*y == 0:
  107. # This will happen anyway with straightforward square root trick
  108. return enc_le(0,self.encLen)
  109. if not is_square((1-y)/(1+y)):
  110. raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
  111. # Choose representative in 4-torsion group
  112. if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x)
  113. if hibit(2*self.magic/x): x,y = -x,-y
  114. s = sqrt((1-y)/(1+y))
  115. if hibit(s): s = -s
  116. return enc_le(s,self.encLen)
  117. @classmethod
  118. def decode(cls,s):
  119. if len(s) != cls.encLen:
  120. raise InvalidEncodingException("wrong length %d" % len(s))
  121. s = dec_le(s)
  122. if s == 0: return cls(0,1)
  123. if s < 0 or s >= (cls.F.modulus()+1)/2:
  124. raise InvalidEncodingException("%d out of range!" % s)
  125. s = cls.F(s)
  126. if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1):
  127. raise InvalidEncodingException("Not on curve")
  128. t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s
  129. if hibit(t): t = -t
  130. y = (1-s^2)/(1+s^2)
  131. x = 2*cls.magic/t
  132. if y == 0 or lobit(t/y):
  133. raise InvalidEncodingException("t/y has high bit")
  134. return cls(x,y)
  135. class TestFailedException(Exception): pass
  136. def test(cls,n):
  137. # TODO: test corner cases like 0,1,i
  138. P = cls.base()
  139. Q = cls()
  140. for i in xrange(n):
  141. QQ = cls.decode(Q.encode())
  142. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  143. if Q.encode() != Q.torque().encode():
  144. raise TestFailedException("Can't torque %s" % str(Q))
  145. Q0 = Q + P
  146. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  147. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  148. r = randint(1,1000)
  149. Q1 = Q0*r
  150. Q2 = Q0*(r+1)
  151. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  152. Q = Q1