A Pure Python implementation of Shamir's Secret Sharing
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

405 lines
11 KiB

  1. # Copyright 2023 John-Mark Gurney.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # 1. Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # 2. Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. #
  12. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  14. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  15. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  16. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  17. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  18. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  19. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  20. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  21. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  22. # SUCH DAMAGE.
  23. #
  24. #
  25. # ls shamirss.py | entr sh -c ' date; python -m coverage run -m unittest shamirss && coverage report -m'
  26. #
  27. '''
  28. An implementation of Shamir's Secret Sharing.
  29. This is over GF(2^256), so unlike some other implementations that are
  30. over primes, it is valid for ALL values, and the output will be exactly
  31. the same length as the secret. This limits the number of shares to
  32. 255.
  33. Sample usage:
  34. ```
  35. import random
  36. from shamirss import *
  37. data = random.SystemRandom().randbytes(32)
  38. shares = create_shares(data, 3, 5)
  39. rdata = recover_data([ shares[1], shares[2], shares[4] ], 3)
  40. print(rdata == data)
  41. ```
  42. '''
  43. import functools
  44. import itertools
  45. import operator
  46. import secrets
  47. import unittest.mock
  48. random = secrets.SystemRandom()
  49. __all__ = [
  50. 'create_shares',
  51. 'recover_data',
  52. 'GF2p8',
  53. ]
  54. def _makered(x, y):
  55. '''Make reduction table entry.
  56. given x * 2^8, reduce it assuming polynomial y.
  57. '''
  58. x = x << 8
  59. for i in range(3, -1, -1):
  60. if x & (1 << (i + 8)):
  61. x ^= (0x100 + y) << i
  62. assert x < 256
  63. return x
  64. def evalpoly(polynomial, powers):
  65. return sum(( x * y for x, y in zip(polynomial, powers,
  66. strict=True)), 0)
  67. def create_shares(data, k, nshares):
  68. '''Given data, create nshares, such that given any k shares,
  69. data can be recovered.
  70. data must be bytes, or able to be converted to bytes, e.g. a list
  71. of ints in the range [ 0, 255 ].
  72. The return value will be a list of length nshares. Each element
  73. will be a tuple of (<int in range [ 1, nshares ]>, <bytes>).'''
  74. data = bytes(data)
  75. #print(repr(data), repr(k), repr(nshares))
  76. powers = (None, ) + tuple(GF2p8(x).powerseries(k - 1) for x in
  77. range(1, nshares + 1))
  78. coeffs = [ [ x ] + [ random.randint(0, 255) for y in
  79. range(k - 1) ] for idx, x in enumerate(data) ]
  80. return [ (x, bytes([ int(evalpoly(coeffs[idx],
  81. powers[x])) for idx, val in enumerate(data) ])) for x in
  82. range(1, nshares + 1) ]
  83. def recover_data(shares, k):
  84. '''Recover the value given shares, where k is the number of
  85. shares needed.
  86. shares must be as least length of k.
  87. Each element of shares is from one returned by create_shares,
  88. that is a tuple of an int and bytes.'''
  89. if len(shares) < k:
  90. raise ValueError('not enough shares to recover')
  91. return bytes([ int(sum(( GF2p8(y[idx]) *
  92. functools.reduce(operator.mul, ( pix * ((GF2p8(pix) - x) ** -1) for
  93. pix, piy in shares[:k] if pix != x ), 1) for x, y in shares[:k] ),
  94. 0)) for idx in range(len(shares[0][1]))])
  95. class GF2p8:
  96. # polynomial 0x187
  97. '''An implementation of GF(2^8). It uses the polynomial 0x187
  98. or x^8 + x^7 + x^2 + x + 1.
  99. '''
  100. _invcache = (None, 1, 195, 130, 162, 126, 65, 90, 81, 54, 63, 172, 227, 104, 45, 42, 235, 155, 27, 53, 220, 30, 86, 165, 178, 116, 52, 18, 213, 100, 21, 221, 182, 75, 142, 251, 206, 233, 217, 161, 110, 219, 15, 44, 43, 14, 145, 241, 89, 215, 58, 244, 26, 19, 9, 80, 169, 99, 50, 245, 201, 204, 173, 10, 91, 6, 230, 247, 71, 191, 190, 68, 103, 123, 183, 33, 175, 83, 147, 255, 55, 8, 174, 77, 196, 209, 22, 164, 214, 48, 7, 64, 139, 157, 187, 140, 239, 129, 168, 57, 29, 212, 122, 72, 13, 226, 202, 176, 199, 222, 40, 218, 151, 210, 242, 132, 25, 179, 185, 135, 167, 228, 102, 73, 149, 153, 5, 163, 238, 97, 3, 194, 115, 243, 184, 119, 224, 248, 156, 92, 95, 186, 34, 250, 240, 46, 254, 78, 152, 124, 211, 112, 148, 125, 234, 17, 138, 93, 188, 236, 216, 39, 4, 127, 87, 23, 229, 120, 98, 56, 171, 170, 11, 62, 82, 76, 107, 203, 24, 117, 192, 253, 32, 74, 134, 118, 141, 94, 158, 237, 70, 69, 180, 252, 131, 2, 84, 208, 223, 108, 205, 60, 106, 177, 61, 200, 36, 232, 197, 85, 113, 150, 101, 28, 88, 49, 160, 38, 111, 41, 20, 31, 109, 198, 136, 249, 105, 12, 121, 166, 66, 246, 207, 37, 154, 16, 159, 189, 128, 96, 144, 47, 114, 133, 51, 59, 231, 67, 137, 225, 143, 35, 193, 181, 146, 79)
  101. @staticmethod
  102. def _primativemul(a, b):
  103. masks = [ 0, 0xff ]
  104. r = 0
  105. for i in range(0, 8):
  106. mask = a & 1
  107. r ^= (masks[mask] & b) << i
  108. a = a >> 1
  109. return r
  110. # bytes is smaller, 49 vs 168 bytes, so should fit in a cache line
  111. _reduce = b'\x00\x87\x89\x0e\x95\x12\x1c\x9b\xad*$\xa38\xbf\xb16'
  112. def __init__(self, v):
  113. '''v must be in the range [ 0, 255 ].
  114. Create an element of GF(2^8).
  115. The operators have been overloaded, so most normal math works.
  116. It will also automatically promote non-GF2p8 numbers if
  117. possible, e.g. GF2p8(5) + 10 works.
  118. '''
  119. if v >= 256 or v < 0:
  120. raise ValueError('%d is not a member of GF(2^8)' % v)
  121. self._v = int(v)
  122. if self._v != v:
  123. raise ValueError('%d is not a member of GF(2^8)' % v)
  124. # basic operations
  125. def __add__(self, o):
  126. if not isinstance(o, self.__class__):
  127. o = self.__class__(o)
  128. return self.__class__(self._v ^ o._v)
  129. def __radd__(self, o):
  130. return self.__add__(o)
  131. def __sub__(self, o):
  132. return self.__add__(o)
  133. def __rsub__(self, o):
  134. return self.__sub__(o)
  135. def __mul__(self, o):
  136. if not isinstance(o, self.__class__):
  137. o = self.__class__(o)
  138. m = o._v
  139. # possibly use log tables:
  140. # a = GF2p8(0x87)
  141. # logtbl = { idx: a ** idx for idx in range(256) }
  142. # invlogtbl = { v: k for k, v in logtbl.items() }
  143. # len(invlogtbl)
  144. # 255
  145. # invlogtbl[GF2p8(6)] + invlogtbl[GF2p8(213)]
  146. # 254
  147. # logtbl[254]
  148. # GF2p8(119)
  149. # multiply
  150. r = self._primativemul(self._v, m)
  151. # reduce
  152. r ^= self._reduce[r >> 12] << 4
  153. r ^= self._reduce[(r >> 8) & 0xf ]
  154. r &= 0xff
  155. return self.__class__(r)
  156. def __rmul__(self, o):
  157. return self.__mul__(o)
  158. def __truediv__(self, o):
  159. if not isinstance(o, self.__class__):
  160. o = self.__class__(o)
  161. return self * (o ** -1)
  162. def __rtruediv__(self, o):
  163. if not isinstance(o, self.__class__):
  164. o = self.__class__(o)
  165. return o * (self ** -1)
  166. def __pow__(self, x):
  167. if x == -1 and self._invcache:
  168. return self.__class__(self._invcache[self._v])
  169. if x < 0:
  170. x += 255
  171. v = self.__class__(1)
  172. # TODO - make faster via caching and squaring
  173. for i in range(x):
  174. v *= self
  175. return v
  176. def powerseries(self, cnt):
  177. '''Generate [ self ** 0, self ** 1, ..., self ** cnt ].'''
  178. r = [ self.__class__(1) ]
  179. for i in range(1, cnt + 1):
  180. r.append(r[-1] * self)
  181. return r
  182. def __eq__(self, o):
  183. if not isinstance(o, self.__class__):
  184. o = self.__class__(o)
  185. return self._v == o._v
  186. def __int__(self):
  187. return self._v
  188. def __hash__(self):
  189. return hash(self._v)
  190. def __repr__(self):
  191. return '%s(%d)' % (self.__class__.__name__, self._v)
  192. class TestShamirSS(unittest.TestCase):
  193. def test_evalpoly(self):
  194. a = GF2p8(random.randint(0, 255))
  195. powers = a.powerseries(4)
  196. self.assertTrue(all(isinstance(x, GF2p8) for x in powers))
  197. vals = [ GF2p8(random.randint(0, 255)) for x in range(5) ]
  198. r = evalpoly(vals, powers)
  199. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  200. powers[2] + vals[3] * powers[3] + vals[4] * powers[4])
  201. r = evalpoly(vals[:3], powers[:3])
  202. self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
  203. powers[2])
  204. self.assertRaises(ValueError, evalpoly, [1], [1, 2])
  205. def test_create_shares(self):
  206. self.assertRaises(TypeError, create_shares, '', 1, 1)
  207. val = bytes([ random.randint(0, 255) for x in range(100) ])
  208. #val = b'this is a test of english text.'
  209. #val = b'1234'
  210. a = create_shares(val, 2, 3)
  211. self.assertNotIn(val, set(x[1] for x in a))
  212. # that it has the number of shares
  213. self.assertEqual(len(a), 3)
  214. # that the length of the share data matches passed in data
  215. self.assertEqual(len(a[0][1]), len(val))
  216. # that one share isn't enough
  217. self.assertRaises(ValueError, recover_data, [ a[0] ], 2)
  218. for i, j in itertools.combinations(range(3), 2):
  219. self.assertEqual(val, recover_data([ a[i], a[j] ], 2))
  220. self.assertEqual(val, recover_data([ a[j], a[i] ], 2))
  221. a = create_shares(val, 15, 30)
  222. for i in range(5):
  223. self.assertEqual(val, recover_data([ a[j] for j in random.sample(range(30), 15) ], 15))
  224. def test_gf2p8_reduce(self):
  225. reduce = bytes((_makered(x, 0x87) for x in range(0, 16)))
  226. if GF2p8._reduce != reduce: # pragma: no cover
  227. print('reduce:', repr(reduce))
  228. self.assertEqual(GF2p8._reduce, reduce)
  229. def test_gf2p8_inv(self):
  230. a = GF2p8(random.randint(0, 255))
  231. with unittest.mock.patch.object(GF2p8, '_invcache', ()) as pinvc:
  232. ainv = a ** -1
  233. self.assertEqual(a * ainv, 1)
  234. invcache = (None, ) + \
  235. tuple(int(GF2p8(x) ** -1) for x in range(1, 256))
  236. if GF2p8._invcache != invcache: # pragma: no cover
  237. print('inv cache:', repr(invcache))
  238. self.assertEqual(GF2p8._invcache, invcache)
  239. def test_gf2p8_power(self):
  240. a = GF2p8(random.randint(0, 255))
  241. v = GF2p8(1)
  242. for i in range(10):
  243. self.assertEqual(a ** i, v)
  244. v = v * a
  245. for i in range(10):
  246. a = GF2p8(random.randint(0, 255))
  247. powers = a.powerseries(10)
  248. for j in range(11):
  249. self.assertEqual(powers[j], a ** j)
  250. def test_gf2p8_errors(self):
  251. self.assertRaises(ValueError, GF2p8, 1000)
  252. self.assertRaises(ValueError, GF2p8, 40.5)
  253. self.assertRaises(ValueError, GF2p8, -1)
  254. def test_gf2p8_div(self):
  255. self.assertEqual(GF2p8(10) / 11, GF2p8(11) ** -1 * GF2p8(10))
  256. self.assertEqual(10 / GF2p8(11), GF2p8(11) ** -1 * GF2p8(10))
  257. def test_gf2p8(self):
  258. self.assertEqual(int(GF2p8(5)), 5)
  259. self.assertEqual(repr(GF2p8(5)), 'GF2p8(5)')
  260. for i in range(10):
  261. a = GF2p8(random.randint(0, 255))
  262. b = GF2p8(random.randint(0, 255))
  263. c = GF2p8(random.randint(0, 255))
  264. # Hashable
  265. { a, b, c }
  266. self.assertEqual(a * 0, 0)
  267. # Identity
  268. self.assertEqual(a + 0, a)
  269. self.assertEqual(a * 1, a)
  270. self.assertEqual(0 + a, a)
  271. self.assertEqual(1 * a, a)
  272. self.assertEqual(0 - a, a)
  273. # Associativity
  274. self.assertEqual((a + b) + c, a + (b + c))
  275. self.assertEqual((a * b) * c, a * (b * c))
  276. # Communitative
  277. self.assertEqual(a + b, b + a)
  278. self.assertEqual(a * b, b * a)
  279. # Distributive
  280. self.assertEqual(a * (b + c), a * b + a * c)
  281. self.assertEqual((b + c) * a, b * a + c * a)
  282. # Basic mul
  283. self.assertEqual(GF2p8(0x80) * 2, 0x87)
  284. self.assertEqual(GF2p8(0x80) * 6,
  285. (0x80 * 6) ^ (0x187 << 1))
  286. self.assertEqual(GF2p8(0x80) * 8,
  287. (0x80 * 8) ^ (0x187 << 2) ^ (0x187 << 1) ^ 0x187)
  288. self.assertEqual(a + b - b, a)