|
- # Copyright 2021 John-Mark Gurney.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- # 1. Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # 2. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
- # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
- # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
- # SUCH DAMAGE.
- #
-
- import os
- import unittest
-
- from ctypes import Array, Structure, POINTER, CFUNCTYPE, pointer, sizeof
- from ctypes import c_uint8, c_uint16, c_ssize_t, c_size_t, c_uint64, c_int
- from ctypes import CDLL
-
- class StructureRepr(object):
- @staticmethod
- def __specialrepr(obj):
- if isinstance(obj, Array):
- return '[ %s ]' % ', '.join(hex(x) for x in obj)
-
- return repr(obj)
-
- def __repr__(self): #pragma: no cover
- return '%s(%s)' % (self.__class__.__name__, ', '.join('%s=%s' %
- (k, self.__specialrepr(getattr(self, k))) for k, v in self._fields_))
-
- class PktBuf(Structure):
- _fields_ = [
- ('pkt', POINTER(c_uint8)),
- ('pktlen', c_uint16),
- ]
-
- def _from(self):
- return bytes(self.pkt[:self.pktlen])
-
- def __repr__(self): #pragma: no cover
- return 'PktBuf(pkt=%s, pktlen=%s)' % (repr(self._from()),
- self.pktlen)
-
- def make_pktbuf(s):
- pb = PktBuf()
-
- if isinstance(s, bytearray):
- obj = s
- pb.pkt = pointer(c_uint8.from_buffer(s))
- else:
- obj = (c_uint8 * len(s))(*s)
- pb.pkt = obj
-
- pb.pktlen = len(s)
-
- pb._make_pktbuf_ref = (obj, s)
-
- return pb
-
- process_msgfunc_t = CFUNCTYPE(None, PktBuf, POINTER(PktBuf))
-
- try:
- _lib = CDLL('libsyote_test.dylib')
- except OSError:
- _lib = None
-
- if _lib is not None:
- _lib._strobe_state_size.restype = c_size_t
- _lib._strobe_state_size.argtypes = ()
- _strobe_state_u64_cnt = (_lib._strobe_state_size() + 7) // 8
- else:
- _strobe_state_u64_cnt = 1
-
- class CommsSession(Structure,StructureRepr):
- _fields_ = [
- ('cs_crypto', c_uint64 * _strobe_state_u64_cnt),
- ('cs_state', c_int),
- ]
-
- EC_PUBLIC_BYTES = 32
- EC_PRIVATE_BYTES = 32
-
- class CommsState(Structure,StructureRepr):
- _fields_ = [
- # The alignment of these may be off
- ('cs_active', CommsSession),
- ('cs_pending', CommsSession),
-
- ('cs_respkey', c_uint8 * EC_PRIVATE_BYTES),
- ('cs_resppubkey', c_uint8 * EC_PUBLIC_BYTES),
- ('cs_initpubkey', c_uint8 * EC_PUBLIC_BYTES),
-
- ('cs_start', CommsSession),
-
- ('cs_procmsg', process_msgfunc_t),
-
- ('cs_prevmsg', PktBuf),
- ('cs_prevmsgresp', PktBuf),
-
- ('cs_prevmsgbuf', c_uint8 * 64),
- ('cs_prevmsgrespbuf', c_uint8 * 64),
- ]
-
- if _lib is not None:
- _lib._comms_state_size.restype = c_size_t
- _lib._comms_state_size.argtypes = ()
-
- if _lib._comms_state_size() != sizeof(CommsState): # pragma: no cover
- raise RuntimeError('CommsState structure size mismatch!')
-
- X25519_BASE_POINT = (c_uint8 * (256//8)).in_dll(_lib, 'X25519_BASE_POINT')
-
- for func, ret, args in [
- ('comms_init', c_int, (POINTER(CommsState), process_msgfunc_t,
- POINTER(PktBuf), POINTER(PktBuf), POINTER(PktBuf))),
- ('comms_process', None, (POINTER(CommsState), PktBuf, POINTER(PktBuf))),
- ('strobe_seed_prng', None, (POINTER(c_uint8), c_ssize_t)),
- ('x25519', c_int, (c_uint8 * EC_PUBLIC_BYTES, c_uint8 * EC_PRIVATE_BYTES, c_uint8 * EC_PUBLIC_BYTES, c_int)),
- ]:
- f = getattr(_lib, func)
- f.restype = ret
- f.argtypes = args
- locals()[func] = f
-
- def x25519_wrap(out, scalar, base, clamp):
- outptr = (c_uint8 * EC_PUBLIC_BYTES).from_buffer_copy(out)
- scalarptr = (c_uint8 * EC_PRIVATE_BYTES).from_buffer_copy(scalar)
- baseptr = (c_uint8 * EC_PRIVATE_BYTES).from_buffer_copy(base)
-
- r = x25519(outptr, scalarptr, baseptr, clamp)
-
- if r != 0:
- raise RuntimeError('x25519 failed')
-
- return bytes(outptr)
-
- def x25519_genkey():
- return os.urandom(EC_PRIVATE_BYTES)
-
- def x25519_base(scalar, clamp):
- out = bytearray(EC_PUBLIC_BYTES)
- outptr = (c_uint8 * EC_PUBLIC_BYTES).from_buffer(out)
- scalarptr = (c_uint8 * EC_PRIVATE_BYTES).from_buffer_copy(scalar)
-
- r = x25519(outptr, scalarptr, X25519_BASE_POINT, clamp)
-
- if r != 0:
- raise RuntimeError('x25519 failed')
-
- return bytes(out)
-
- class X25519:
- '''Class to wrap the x25519 functions into something a bit more
- usable. This provides better key ingestion and better support
- for other key formats.
-
- Use either the gen method to generate a random key, or the frombytes
- method.
-
- a = X25519.gen()
- b = X25519.gen()
-
- a.dh(b.getpub()) == b.dh(a.getpub())
-
- That is, each party generates a key, sends their public part to the
- other party, and then uses their received public part as an argument
- to the dh method. The resulting value will be shared between the
- two parties.
- '''
-
- def __init__(self, key):
- self.privkey = key
- self.pubkey = x25519_base(key, 1)
-
- def dh(self, pub):
- '''Perform a DH operation using the public part pub.'''
-
- return x25519_wrap(self.pubkey, self.privkey, pub, 1)
-
- def getpub(self):
- '''Get the public part of the key. This is to be sent
- to the other party for key exchange.'''
-
- return self.pubkey
-
- def getpriv(self):
- return self.privkey
-
- @classmethod
- def gen(cls):
- '''Generate a random X25519 key.'''
-
- return cls(x25519_genkey())
-
- @classmethod
- def frombytes(cls, key):
- '''Generate an X25519 key from 32 bytes.'''
-
- return cls(key)
-
- def comms_process_wrap(state, input):
- '''A wrapper around comms_process that converts the argument
- into the buffer, and the returns the message as a bytes string.
- '''
-
- inpkt = make_pktbuf(input)
-
- outbytes = bytearray(64)
- outbuf = make_pktbuf(outbytes)
-
- comms_process(state, inpkt, outbuf)
-
- return outbuf._from()
-
- if __name__ == '__main__':
- key = X25519.gen()
-
- print(key.getpriv().hex())
- print(key.getpub().hex())
-
- class TestX25519(unittest.TestCase):
- PUBLIC_BYTES = EC_PUBLIC_BYTES
- PRIVATE_BYTES = EC_PRIVATE_BYTES
-
- def test_class(self):
- key = X25519.gen()
-
- pubkey = key.getpub()
- privkey = key.getpriv()
-
- apubkey = x25519_base(privkey, 1)
-
- self.assertEqual(apubkey, pubkey)
- self.assertEqual(X25519.frombytes(privkey).getpub(), pubkey)
-
- with self.assertRaises(ValueError):
- X25519(b'0'*31)
-
- def test_rfc7748_6_1(self):
- # KAT from https://datatracker.ietf.org/doc/html/rfc7748#section-6.1
- apriv = bytes.fromhex('77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a')
-
- akey = X25519(apriv)
- self.assertEqual(akey.getpub(), bytes.fromhex('8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a'))
-
- bpriv = bytes.fromhex('5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb')
- bkey = X25519(bpriv)
- self.assertEqual(bkey.getpub(), bytes.fromhex('de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f'))
-
- ss = bytes.fromhex('4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742')
- self.assertEqual(akey.dh(bkey.getpub()), ss)
- self.assertEqual(bkey.dh(akey.getpub()), ss)
-
- def test_basic_ops(self):
- aprivkey = x25519_genkey()
- apubkey = x25519_base(aprivkey, 1)
-
- bprivkey = x25519_genkey()
- bpubkey = x25519_base(bprivkey, 1)
-
- self.assertNotEqual(aprivkey, bprivkey)
- self.assertNotEqual(apubkey, bpubkey)
-
- ra = x25519_wrap(apubkey, aprivkey, bpubkey, 1)
-
- rb = x25519_wrap(bpubkey, bprivkey, apubkey, 1)
-
- self.assertEqual(ra, rb)
|