# Copyright 2023 John-Mark Gurney. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # # # ls shamirss.py | entr sh -c ' date; python -m coverage run -m unittest shamirss && coverage report -m' # ''' An implementation of Shamir's Secret Sharing. This is over GF(2^256), so unlike some other implementations that are over primes, it is valid for ALL values, and the output will be exactly the same length as the secret. This limits the number of shares to 255. Sample usage: ``` import random from shamirss import * data = random.SystemRandom().randbytes(32) shares = create_shares(data, 3, 5) rdata = recover_data([ shares[1], shares[2], shares[4] ], 3) print(rdata == data) ``` ''' import functools import itertools import operator import secrets import unittest.mock random = secrets.SystemRandom() __all__ = [ 'create_shares', 'recover_data', 'GF2p8', ] def _makered(x, y): '''Make reduction table entry. given x * 2^8, reduce it assuming polynomial y. ''' x = x << 8 for i in range(3, -1, -1): if x & (1 << (i + 8)): x ^= (0x100 + y) << i assert x < 256 return x def evalpoly(polynomial, powers): return sum(( x * y for x, y in zip(polynomial, powers, strict=True)), 0) def create_shares(data, k, nshares): '''Given data, create nshares, such that given any k shares, data can be recovered. data must be bytes, or able to be converted to bytes, e.g. a list of ints in the range [ 0, 255 ]. The return value will be a list of length nshares. Each element will be a tuple of (, ).''' data = bytes(data) #print(repr(data), repr(k), repr(nshares)) powers = (None, ) + tuple(GF2p8(x).powerseries(k - 1) for x in range(1, nshares + 1)) coeffs = [ [ x ] + [ random.randint(0, 255) for y in range(k - 1) ] for idx, x in enumerate(data) ] return [ (x, bytes([ int(evalpoly(coeffs[idx], powers[x])) for idx, val in enumerate(data) ])) for x in range(1, nshares + 1) ] def recover_data(shares, k): '''Recover the value given shares, where k is the number of shares needed. shares must be as least length of k. Each element of shares is from one returned by create_shares, that is a tuple of an int and bytes.''' if len(shares) < k: raise ValueError('not enough shares to recover') return bytes([ int(sum(( GF2p8(y[idx]) * functools.reduce(operator.mul, ( pix * ((GF2p8(pix) - x) ** -1) for pix, piy in shares[:k] if pix != x ), 1) for x, y in shares[:k] ), 0)) for idx in range(len(shares[0][1]))]) class GF2p8: # polynomial 0x187 '''An implementation of GF(2^8). It uses the polynomial 0x187 or x^8 + x^7 + x^2 + x + 1. ''' _invcache = (None, 1, 195, 130, 162, 126, 65, 90, 81, 54, 63, 172, 227, 104, 45, 42, 235, 155, 27, 53, 220, 30, 86, 165, 178, 116, 52, 18, 213, 100, 21, 221, 182, 75, 142, 251, 206, 233, 217, 161, 110, 219, 15, 44, 43, 14, 145, 241, 89, 215, 58, 244, 26, 19, 9, 80, 169, 99, 50, 245, 201, 204, 173, 10, 91, 6, 230, 247, 71, 191, 190, 68, 103, 123, 183, 33, 175, 83, 147, 255, 55, 8, 174, 77, 196, 209, 22, 164, 214, 48, 7, 64, 139, 157, 187, 140, 239, 129, 168, 57, 29, 212, 122, 72, 13, 226, 202, 176, 199, 222, 40, 218, 151, 210, 242, 132, 25, 179, 185, 135, 167, 228, 102, 73, 149, 153, 5, 163, 238, 97, 3, 194, 115, 243, 184, 119, 224, 248, 156, 92, 95, 186, 34, 250, 240, 46, 254, 78, 152, 124, 211, 112, 148, 125, 234, 17, 138, 93, 188, 236, 216, 39, 4, 127, 87, 23, 229, 120, 98, 56, 171, 170, 11, 62, 82, 76, 107, 203, 24, 117, 192, 253, 32, 74, 134, 118, 141, 94, 158, 237, 70, 69, 180, 252, 131, 2, 84, 208, 223, 108, 205, 60, 106, 177, 61, 200, 36, 232, 197, 85, 113, 150, 101, 28, 88, 49, 160, 38, 111, 41, 20, 31, 109, 198, 136, 249, 105, 12, 121, 166, 66, 246, 207, 37, 154, 16, 159, 189, 128, 96, 144, 47, 114, 133, 51, 59, 231, 67, 137, 225, 143, 35, 193, 181, 146, 79) @staticmethod def _primativemul(a, b): masks = [ 0, 0xff ] r = 0 for i in range(0, 8): mask = a & 1 r ^= (masks[mask] & b) << i a = a >> 1 return r # bytes is smaller, 49 vs 168 bytes, so should fit in a cache line _reduce = b'\x00\x87\x89\x0e\x95\x12\x1c\x9b\xad*$\xa38\xbf\xb16' def __init__(self, v): '''v must be in the range [ 0, 255 ]. Create an element of GF(2^8). The operators have been overloaded, so most normal math works. It will also automatically promote non-GF2p8 numbers if possible, e.g. GF2p8(5) + 10 works. ''' if v >= 256 or v < 0: raise ValueError('%d is not a member of GF(2^8)' % v) self._v = int(v) if self._v != v: raise ValueError('%d is not a member of GF(2^8)' % v) # basic operations def __add__(self, o): if not isinstance(o, self.__class__): o = self.__class__(o) return self.__class__(self._v ^ o._v) def __radd__(self, o): return self.__add__(o) def __sub__(self, o): return self.__add__(o) def __rsub__(self, o): return self.__sub__(o) def __mul__(self, o): if not isinstance(o, self.__class__): o = self.__class__(o) m = o._v # possibly use log tables: # a = GF2p8(0x87) # logtbl = { idx: a ** idx for idx in range(256) } # invlogtbl = { v: k for k, v in logtbl.items() } # len(invlogtbl) # 255 # invlogtbl[GF2p8(6)] + invlogtbl[GF2p8(213)] # 254 # logtbl[254] # GF2p8(119) # multiply r = self._primativemul(self._v, m) # reduce r ^= self._reduce[r >> 12] << 4 r ^= self._reduce[(r >> 8) & 0xf ] r &= 0xff return self.__class__(r) def __rmul__(self, o): return self.__mul__(o) def __truediv__(self, o): if not isinstance(o, self.__class__): o = self.__class__(o) return self * (o ** -1) def __rtruediv__(self, o): if not isinstance(o, self.__class__): o = self.__class__(o) return o * (self ** -1) def __pow__(self, x): if x == -1 and self._invcache: return self.__class__(self._invcache[self._v]) if x < 0: x += 255 v = self.__class__(1) # TODO - make faster via caching and squaring for i in range(x): v *= self return v def powerseries(self, cnt): '''Generate [ self ** 0, self ** 1, ..., self ** cnt ].''' r = [ self.__class__(1) ] for i in range(1, cnt + 1): r.append(r[-1] * self) return r def __eq__(self, o): if not isinstance(o, self.__class__): o = self.__class__(o) return self._v == o._v def __int__(self): return self._v def __hash__(self): return hash(self._v) def __repr__(self): return '%s(%d)' % (self.__class__.__name__, self._v) class TestShamirSS(unittest.TestCase): def test_evalpoly(self): a = GF2p8(random.randint(0, 255)) powers = a.powerseries(4) self.assertTrue(all(isinstance(x, GF2p8) for x in powers)) vals = [ GF2p8(random.randint(0, 255)) for x in range(5) ] r = evalpoly(vals, powers) self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] * powers[2] + vals[3] * powers[3] + vals[4] * powers[4]) r = evalpoly(vals[:3], powers[:3]) self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] * powers[2]) self.assertRaises(ValueError, evalpoly, [1], [1, 2]) def test_create_shares(self): self.assertRaises(TypeError, create_shares, '', 1, 1) val = bytes([ random.randint(0, 255) for x in range(100) ]) #val = b'this is a test of english text.' #val = b'1234' a = create_shares(val, 2, 3) self.assertNotIn(val, set(x[1] for x in a)) # that it has the number of shares self.assertEqual(len(a), 3) # that the length of the share data matches passed in data self.assertEqual(len(a[0][1]), len(val)) # that one share isn't enough self.assertRaises(ValueError, recover_data, [ a[0] ], 2) for i, j in itertools.combinations(range(3), 2): self.assertEqual(val, recover_data([ a[i], a[j] ], 2)) self.assertEqual(val, recover_data([ a[j], a[i] ], 2)) a = create_shares(val, 15, 30) for i in range(5): self.assertEqual(val, recover_data([ a[j] for j in random.sample(range(30), 15) ], 15)) def test_gf2p8_reduce(self): reduce = bytes((_makered(x, 0x87) for x in range(0, 16))) if GF2p8._reduce != reduce: # pragma: no cover print('reduce:', repr(reduce)) self.assertEqual(GF2p8._reduce, reduce) def test_gf2p8_inv(self): a = GF2p8(random.randint(0, 255)) with unittest.mock.patch.object(GF2p8, '_invcache', ()) as pinvc: ainv = a ** -1 self.assertEqual(a * ainv, 1) invcache = (None, ) + \ tuple(int(GF2p8(x) ** -1) for x in range(1, 256)) if GF2p8._invcache != invcache: # pragma: no cover print('inv cache:', repr(invcache)) self.assertEqual(GF2p8._invcache, invcache) def test_gf2p8_power(self): a = GF2p8(random.randint(0, 255)) v = GF2p8(1) for i in range(10): self.assertEqual(a ** i, v) v = v * a for i in range(10): a = GF2p8(random.randint(0, 255)) powers = a.powerseries(10) for j in range(11): self.assertEqual(powers[j], a ** j) def test_gf2p8_errors(self): self.assertRaises(ValueError, GF2p8, 1000) self.assertRaises(ValueError, GF2p8, 40.5) self.assertRaises(ValueError, GF2p8, -1) def test_gf2p8_div(self): self.assertEqual(GF2p8(10) / 11, GF2p8(11) ** -1 * GF2p8(10)) self.assertEqual(10 / GF2p8(11), GF2p8(11) ** -1 * GF2p8(10)) def test_gf2p8(self): self.assertEqual(int(GF2p8(5)), 5) self.assertEqual(repr(GF2p8(5)), 'GF2p8(5)') for i in range(10): a = GF2p8(random.randint(0, 255)) b = GF2p8(random.randint(0, 255)) c = GF2p8(random.randint(0, 255)) # Hashable { a, b, c } self.assertEqual(a * 0, 0) # Identity self.assertEqual(a + 0, a) self.assertEqual(a * 1, a) self.assertEqual(0 + a, a) self.assertEqual(1 * a, a) self.assertEqual(0 - a, a) # Associativity self.assertEqual((a + b) + c, a + (b + c)) self.assertEqual((a * b) * c, a * (b * c)) # Communitative self.assertEqual(a + b, b + a) self.assertEqual(a * b, b * a) # Distributive self.assertEqual(a * (b + c), a * b + a * c) self.assertEqual((b + c) * a, b * a + c * a) # Basic mul self.assertEqual(GF2p8(0x80) * 2, 0x87) self.assertEqual(GF2p8(0x80) * 6, (0x80 * 6) ^ (0x187 << 1)) self.assertEqual(GF2p8(0x80) * 8, (0x80 * 8) ^ (0x187 << 2) ^ (0x187 << 1) ^ 0x187) self.assertEqual(a + b - b, a)