A Pure Python implementation of Shamir's Secret Sharing

344 lines
9.6 KiB

  1. # Copyright 2023 John-Mark Gurney.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # 1. Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # 2. Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. #
  12. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  14. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  15. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  16. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  17. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  18. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  19. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  20. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  21. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  22. # SUCH DAMAGE.
  23. #
  24. #
  25. # ls shamirss.py | entr sh -c ' date; python -m coverage run -m unittest shamirss && coverage report -m'
  26. #
  27. '''
  28. An implementation of Shamir's Secret Sharing.
  29. This is over GF(2^256), so unlike some other implementations that are
  30. over primes, it is valid for ALL values, and the output will be exactly
  31. the same size as the secret. This also limits the number of shared to
  32. 255.
  33. Sample usage:
  34. ```
  35. import random
  36. from shamirss import *
  37. data = random.SystemRandom().randbytes(32)
  38. shares = create_shares(data, 3, 5)
  39. rdata = recover_data([ shares[1], shares[2], shares[4] ], 3)
  40. print(rdata == data)
  41. ```
  42. '''
  43. import functools
  44. import itertools
  45. import operator
  46. import secrets
  47. import unittest.mock
  48. random = secrets.SystemRandom()
  49. __all__ = [
  50. 'create_shares',
  51. 'recover_data',
  52. 'GF2p8',
  53. ]
  54. def _makered(x, y):
  55. '''Make reduction table entry.
  56. given x * 2^8, reduce it assuming polynomial y.
  57. '''
  58. x = x << 8
  59. for i in range(3, -1, -1):
  60. if x & (1 << (i + 8)):
  61. x ^= (0x100 + y) << i
  62. assert x < 256
  63. return x
  64. def evalpoly(polynomial, powers):
  65. return sum(( x * y for x, y in zip(polynomial, powers,
  66. strict=True)), 0)
  67. def create_shares(data, k, nshares):
  68. '''Given data, create nshares, such that given any k shares,
  69. data can be recovered.
  70. data must be bytes, or able to be converted to bytes, e.g. a list
  71. of ints in the range [ 0, 255 ].
  72. The return value will be a list of length nshares. Each element
  73. will be a tuple of (<int in range [ 1, nshares ]>, <bytes>).'''
  74. data = bytes(data)
  75. #print(repr(data), repr(k), repr(nshares))
  76. powers = (None, ) + tuple(GF2p8(x).powerseries(k - 1) for x in
  77. range(1, nshares + 1))
  78. coeffs = [ [ x ] + [ random.randint(0, 255) for y in
  79. range(k - 1) ] for idx, x in enumerate(data) ]
  80. return [ (x, bytes([ int(evalpoly(coeffs[idx],
  81. powers[x])) for idx, val in enumerate(data) ])) for x in
  82. range(1, nshares + 1) ]
  83. def recover_data(shares, k):
  84. '''Recover the value given shares, where k is the number of
  85. shares needed.
  86. shares must be as least length of k.
  87. Each element of shares is from one returned by create_shares,
  88. that is a tuple of an int and bytes.'''
  89. if len(shares) < k:
  90. raise ValueError('not enough shares to recover')
  91. return bytes([ int(sum([ GF2p8(y[idx]) *
  92. functools.reduce(operator.mul, [ pix * ((GF2p8(pix) - x) ** -1) for
  93. pix, piy in shares[:k] if pix != x ], 1) for x, y in shares[:k] ],
  94. 0)) for idx in range(len(shares[0][1]))])
  95. class GF2p8:
  96. _invcache = (None, 1, 195, 130, 162, 126, 65, 90, 81, 54, 63, 172, 227, 104, 45, 42, 235, 155, 27, 53, 220, 30, 86, 165, 178, 116, 52, 18, 213, 100, 21, 221, 182, 75, 142, 251, 206, 233, 217, 161, 110, 219, 15, 44, 43, 14, 145, 241, 89, 215, 58, 244, 26, 19, 9, 80, 169, 99, 50, 245, 201, 204, 173, 10, 91, 6, 230, 247, 71, 191, 190, 68, 103, 123, 183, 33, 175, 83, 147, 255, 55, 8, 174, 77, 196, 209, 22, 164, 214, 48, 7, 64, 139, 157, 187, 140, 239, 129, 168, 57, 29, 212, 122, 72, 13, 226, 202, 176, 199, 222, 40, 218, 151, 210, 242, 132, 25, 179, 185, 135, 167, 228, 102, 73, 149, 153, 5, 163, 238, 97, 3, 194, 115, 243, 184, 119, 224, 248, 156, 92, 95, 186, 34, 250, 240, 46, 254, 78, 152, 124, 211, 112, 148, 125, 234, 17, 138, 93, 188, 236, 216, 39, 4, 127, 87, 23, 229, 120, 98, 56, 171, 170, 11, 62, 82, 76, 107, 203, 24, 117, 192, 253, 32, 74, 134, 118, 141, 94, 158, 237, 70, 69, 180, 252, 131, 2, 84, 208, 223, 108, 205, 60, 106, 177, 61, 200, 36, 232, 197, 85, 113, 150, 101, 28, 88, 49, 160, 38, 111, 41, 20, 31, 109, 198, 136, 249, 105, 12, 121, 166, 66, 246, 207, 37, 154, 16, 159, 189, 128, 96, 144, 47, 114, 133, 51, 59, 231, 67, 137, 225, 143, 35, 193, 181, 146, 79)
  97. @staticmethod
  98. def _primativemul(a, b):
  99. masks = [ 0, 0xff ]
  100. r = 0
  101. for i in range(0, 8):
  102. mask = a & 1
  103. r ^= (masks[mask] & b) << i
  104. a = a >> 1
  105. return r
  106. # polynomial 0x187
  107. _reduce = tuple(_makered(x, 0x87) for x in range(0, 16))
  108. def __init__(self, v):
  109. if v >= 256:
  110. raise ValueError('%d is not a member of GF(2^8)' % v)
  111. self._v = v
  112. # basic operations
  113. def __add__(self, o):
  114. if isinstance(o, int):
  115. return self + self.__class__(o)
  116. return self.__class__(self._v ^ o._v)
  117. def __radd__(self, o):
  118. return self.__add__(o)
  119. def __sub__(self, o):
  120. return self.__add__(o)
  121. def __rsub__(self, o):
  122. return self.__sub__(o)
  123. def __mul__(self, o):
  124. if isinstance(o, int):
  125. o = self.__class__(o)
  126. m = o._v
  127. # multiply
  128. r = self._primativemul(self._v, m)
  129. # reduce
  130. r ^= self._reduce[r >> 12] << 4
  131. r ^= self._reduce[(r >> 8) & 0xf ]
  132. r &= 0xff
  133. return self.__class__(r)
  134. def __rmul__(self, o):
  135. return self.__mul__(o)
  136. def __pow__(self, x):
  137. if x == -1 and self._invcache:
  138. return GF2p8(self._invcache[self._v])
  139. if x < 0:
  140. x += 255
  141. v = self.__class__(1)
  142. # TODO - make faster via caching and squaring
  143. for i in range(x):
  144. v *= self
  145. return v
  146. def powerseries(self, cnt):
  147. '''Generate [ self ** 0, self ** 1, ..., self ** cnt ].'''
  148. r = [ self.__class__(1) ]
  149. for i in range(1, cnt + 1):
  150. r.append(r[-1] * self)
  151. return r
  152. def __eq__(self, o):
  153. if isinstance(o, int):
  154. return self._v == o
  155. return self._v == o._v
  156. def __int__(self):
  157. return self._v
  158. def __repr__(self):
  159. return '%s(%d)' % (self.__class__.__name__, self._v)
  160. class TestShamirSS(unittest.TestCase):
  161. def test_evalpoly(self):
  162. a = GF2p8(random.randint(0, 255))
  163. powers = a.powerseries(4)
  164. self.assertTrue(all(isinstance(x, GF2p8) for x in powers))
  165. vals = [ GF2p8(random.randint(0, 255)) for x in range(5) ]
  166. r = evalpoly(vals, powers)
  167. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  168. powers[2] + vals[3] * powers[3] + vals[4] * powers[4])
  169. r = evalpoly(vals[:3], powers[:3])
  170. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  171. powers[2])
  172. self.assertRaises(ValueError, evalpoly, [1], [1, 2])
  173. def test_create_shares(self):
  174. self.assertRaises(TypeError, create_shares, '', 1, 1)
  175. val = bytes([ random.randint(0, 255) for x in range(100) ])
  176. #val = b'this is a test of english text.'
  177. #val = b'1234'
  178. a = create_shares(val, 2, 3)
  179. self.assertNotIn(val, set(x[1] for x in a))
  180. # that it has the number of shares
  181. self.assertEqual(len(a), 3)
  182. # that the length of the share data matches passed in data
  183. self.assertEqual(len(a[0][1]), len(val))
  184. # that one share isn't enough
  185. self.assertRaises(ValueError, recover_data, [ a[0] ], 2)
  186. for i, j in itertools.combinations(range(3), 2):
  187. self.assertEqual(val, recover_data([ a[i], a[j] ], 2))
  188. self.assertEqual(val, recover_data([ a[j], a[i] ], 2))
  189. a = create_shares(val, 15, 30)
  190. for i in range(5):
  191. self.assertEqual(val, recover_data([ a[j] for j in random.sample(range(30), 15) ], 15))
  192. def test_gf2p8_inv(self):
  193. a = GF2p8(random.randint(0, 255))
  194. with unittest.mock.patch.object(GF2p8, '_invcache', []) as pinvc:
  195. ainv = a ** -1
  196. self.assertEqual(a * ainv, 1)
  197. invcache = (None, ) + \
  198. tuple(int(GF2p8(x) ** -1) for x in range(1, 256))
  199. if GF2p8._invcache != invcache: # pragma: no cover
  200. print('inv cache:', repr(invcache))
  201. self.assertEqual(GF2p8._invcache, invcache)
  202. def test_gf2p8_power(self):
  203. a = GF2p8(random.randint(0, 255))
  204. v = GF2p8(1)
  205. for i in range(10):
  206. self.assertEqual(a ** i, v)
  207. v = v * a
  208. for i in range(10):
  209. a = GF2p8(random.randint(0, 255))
  210. powers = a.powerseries(10)
  211. for j in range(11):
  212. self.assertEqual(powers[j], a ** j)
  213. def test_gf2p8_errors(self):
  214. self.assertRaises(ValueError, GF2p8, 1000)
  215. def test_gf2p8(self):
  216. self.assertEqual(int(GF2p8(5)), 5)
  217. self.assertEqual(repr(GF2p8(5)), 'GF2p8(5)')
  218. for i in range(10):
  219. a = GF2p8(random.randint(0, 255))
  220. b = GF2p8(random.randint(0, 255))
  221. c = GF2p8(random.randint(0, 255))
  222. self.assertEqual(a * 0, 0)
  223. # Identity
  224. self.assertEqual(a + 0, a)
  225. self.assertEqual(a * 1, a)
  226. self.assertEqual(0 + a, a)
  227. self.assertEqual(1 * a, a)
  228. self.assertEqual(0 - a, a)
  229. # Associativity
  230. self.assertEqual((a + b) + c, a + (b + c))
  231. self.assertEqual((a * b) * c, a * (b * c))
  232. # Communitative
  233. self.assertEqual(a + b, b + a)
  234. self.assertEqual(a * b, b * a)
  235. # Distributive
  236. self.assertEqual(a * (b + c), a * b + a * c)
  237. self.assertEqual((b + c) * a, b * a + c * a)
  238. # Basic mul
  239. self.assertEqual(GF2p8(0x80) * 2, 0x87)
  240. self.assertEqual(GF2p8(0x80) * 6,
  241. (0x80 * 6) ^ (0x187 << 1))
  242. self.assertEqual(GF2p8(0x80) * 8,
  243. (0x80 * 8) ^ (0x187 << 2) ^ (0x187 << 1) ^ 0x187)
  244. self.assertEqual(a + b - b, a)