A Pure Python implementation of Shamir's Secret Sharing
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
8.7 KiB

  1. # Copyright 2023 John-Mark Gurney.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # 1. Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # 2. Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. #
  12. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  14. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  15. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  16. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  17. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  18. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  19. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  20. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  21. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  22. # SUCH DAMAGE.
  23. #
  24. #
  25. # ls shamirss.py | entr sh -c ' date; python -m coverage run -m unittest shamirss && coverage report -m'
  26. #
  27. import functools
  28. import operator
  29. import secrets
  30. import unittest.mock
  31. random = secrets.SystemRandom()
  32. __all__ = [
  33. 'create_shares',
  34. 'recover_data',
  35. 'GF2p8',
  36. ]
  37. def _makered(x, y):
  38. '''Make reduction table entry.
  39. given x * 2^8, reduce it assuming polynomial y.
  40. '''
  41. x = x << 8
  42. for i in range(3, -1, -1):
  43. if x & (1 << (i + 8)):
  44. x ^= (0x100 + y) << i
  45. assert x < 256
  46. return x
  47. def evalpoly(polynomial, powers):
  48. return sum(( x * y for x, y in zip(polynomial, powers,
  49. strict=True)), 0)
  50. def create_shares(data, k, nshares):
  51. '''Given data, create nshares, such that given any k shares,
  52. data can be recovered.
  53. data must be bytes, or able to be converted to bytes.
  54. The return value will be a list of length nshares. Each element
  55. will be a tuple of (<int in range [1, nshares + 1)>, <bytes>).'''
  56. data = bytes(data)
  57. #print(repr(data), repr(k), repr(nshares))
  58. powers = (None, ) + tuple(GF2p8(x).powerseries(k - 1) for x in
  59. range(1, nshares + 1))
  60. coeffs = [ [ x ] + [ random.randint(0, 255) for y in
  61. range(k - 1) ] for idx, x in enumerate(data) ]
  62. return [ (x, bytes([ int(evalpoly(coeffs[idx],
  63. powers[x])) for idx, val in enumerate(data) ])) for x in
  64. range(1, nshares + 1) ]
  65. def recover_data(shares, k):
  66. '''Recover the value given shares, where k is needed.
  67. shares must be as least length of k.'''
  68. if len(shares) < k:
  69. raise ValueError('not enough shares to recover')
  70. return bytes([ int(sum([ GF2p8(y[idx]) *
  71. functools.reduce(operator.mul, [ pix * ((GF2p8(pix) - x) ** -1) for
  72. pix, piy in shares[:k] if pix != x ], 1) for x, y in shares[:k] ],
  73. 0)) for idx in range(len(shares[0][1]))])
  74. class GF2p8:
  75. _invcache = (None, 1, 195, 130, 162, 126, 65, 90, 81, 54, 63, 172, 227, 104, 45, 42, 235, 155, 27, 53, 220, 30, 86, 165, 178, 116, 52, 18, 213, 100, 21, 221, 182, 75, 142, 251, 206, 233, 217, 161, 110, 219, 15, 44, 43, 14, 145, 241, 89, 215, 58, 244, 26, 19, 9, 80, 169, 99, 50, 245, 201, 204, 173, 10, 91, 6, 230, 247, 71, 191, 190, 68, 103, 123, 183, 33, 175, 83, 147, 255, 55, 8, 174, 77, 196, 209, 22, 164, 214, 48, 7, 64, 139, 157, 187, 140, 239, 129, 168, 57, 29, 212, 122, 72, 13, 226, 202, 176, 199, 222, 40, 218, 151, 210, 242, 132, 25, 179, 185, 135, 167, 228, 102, 73, 149, 153, 5, 163, 238, 97, 3, 194, 115, 243, 184, 119, 224, 248, 156, 92, 95, 186, 34, 250, 240, 46, 254, 78, 152, 124, 211, 112, 148, 125, 234, 17, 138, 93, 188, 236, 216, 39, 4, 127, 87, 23, 229, 120, 98, 56, 171, 170, 11, 62, 82, 76, 107, 203, 24, 117, 192, 253, 32, 74, 134, 118, 141, 94, 158, 237, 70, 69, 180, 252, 131, 2, 84, 208, 223, 108, 205, 60, 106, 177, 61, 200, 36, 232, 197, 85, 113, 150, 101, 28, 88, 49, 160, 38, 111, 41, 20, 31, 109, 198, 136, 249, 105, 12, 121, 166, 66, 246, 207, 37, 154, 16, 159, 189, 128, 96, 144, 47, 114, 133, 51, 59, 231, 67, 137, 225, 143, 35, 193, 181, 146, 79)
  76. @staticmethod
  77. def _primativemul(a, b):
  78. masks = [ 0, 0xff ]
  79. r = 0
  80. for i in range(0, 8):
  81. mask = a & 1
  82. r ^= (masks[mask] & b) << i
  83. a = a >> 1
  84. return r
  85. # polynomial 0x187
  86. _reduce = tuple(_makered(x, 0x87) for x in range(0, 16))
  87. def __init__(self, v):
  88. if v >= 256:
  89. raise ValueError('%d is not a member of GF(2^8)' % v)
  90. self._v = v
  91. # basic operations
  92. def __add__(self, o):
  93. if isinstance(o, int):
  94. return self + self.__class__(o)
  95. return self.__class__(self._v ^ o._v)
  96. def __radd__(self, o):
  97. return self.__add__(o)
  98. def __sub__(self, o):
  99. return self.__add__(o)
  100. def __rsub__(self, o):
  101. return self.__sub__(o)
  102. def __mul__(self, o):
  103. if isinstance(o, int):
  104. o = self.__class__(o)
  105. m = o._v
  106. # multiply
  107. r = self._primativemul(self._v, m)
  108. # reduce
  109. r ^= self._reduce[r >> 12] << 4
  110. r ^= self._reduce[(r >> 8) & 0xf ]
  111. r &= 0xff
  112. return self.__class__(r)
  113. def __rmul__(self, o):
  114. return self.__mul__(o)
  115. def __pow__(self, x):
  116. if x == -1 and self._invcache:
  117. return GF2p8(self._invcache[self._v])
  118. if x < 0:
  119. x += 255
  120. v = self.__class__(1)
  121. # TODO - make faster via caching and squaring
  122. for i in range(x):
  123. v *= self
  124. return v
  125. def powerseries(self, cnt):
  126. '''Generate [ self ** 0, self ** 1, ..., self ** cnt ].'''
  127. r = [ self.__class__(1) ]
  128. for i in range(1, cnt + 1):
  129. r.append(r[-1] * self)
  130. return r
  131. def __eq__(self, o):
  132. if isinstance(o, int):
  133. return self._v == o
  134. return self._v == o._v
  135. def __int__(self):
  136. return self._v
  137. def __repr__(self):
  138. return '%s(%d)' % (self.__class__.__name__, self._v)
  139. class TestShamirSS(unittest.TestCase):
  140. def test_evalpoly(self):
  141. a = GF2p8(random.randint(0, 255))
  142. powers = a.powerseries(4)
  143. self.assertTrue(all(isinstance(x, GF2p8) for x in powers))
  144. vals = [ GF2p8(random.randint(0, 255)) for x in range(5) ]
  145. r = evalpoly(vals, powers)
  146. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  147. powers[2] + vals[3] * powers[3] + vals[4] * powers[4])
  148. r = evalpoly(vals[:3], powers[:3])
  149. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  150. powers[2])
  151. self.assertRaises(ValueError, evalpoly, [1], [1, 2])
  152. def test_create_shares(self):
  153. self.assertRaises(TypeError, create_shares, '', 1, 1)
  154. val = bytes([ random.randint(0, 255) for x in range(100) ])
  155. #val = b'this is a test of english text.'
  156. #val = b'1234'
  157. a = create_shares(val, 2, 3)
  158. self.assertNotIn(val, set(x[1] for x in a))
  159. # that it has the number of shares
  160. self.assertEqual(len(a), 3)
  161. # that the length of the share data matches passed in data
  162. self.assertEqual(len(a[0][1]), len(val))
  163. # that one share isn't enough
  164. self.assertRaises(ValueError, recover_data, [ a[0] ], 2)
  165. self.assertEqual(val, recover_data(a[:2], 2))
  166. def test_gf2p8_inv(self):
  167. a = GF2p8(random.randint(0, 255))
  168. with unittest.mock.patch.object(GF2p8, '_invcache', []) as pinvc:
  169. ainv = a ** -1
  170. self.assertEqual(a * ainv, 1)
  171. invcache = (None, ) + \
  172. tuple(int(GF2p8(x) ** -1) for x in range(1, 256))
  173. if GF2p8._invcache != invcache: # pragma: no cover
  174. print('inv cache:', repr(invcache))
  175. self.assertEqual(GF2p8._invcache, invcache)
  176. def test_gf2p8_power(self):
  177. a = GF2p8(random.randint(0, 255))
  178. v = GF2p8(1)
  179. for i in range(10):
  180. self.assertEqual(a ** i, v)
  181. v = v * a
  182. for i in range(10):
  183. a = GF2p8(random.randint(0, 255))
  184. powers = a.powerseries(10)
  185. for j in range(11):
  186. self.assertEqual(powers[j], a ** j)
  187. def test_gf2p8_errors(self):
  188. self.assertRaises(ValueError, GF2p8, 1000)
  189. def test_gf2p8(self):
  190. self.assertEqual(int(GF2p8(5)), 5)
  191. self.assertEqual(repr(GF2p8(5)), 'GF2p8(5)')
  192. for i in range(10):
  193. a = GF2p8(random.randint(0, 255))
  194. b = GF2p8(random.randint(0, 255))
  195. c = GF2p8(random.randint(0, 255))
  196. self.assertEqual(a * 0, 0)
  197. # Identity
  198. self.assertEqual(a + 0, a)
  199. self.assertEqual(a * 1, a)
  200. self.assertEqual(0 + a, a)
  201. self.assertEqual(1 * a, a)
  202. self.assertEqual(0 - a, a)
  203. # Associativity
  204. self.assertEqual((a + b) + c, a + (b + c))
  205. self.assertEqual((a * b) * c, a * (b * c))
  206. # Communitative
  207. self.assertEqual(a + b, b + a)
  208. self.assertEqual(a * b, b * a)
  209. # Distributive
  210. self.assertEqual(a * (b + c), a * b + a * c)
  211. self.assertEqual((b + c) * a, b * a + c * a)
  212. # Basic mul
  213. self.assertEqual(GF2p8(0x80) * 2, 0x87)
  214. self.assertEqual(GF2p8(0x80) * 6,
  215. (0x80 * 6) ^ (0x187 << 1))
  216. self.assertEqual(GF2p8(0x80) * 8,
  217. (0x80 * 8) ^ (0x187 << 2) ^ (0x187 << 1) ^ 0x187)
  218. self.assertEqual(a + b - b, a)