|
- # Copyright 2023 John-Mark Gurney.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- # 1. Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # 2. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
- # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
- # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
- # SUCH DAMAGE.
- #
-
- #
- # ls shamirss.py | entr sh -c ' date; python -m coverage run -m unittest shamirss && coverage report -m'
- #
-
- '''
- An implementation of Shamir's Secret Sharing.
-
- This is over GF(2^256), so unlike some other implementations that are
- over primes, it is valid for ALL values, and the output will be exactly
- the same length as the secret. This limits the number of shares to
- 255.
-
- Sample usage:
- ```
- import random
- from shamirss import *
- data = random.SystemRandom().randbytes(32)
- shares = create_shares(data, 3, 5)
- rdata = recover_data([ shares[1], shares[2], shares[4] ], 3)
- print(rdata == data)
- ```
- '''
-
- import functools
- import itertools
- import operator
- import secrets
- import unittest.mock
-
- random = secrets.SystemRandom()
-
- __all__ = [
- 'create_shares',
- 'recover_data',
- 'GF2p8',
- ]
-
- def _makered(x, y):
- '''Make reduction table entry.
-
- given x * 2^8, reduce it assuming polynomial y.
- '''
-
- x = x << 8
-
- for i in range(3, -1, -1):
- if x & (1 << (i + 8)):
- x ^= (0x100 + y) << i
-
- assert x < 256
-
- return x
-
- def evalpoly(polynomial, powers):
- return sum(( x * y for x, y in zip(polynomial, powers,
- strict=True)), 0)
-
- def create_shares(data, k, nshares):
- '''Given data, create nshares, such that given any k shares,
- data can be recovered.
-
- data must be bytes, or able to be converted to bytes, e.g. a list
- of ints in the range [ 0, 255 ].
-
- The return value will be a list of length nshares. Each element
- will be a tuple of (<int in range [ 1, nshares ]>, <bytes>).'''
-
- data = bytes(data)
-
- #print(repr(data), repr(k), repr(nshares))
-
- powers = (None, ) + tuple(GF2p8(x).powerseries(k - 1) for x in
- range(1, nshares + 1))
-
- coeffs = [ [ x ] + [ random.randint(0, 255) for y in
- range(k - 1) ] for idx, x in enumerate(data) ]
-
- return [ (x, bytes([ int(evalpoly(coeffs[idx],
- powers[x])) for idx, val in enumerate(data) ])) for x in
- range(1, nshares + 1) ]
-
- def recover_data(shares, k):
- '''Recover the value given shares, where k is the number of
- shares needed.
-
- shares must be as least length of k.
-
- Each element of shares is from one returned by create_shares,
- that is a tuple of an int and bytes.'''
-
- if len(shares) < k:
- raise ValueError('not enough shares to recover')
-
- return bytes([ int(sum(( GF2p8(y[idx]) *
- functools.reduce(operator.mul, ( pix * ((GF2p8(pix) - x) ** -1) for
- pix, piy in shares[:k] if pix != x ), 1) for x, y in shares[:k] ),
- 0)) for idx in range(len(shares[0][1]))])
-
- class GF2p8:
- # polynomial 0x187
- '''An implementation of GF(2^8). It uses the polynomial 0x187
- or x^8 + x^7 + x^2 + x + 1.
- '''
-
- _invcache = b'\x00\x01\xc3\x82\xa2~AZQ6?\xac\xe3h-*\xeb\x9b\x1b5\xdc\x1eV\xa5\xb2t4\x12\xd5d\x15\xdd\xb6K\x8e\xfb\xce\xe9\xd9\xa1n\xdb\x0f,+\x0e\x91\xf1Y\xd7:\xf4\x1a\x13\tP\xa9c2\xf5\xc9\xcc\xad\n[\x06\xe6\xf7G\xbf\xbeDg{\xb7!\xafS\x93\xff7\x08\xaeM\xc4\xd1\x16\xa4\xd60\x07@\x8b\x9d\xbb\x8c\xef\x81\xa89\x1d\xd4zH\r\xe2\xca\xb0\xc7\xde(\xda\x97\xd2\xf2\x84\x19\xb3\xb9\x87\xa7\xe4fI\x95\x99\x05\xa3\xeea\x03\xc2s\xf3\xb8w\xe0\xf8\x9c\\_\xba"\xfa\xf0.\xfeN\x98|\xd3p\x94}\xea\x11\x8a]\xbc\xec\xd8\'\x04\x7fW\x17\xe5xb8\xab\xaa\x0b>RLk\xcb\x18u\xc0\xfd J\x86v\x8d^\x9e\xedFE\xb4\xfc\x83\x02T\xd0\xdfl\xcd<j\xb1=\xc8$\xe8\xc5Uq\x96e\x1cX1\xa0&o)\x14\x1fm\xc6\x88\xf9i\x0cy\xa6B\xf6\xcf%\x9a\x10\x9f\xbd\x80`\x90/r\x853;\xe7C\x89\xe1\x8f#\xc1\xb5\x92O'
-
- @staticmethod
- def _primativemul(a, b):
- masks = [ 0, 0xff ]
-
- r = 0
-
- for i in range(0, 8):
- mask = a & 1
- r ^= (masks[mask] & b) << i
- a = a >> 1
-
- return r
-
- # bytes is smaller, 49 vs 168 bytes, so should fit in a cache line
- _reduce = b'\x00\x87\x89\x0e\x95\x12\x1c\x9b\xad*$\xa38\xbf\xb16'
-
- def __init__(self, v):
- '''v must be in the range [ 0, 255 ].
-
- Create an element of GF(2^8).
-
- The operators have been overloaded, so most normal math works.
-
- It will also automatically promote non-GF2p8 numbers if
- possible, e.g. GF2p8(5) + 10 works.
- '''
-
- if v >= 256 or v < 0:
- raise ValueError('%d is not a member of GF(2^8)' % v)
-
- self._v = int(v)
-
- if self._v != v:
- raise ValueError('%d is not a member of GF(2^8)' % v)
-
- # basic operations
- def __add__(self, o):
- if not isinstance(o, self.__class__):
- o = self.__class__(o)
-
- return self.__class__(self._v ^ o._v)
-
- def __radd__(self, o):
- return self.__add__(o)
-
- def __sub__(self, o):
- return self.__add__(o)
-
- def __rsub__(self, o):
- return self.__sub__(o)
-
- def __mul__(self, o):
-
- if not isinstance(o, self.__class__):
- o = self.__class__(o)
-
- m = o._v
-
- # possibly use log tables:
- # a = GF2p8(0x87)
- # logtbl = { idx: a ** idx for idx in range(256) }
- # invlogtbl = { v: k for k, v in logtbl.items() }
- # len(invlogtbl)
- # 255
- # invlogtbl[GF2p8(6)] + invlogtbl[GF2p8(213)]
- # 254
- # logtbl[254]
- # GF2p8(119)
-
- # multiply
- r = self._primativemul(self._v, m)
-
- # reduce
- r ^= self._reduce[r >> 12] << 4
- r ^= self._reduce[(r >> 8) & 0xf ]
-
- r &= 0xff
-
- return self.__class__(r)
-
- def __rmul__(self, o):
- return self.__mul__(o)
-
- def __truediv__(self, o):
- if not isinstance(o, self.__class__):
- o = self.__class__(o)
-
- return self * (o ** -1)
-
- def __rtruediv__(self, o):
- if not isinstance(o, self.__class__):
- o = self.__class__(o)
-
- return o * (self ** -1)
-
- def __pow__(self, x):
- if self._v == 0:
- if x < 0:
- raise ZeroDivisionError
-
- return self
-
- if x == -1 and self._invcache:
- return self.__class__(self._invcache[self._v])
-
- # we loop after 255, so no need to do extra work
- x %= 255
-
- # Note: not constant time, also, not optimial algorithm
-
- # The art of computer programming vol 2. ยง 4.6.3 Algorithm A
- # https://archive.org/details/artofcomputerpro0000knut/page/400/mode/2up
-
- # A1
- n = x
- y = self.__class__(1)
- z = self
-
- while n:
- # A2
- n, isodd = divmod(n, 2)
-
- if not isodd:
- # A5
- z *= z
- continue
-
- # A3
- y *= z
-
- # A4
- if not n:
- break
-
- # A5
- z *= z
-
- return y
-
- def powerseries(self, cnt):
- '''Generate [ self ** 0, self ** 1, ..., self ** cnt ].'''
-
- r = [ self.__class__(1) ]
-
- for i in range(1, cnt + 1):
- r.append(r[-1] * self)
-
- return r
-
- def __eq__(self, o):
- if not isinstance(o, self.__class__):
- o = self.__class__(o)
-
- return self._v == o._v
-
- def __int__(self):
- return self._v
-
- def __hash__(self):
- return hash(self._v)
-
- def __repr__(self):
- return '%s(%d)' % (self.__class__.__name__, self._v)
-
- class TestShamirSS(unittest.TestCase):
- def test_evalpoly(self):
- a = GF2p8(random.randint(0, 255))
-
- powers = a.powerseries(4)
-
- self.assertTrue(all(isinstance(x, GF2p8) for x in powers))
-
- vals = [ GF2p8(random.randint(0, 255)) for x in range(5) ]
-
- r = evalpoly(vals, powers)
- self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
- powers[2] + vals[3] * powers[3] + vals[4] * powers[4])
-
- r = evalpoly(vals[:3], powers[:3])
- self.assertEqual(r, vals[0] + vals[1] * powers[1] + vals[2] *
- powers[2])
-
- self.assertRaises(ValueError, evalpoly, [1], [1, 2])
-
- def test_create_shares(self):
- self.assertRaises(TypeError, create_shares, '', 1, 1)
-
- val = bytes([ random.randint(0, 255) for x in range(100) ])
- #val = b'this is a test of english text.'
- #val = b'1234'
-
- a = create_shares(val, 2, 3)
-
- self.assertNotIn(val, set(x[1] for x in a))
-
- # that it has the number of shares
- self.assertEqual(len(a), 3)
-
- # that the length of the share data matches passed in data
- self.assertEqual(len(a[0][1]), len(val))
-
- # that one share isn't enough
- self.assertRaises(ValueError, recover_data, [ a[0] ], 2)
-
- for i, j in itertools.combinations(range(3), 2):
- self.assertEqual(val, recover_data([ a[i], a[j] ], 2))
- self.assertEqual(val, recover_data([ a[j], a[i] ], 2))
-
- a = create_shares(val, 15, 30)
-
- for i in range(5):
- self.assertEqual(val, recover_data([ a[j] for j in random.sample(range(30), 15) ], 15))
-
- def test_gf2p8_reduce(self):
- reduce = bytes((_makered(x, 0x87) for x in range(0, 16)))
-
- if GF2p8._reduce != reduce: # pragma: no cover
- print('reduce:', repr(reduce))
- self.assertEqual(GF2p8._reduce, reduce)
-
-
- def test_gf2p8_inv(self):
-
- a = GF2p8(random.randint(0, 255))
-
- with unittest.mock.patch.object(GF2p8, '_invcache', ()) as pinvc:
- ainv = a ** -1
-
- self.assertEqual(a * ainv, 1)
-
- invcache = bytes((0, ) + \
- tuple(int(GF2p8(x) ** -1) for x in range(1, 256)))
-
- if GF2p8._invcache != invcache: # pragma: no cover
- print('inv cache:', repr(invcache))
- self.assertEqual(GF2p8._invcache, invcache)
-
- def test_gf2p8_power(self):
- zero = GF2p8(0)
- self.assertEqual(zero ** 5, zero)
- with self.assertRaises(ZeroDivisionError):
- zero ** -1
-
- a = GF2p8(random.randint(0, 255))
-
- v = GF2p8(1)
- for i in range(260):
- self.assertEqual(a ** i, v)
-
- v = v * a
-
- for i in range(10):
- neg = random.randint(-600, -1)
-
- p = neg
- while p < 0:
- p += 255
-
- self.assertEqual(a ** neg, a ** p)
-
- for i in range(10):
- a = GF2p8(random.randint(0, 255))
-
- powers = a.powerseries(10)
- for j in range(11):
- self.assertEqual(powers[j], a ** j)
-
- def test_gf2p8_errors(self):
- self.assertRaises(ValueError, GF2p8, 1000)
- self.assertRaises(ValueError, GF2p8, 40.5)
- self.assertRaises(ValueError, GF2p8, -1)
-
- def test_gf2p8_div(self):
- self.assertEqual(GF2p8(10) / 11, GF2p8(11) ** -1 * GF2p8(10))
- self.assertEqual(10 / GF2p8(11), GF2p8(11) ** -1 * GF2p8(10))
-
- def test_gf2p8(self):
- self.assertEqual(int(GF2p8(5)), 5)
- self.assertEqual(repr(GF2p8(5)), 'GF2p8(5)')
-
- for i in range(10):
- a = GF2p8(random.randint(0, 255))
- b = GF2p8(random.randint(0, 255))
- c = GF2p8(random.randint(0, 255))
-
- # Hashable
- { a, b, c }
-
- self.assertEqual(a * 0, 0)
-
- # Identity
- self.assertEqual(a + 0, a)
- self.assertEqual(a * 1, a)
- self.assertEqual(0 + a, a)
- self.assertEqual(1 * a, a)
- self.assertEqual(0 - a, a)
-
- # Associativity
- self.assertEqual((a + b) + c, a + (b + c))
- self.assertEqual((a * b) * c, a * (b * c))
-
- # Communitative
- self.assertEqual(a + b, b + a)
- self.assertEqual(a * b, b * a)
-
- # Distributive
- self.assertEqual(a * (b + c), a * b + a * c)
- self.assertEqual((b + c) * a, b * a + c * a)
-
- # Basic mul
- self.assertEqual(GF2p8(0x80) * 2, 0x87)
- self.assertEqual(GF2p8(0x80) * 6,
- (0x80 * 6) ^ (0x187 << 1))
- self.assertEqual(GF2p8(0x80) * 8,
- (0x80 * 8) ^ (0x187 << 2) ^ (0x187 << 1) ^ 0x187)
-
- self.assertEqual(a + b - b, a)
|