A Pure Python implementation of Shamir's Secret Sharing
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

371 lines
10 KiB

  1. # Copyright 2023 John-Mark Gurney.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # 1. Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # 2. Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. #
  12. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  14. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  15. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  16. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  17. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  18. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  19. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  20. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  21. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  22. # SUCH DAMAGE.
  23. #
  24. #
  25. # ls shamirss.py | entr sh -c ' date; python -m coverage run -m unittest shamirss && coverage report -m'
  26. #
  27. '''
  28. An implementation of Shamir's Secret Sharing.
  29. This is over GF(2^256), so unlike some other implementations that are
  30. over primes, it is valid for ALL values, and the output will be exactly
  31. the same length as the secret. This limits the number of shares to
  32. 255.
  33. Sample usage:
  34. ```
  35. import random
  36. from shamirss import *
  37. data = random.SystemRandom().randbytes(32)
  38. shares = create_shares(data, 3, 5)
  39. rdata = recover_data([ shares[1], shares[2], shares[4] ], 3)
  40. print(rdata == data)
  41. ```
  42. '''
  43. import functools
  44. import itertools
  45. import operator
  46. import secrets
  47. import unittest.mock
  48. random = secrets.SystemRandom()
  49. __all__ = [
  50. 'create_shares',
  51. 'recover_data',
  52. 'GF2p8',
  53. ]
  54. def _makered(x, y):
  55. '''Make reduction table entry.
  56. given x * 2^8, reduce it assuming polynomial y.
  57. '''
  58. x = x << 8
  59. for i in range(3, -1, -1):
  60. if x & (1 << (i + 8)):
  61. x ^= (0x100 + y) << i
  62. assert x < 256
  63. return x
  64. def evalpoly(polynomial, powers):
  65. return sum(( x * y for x, y in zip(polynomial, powers,
  66. strict=True)), 0)
  67. def create_shares(data, k, nshares):
  68. '''Given data, create nshares, such that given any k shares,
  69. data can be recovered.
  70. data must be bytes, or able to be converted to bytes, e.g. a list
  71. of ints in the range [ 0, 255 ].
  72. The return value will be a list of length nshares. Each element
  73. will be a tuple of (<int in range [ 1, nshares ]>, <bytes>).'''
  74. data = bytes(data)
  75. #print(repr(data), repr(k), repr(nshares))
  76. powers = (None, ) + tuple(GF2p8(x).powerseries(k - 1) for x in
  77. range(1, nshares + 1))
  78. coeffs = [ [ x ] + [ random.randint(0, 255) for y in
  79. range(k - 1) ] for idx, x in enumerate(data) ]
  80. return [ (x, bytes([ int(evalpoly(coeffs[idx],
  81. powers[x])) for idx, val in enumerate(data) ])) for x in
  82. range(1, nshares + 1) ]
  83. def recover_data(shares, k):
  84. '''Recover the value given shares, where k is the number of
  85. shares needed.
  86. shares must be as least length of k.
  87. Each element of shares is from one returned by create_shares,
  88. that is a tuple of an int and bytes.'''
  89. if len(shares) < k:
  90. raise ValueError('not enough shares to recover')
  91. return bytes([ int(sum(( GF2p8(y[idx]) *
  92. functools.reduce(operator.mul, ( pix * ((GF2p8(pix) - x) ** -1) for
  93. pix, piy in shares[:k] if pix != x ), 1) for x, y in shares[:k] ),
  94. 0)) for idx in range(len(shares[0][1]))])
  95. class GF2p8:
  96. # polynomial 0x187
  97. '''An implementation of GF(2^8). It uses the polynomial 0x187
  98. or x^8 + x^7 + x^2 + x + 1.
  99. '''
  100. _invcache = (None, 1, 195, 130, 162, 126, 65, 90, 81, 54, 63, 172, 227, 104, 45, 42, 235, 155, 27, 53, 220, 30, 86, 165, 178, 116, 52, 18, 213, 100, 21, 221, 182, 75, 142, 251, 206, 233, 217, 161, 110, 219, 15, 44, 43, 14, 145, 241, 89, 215, 58, 244, 26, 19, 9, 80, 169, 99, 50, 245, 201, 204, 173, 10, 91, 6, 230, 247, 71, 191, 190, 68, 103, 123, 183, 33, 175, 83, 147, 255, 55, 8, 174, 77, 196, 209, 22, 164, 214, 48, 7, 64, 139, 157, 187, 140, 239, 129, 168, 57, 29, 212, 122, 72, 13, 226, 202, 176, 199, 222, 40, 218, 151, 210, 242, 132, 25, 179, 185, 135, 167, 228, 102, 73, 149, 153, 5, 163, 238, 97, 3, 194, 115, 243, 184, 119, 224, 248, 156, 92, 95, 186, 34, 250, 240, 46, 254, 78, 152, 124, 211, 112, 148, 125, 234, 17, 138, 93, 188, 236, 216, 39, 4, 127, 87, 23, 229, 120, 98, 56, 171, 170, 11, 62, 82, 76, 107, 203, 24, 117, 192, 253, 32, 74, 134, 118, 141, 94, 158, 237, 70, 69, 180, 252, 131, 2, 84, 208, 223, 108, 205, 60, 106, 177, 61, 200, 36, 232, 197, 85, 113, 150, 101, 28, 88, 49, 160, 38, 111, 41, 20, 31, 109, 198, 136, 249, 105, 12, 121, 166, 66, 246, 207, 37, 154, 16, 159, 189, 128, 96, 144, 47, 114, 133, 51, 59, 231, 67, 137, 225, 143, 35, 193, 181, 146, 79)
  101. @staticmethod
  102. def _primativemul(a, b):
  103. masks = [ 0, 0xff ]
  104. r = 0
  105. for i in range(0, 8):
  106. mask = a & 1
  107. r ^= (masks[mask] & b) << i
  108. a = a >> 1
  109. return r
  110. _reduce = (0, 135, 137, 14, 149, 18, 28, 155, 173, 42, 36, 163, 56, 191, 177, 54)
  111. def __init__(self, v):
  112. '''v must be in the range [ 0, 255 ].
  113. Create an element of GF(2^8).
  114. The operators have been overloaded, so most normal math works.
  115. It will also automatically promote non-GF2p8 numbers if
  116. possible, e.g. GF2p8(5) + 10 works.
  117. '''
  118. if v >= 256 or v < 0:
  119. raise ValueError('%d is not a member of GF(2^8)' % v)
  120. self._v = int(v)
  121. if self._v != v:
  122. raise ValueError('%d is not a member of GF(2^8)' % v)
  123. # basic operations
  124. def __add__(self, o):
  125. if isinstance(o, int):
  126. return self + self.__class__(o)
  127. return self.__class__(self._v ^ o._v)
  128. def __radd__(self, o):
  129. return self.__add__(o)
  130. def __sub__(self, o):
  131. return self.__add__(o)
  132. def __rsub__(self, o):
  133. return self.__sub__(o)
  134. def __mul__(self, o):
  135. if isinstance(o, int):
  136. o = self.__class__(o)
  137. m = o._v
  138. # multiply
  139. r = self._primativemul(self._v, m)
  140. # reduce
  141. r ^= self._reduce[r >> 12] << 4
  142. r ^= self._reduce[(r >> 8) & 0xf ]
  143. r &= 0xff
  144. return self.__class__(r)
  145. def __rmul__(self, o):
  146. return self.__mul__(o)
  147. def __pow__(self, x):
  148. if x == -1 and self._invcache:
  149. return GF2p8(self._invcache[self._v])
  150. if x < 0:
  151. x += 255
  152. v = self.__class__(1)
  153. # TODO - make faster via caching and squaring
  154. for i in range(x):
  155. v *= self
  156. return v
  157. def powerseries(self, cnt):
  158. '''Generate [ self ** 0, self ** 1, ..., self ** cnt ].'''
  159. r = [ self.__class__(1) ]
  160. for i in range(1, cnt + 1):
  161. r.append(r[-1] * self)
  162. return r
  163. def __eq__(self, o):
  164. if isinstance(o, int):
  165. return self._v == o
  166. return self._v == o._v
  167. def __int__(self):
  168. return self._v
  169. def __repr__(self):
  170. return '%s(%d)' % (self.__class__.__name__, self._v)
  171. class TestShamirSS(unittest.TestCase):
  172. def test_evalpoly(self):
  173. a = GF2p8(random.randint(0, 255))
  174. powers = a.powerseries(4)
  175. self.assertTrue(all(isinstance(x, GF2p8) for x in powers))
  176. vals = [ GF2p8(random.randint(0, 255)) for x in range(5) ]
  177. r = evalpoly(vals, powers)
  178. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  179. powers[2] + vals[3] * powers[3] + vals[4] * powers[4])
  180. r = evalpoly(vals[:3], powers[:3])
  181. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  182. powers[2])
  183. self.assertRaises(ValueError, evalpoly, [1], [1, 2])
  184. def test_create_shares(self):
  185. self.assertRaises(TypeError, create_shares, '', 1, 1)
  186. val = bytes([ random.randint(0, 255) for x in range(100) ])
  187. #val = b'this is a test of english text.'
  188. #val = b'1234'
  189. a = create_shares(val, 2, 3)
  190. self.assertNotIn(val, set(x[1] for x in a))
  191. # that it has the number of shares
  192. self.assertEqual(len(a), 3)
  193. # that the length of the share data matches passed in data
  194. self.assertEqual(len(a[0][1]), len(val))
  195. # that one share isn't enough
  196. self.assertRaises(ValueError, recover_data, [ a[0] ], 2)
  197. for i, j in itertools.combinations(range(3), 2):
  198. self.assertEqual(val, recover_data([ a[i], a[j] ], 2))
  199. self.assertEqual(val, recover_data([ a[j], a[i] ], 2))
  200. a = create_shares(val, 15, 30)
  201. for i in range(5):
  202. self.assertEqual(val, recover_data([ a[j] for j in random.sample(range(30), 15) ], 15))
  203. def test_gf2p8_reduce(self):
  204. reduce = tuple(_makered(x, 0x87) for x in range(0, 16))
  205. if GF2p8._reduce != reduce: # pragma: no cover
  206. print('reduce:', repr(reduce))
  207. self.assertEqual(GF2p8._reduce, reduce)
  208. def test_gf2p8_inv(self):
  209. a = GF2p8(random.randint(0, 255))
  210. with unittest.mock.patch.object(GF2p8, '_invcache', ()) as pinvc:
  211. ainv = a ** -1
  212. self.assertEqual(a * ainv, 1)
  213. invcache = (None, ) + \
  214. tuple(int(GF2p8(x) ** -1) for x in range(1, 256))
  215. if GF2p8._invcache != invcache: # pragma: no cover
  216. print('inv cache:', repr(invcache))
  217. self.assertEqual(GF2p8._invcache, invcache)
  218. def test_gf2p8_power(self):
  219. a = GF2p8(random.randint(0, 255))
  220. v = GF2p8(1)
  221. for i in range(10):
  222. self.assertEqual(a ** i, v)
  223. v = v * a
  224. for i in range(10):
  225. a = GF2p8(random.randint(0, 255))
  226. powers = a.powerseries(10)
  227. for j in range(11):
  228. self.assertEqual(powers[j], a ** j)
  229. def test_gf2p8_errors(self):
  230. self.assertRaises(ValueError, GF2p8, 1000)
  231. self.assertRaises(ValueError, GF2p8, 40.5)
  232. self.assertRaises(ValueError, GF2p8, -1)
  233. def test_gf2p8(self):
  234. self.assertEqual(int(GF2p8(5)), 5)
  235. self.assertEqual(repr(GF2p8(5)), 'GF2p8(5)')
  236. for i in range(10):
  237. a = GF2p8(random.randint(0, 255))
  238. b = GF2p8(random.randint(0, 255))
  239. c = GF2p8(random.randint(0, 255))
  240. self.assertEqual(a * 0, 0)
  241. # Identity
  242. self.assertEqual(a + 0, a)
  243. self.assertEqual(a * 1, a)
  244. self.assertEqual(0 + a, a)
  245. self.assertEqual(1 * a, a)
  246. self.assertEqual(0 - a, a)
  247. # Associativity
  248. self.assertEqual((a + b) + c, a + (b + c))
  249. self.assertEqual((a * b) * c, a * (b * c))
  250. # Communitative
  251. self.assertEqual(a + b, b + a)
  252. self.assertEqual(a * b, b * a)
  253. # Distributive
  254. self.assertEqual(a * (b + c), a * b + a * c)
  255. self.assertEqual((b + c) * a, b * a + c * a)
  256. # Basic mul
  257. self.assertEqual(GF2p8(0x80) * 2, 0x87)
  258. self.assertEqual(GF2p8(0x80) * 6,
  259. (0x80 * 6) ^ (0x187 << 1))
  260. self.assertEqual(GF2p8(0x80) * 8,
  261. (0x80 * 8) ^ (0x187 << 2) ^ (0x187 << 1) ^ 0x187)
  262. self.assertEqual(a + b - b, a)