| @@ -1,15 +1,21 @@ | |||
| from noise.connection import NoiseConnection, Keypair | |||
| from cryptography.hazmat.primitives.kdf.hkdf import HKDF | |||
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |||
| from cryptography.hazmat.primitives import hashes | |||
| from cryptography.hazmat.backends import default_backend | |||
| from cryptography.hazmat.primitives.asymmetric import x448 | |||
| from cryptography.hazmat.primitives import hashes | |||
| from cryptography.hazmat.primitives import serialization | |||
| from cryptography.hazmat.primitives.asymmetric import x448 | |||
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |||
| from cryptography.hazmat.primitives.kdf.hkdf import HKDF | |||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | |||
| from noise.connection import NoiseConnection, Keypair | |||
| import tracemalloc; tracemalloc.start() | |||
| import argparse | |||
| import asyncio | |||
| import base64 | |||
| import os.path | |||
| import shutil | |||
| import socket | |||
| import sys | |||
| import tempfile | |||
| import threading | |||
| import unittest | |||
| @@ -231,6 +237,52 @@ class Tests_misc(unittest.TestCase): | |||
| msg = os.urandom(i) | |||
| self.assertEqual(len(msg), dec(enc(msg) + msg)) | |||
| def cmd_genkey(args): | |||
| keypair = genkeypair() | |||
| key = x448.X448PrivateKey.generate() | |||
| # public key part | |||
| enc = serialization.Encoding.Raw | |||
| pubformat = serialization.PublicFormat.Raw | |||
| pub = key.public_key().public_bytes(encoding=enc, format=pubformat) | |||
| try: | |||
| fname = args.fname + '.pub' | |||
| with open(fname, 'x', encoding='ascii') as fp: | |||
| print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp) | |||
| except FileExistsError: | |||
| print('failed to create %s, file exists.' % fname, file=sys.stderr) | |||
| sys.exit(1) | |||
| enc = serialization.Encoding.PEM | |||
| format = serialization.PrivateFormat.PKCS8 | |||
| encalgo = serialization.NoEncryption() | |||
| with open(args.fname, 'x', encoding='ascii') as fp: | |||
| fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii')) | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help') | |||
| parser_gk = subparsers.add_parser('genkey', help='generate keys') | |||
| parser_gk.add_argument('fname', type=str, help='file name for the key') | |||
| parser_gk.set_defaults(func=cmd_genkey) | |||
| args = parser.parse_args() | |||
| try: | |||
| fun = args.func | |||
| except AttributeError: | |||
| parser.print_usage() | |||
| sys.exit(5) | |||
| fun(args) | |||
| if __name__ == '__main__': # pragma: no cover | |||
| main() | |||
| def _asyncsockpair(): | |||
| '''Create a pair of sockets that are bound to each other. | |||
| The function will return a tuple of two coroutine's, that | |||
| @@ -241,7 +293,64 @@ def _asyncsockpair(): | |||
| return asyncio.open_connection(sock=socka), \ | |||
| asyncio.open_connection(sock=sockb) | |||
| class Tests(unittest.TestCase): | |||
| class TestMain(unittest.TestCase): | |||
| def setUp(self): | |||
| # setup temporary directory | |||
| d = os.path.realpath(tempfile.mkdtemp()) | |||
| self.basetempdir = d | |||
| self.tempdir = os.path.join(d, 'subdir') | |||
| os.mkdir(self.tempdir) | |||
| # Generate key pairs | |||
| self.server_key_pair = genkeypair() | |||
| self.client_key_pair = genkeypair() | |||
| os.chdir(self.tempdir) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.basetempdir) | |||
| self.tempdir = None | |||
| def test_noargs(self): | |||
| sys.argv = [ 'prog' ] | |||
| with self.assertRaises(SystemExit) as cm: | |||
| main() | |||
| # XXX - not checking error message | |||
| # And that it exited w/ the correct code | |||
| self.assertEqual(5, cm.exception.code) | |||
| def test_genkey(self): | |||
| # that it can generate a key | |||
| sys.argv = [ 'prog', 'genkey', 'somefile' ] | |||
| main() | |||
| with open('somefile.pub', encoding='ascii') as fp: | |||
| lines = fp.readlines() | |||
| self.assertEqual(len(lines), 1) | |||
| keytype, keyvalue = lines[0].split() | |||
| self.assertEqual(keytype, 'ntun-x448') | |||
| key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) | |||
| with open('somefile', encoding='ascii') as fp: | |||
| data = fp.read().encode('ascii') | |||
| key = load_pem_private_key(data, password=None, backend=default_backend()) | |||
| self.assertIsInstance(key, x448.X448PrivateKey) | |||
| # that a second call fails | |||
| with self.assertRaises(SystemExit) as cm: | |||
| main() | |||
| # XXX - not checking error message | |||
| # And that it exited w/ the correct code | |||
| self.assertEqual(1, cm.exception.code) | |||
| class TestNoiseFowarder(unittest.TestCase): | |||
| def setUp(self): | |||
| # setup temporary directory | |||
| d = os.path.realpath(tempfile.mkdtemp()) | |||