You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

181 lines
5.4 KiB

  1. class InvalidEncodingException(Exception): pass
  2. class NotOnCurveException(Exception): pass
  3. def lobit(x): return int(x) & 1
  4. def hibit(x): return lobit(2*x)
  5. def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)])
  6. def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x))
  7. class EdwardsPoint(object):
  8. """Abstract class for point an an Edwards curve; needs F,a,d to work"""
  9. def __init__(self,x=0,y=1):
  10. x = self.x = self.F(x)
  11. y = self.y = self.F(y)
  12. if y^2 + self.a*x^2 != 1 + self.d*x^2*y^2:
  13. raise NotOnCurveException()
  14. def __repr__(self):
  15. return "%s(%d,%d)" % (self.__class__.__name__, self.x, self.y)
  16. def __iter__(self):
  17. yield self.x
  18. yield self.y
  19. def __add__(self,other):
  20. x,y = self
  21. X,Y = other
  22. a,d = self.a,self.d
  23. return self.__class__(
  24. (x*Y+y*X)/(1+d*x*y*X*Y),
  25. (y*Y-a*x*X)/(1-d*x*y*X*Y)
  26. )
  27. def __neg__(self): return self.__class__(-self.x,self.y)
  28. def __sub__(self,other): return self + (-other)
  29. def __rmul__(self,other): return self*other
  30. def __eq__(self,other): return tuple(self) == tuple(other)
  31. def __ne__(self,other): return not (self==other)
  32. def __mul__(self,exp):
  33. exp = int(exp)
  34. total = self.__class__()
  35. work = self
  36. while exp != 0:
  37. if exp & 1: total += work
  38. work += work
  39. exp >>= 1
  40. return total
  41. class Ed25519Point(EdwardsPoint):
  42. F = GF(2^255-19)
  43. d = F(-121665/121666)
  44. a = F(-1)
  45. i = sqrt(Ed25519Point.F(-1))
  46. @classmethod
  47. def base(cls):
  48. y = cls.F(4/5)
  49. x = sqrt((y^2-1)/(cls.d*y^2+1))
  50. if lobit(x): x = -x
  51. return cls(x,y)
  52. def torque(self):
  53. return self.__class__(self.y*self.i, self.x*self.i)
  54. class RistrettoPoint(Ed25519Point):
  55. """Like current decaf but tweaked for simplicity"""
  56. encLen = 32
  57. def __eq__(self,other):
  58. x,y = self
  59. X,Y = other
  60. return x*Y == X*y or x*X == y*Y
  61. @staticmethod
  62. def sqrt(x,negative=lobit,exn=InvalidEncodingException("Not on curve")):
  63. if not is_square(x): raise exn
  64. s = sqrt(x)
  65. if negative(s): s=-s
  66. return s
  67. def encode(self):
  68. x,y = self
  69. if lobit(x*y) or x==0: (x,y) = (self.i*y,self.i*x)
  70. if lobit(x): x,y = -x,-y
  71. s = self.sqrt((1-y)/(1+y),exn=Exception("Unimplemented: point is even"))
  72. return enc_le(s,self.encLen)
  73. @classmethod
  74. def decode(cls,s):
  75. if len(s) != cls.encLen:
  76. raise InvalidEncodingException("wrong length %d" % len(s))
  77. s = dec_le(s)
  78. if s < 0 or s >= cls.F.modulus() or lobit(s):
  79. raise InvalidEncodingException("%d out of range!" % s)
  80. s = cls.F(s)
  81. x = cls.sqrt(-4*s^2 / (cls.d*(s^2-1)^2 + (s^2+1)^2))
  82. y = (1-s^2) / (1+s^2)
  83. if lobit(x*y) or x==0:
  84. raise InvalidEncodingException("x*y has high bit")
  85. return cls(x,y)
  86. class DecafPoint(Ed25519Point):
  87. """Works like current decaf"""
  88. dMont = Ed25519Point.F(-121665)
  89. magic = sqrt(dMont-1)
  90. encLen = 32
  91. def __eq__(self,other):
  92. x,y = self
  93. X,Y = other
  94. return x*Y == X*y or x*X == y*Y
  95. def encode(self):
  96. x,y = self
  97. a,d = self.a,self.d
  98. if x*y == 0:
  99. # This will happen anyway with straightforward square root trick
  100. return enc_le(0,self.encLen)
  101. if not is_square((1-y)/(1+y)):
  102. raise Exception("Unimplemented: odd point in RistrettoPoint.encode")
  103. # Choose representative in 4-torsion group
  104. if hibit(self.magic/(x*y)): (x,y) = (self.i*y,self.i*x)
  105. if hibit(2*self.magic/x): x,y = -x,-y
  106. s = sqrt((1-y)/(1+y))
  107. if hibit(s): s = -s
  108. return enc_le(s,self.encLen)
  109. @classmethod
  110. def decode(cls,s):
  111. if len(s) != cls.encLen:
  112. raise InvalidEncodingException("wrong length %d" % len(s))
  113. s = dec_le(s)
  114. if s == 0: return cls(0,1)
  115. if s < 0 or s >= (cls.F.modulus()+1)/2:
  116. raise InvalidEncodingException("%d out of range!" % s)
  117. s = cls.F(s)
  118. if not is_square(s^4 + (2-4*cls.dMont)*s^2 + 1):
  119. raise InvalidEncodingException("Not on curve")
  120. t = sqrt(s^4 + (2-4*cls.dMont)*s^2 + 1)/s
  121. if hibit(t): t = -t
  122. y = (1-s^2)/(1+s^2)
  123. x = 2*cls.magic/t
  124. if y == 0 or lobit(t/y):
  125. raise InvalidEncodingException("t/y has high bit")
  126. return cls(x,y)
  127. class TestFailedException(Exception): pass
  128. def test(cls,n):
  129. # TODO: test corner cases like 0,1,i
  130. P = cls.base()
  131. Q = cls()
  132. for i in xrange(n):
  133. QQ = cls.decode(Q.encode())
  134. if QQ != Q: raise TestFailedException("Round trip %s != %s" % (str(QQ),str(Q)))
  135. if Q.encode() != Q.torque().encode():
  136. raise TestFailedException("Can't torque %s" % str(Q))
  137. Q0 = Q + P
  138. if Q0 == Q: raise TestFailedException("Addition doesn't work")
  139. if Q0-P != Q: raise TestFailedException("Subtraction doesn't work")
  140. r = randint(1,1000)
  141. Q1 = Q0*r
  142. Q2 = Q0*(r+1)
  143. if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work")
  144. Q = Q1