| @@ -1,15 +1,21 @@ | |||||
| from noise.connection import NoiseConnection, Keypair | |||||
| from cryptography.hazmat.primitives.kdf.hkdf import HKDF | |||||
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |||||
| from cryptography.hazmat.primitives import hashes | |||||
| from cryptography.hazmat.backends import default_backend | from cryptography.hazmat.backends import default_backend | ||||
| from cryptography.hazmat.primitives.asymmetric import x448 | |||||
| from cryptography.hazmat.primitives import hashes | |||||
| from cryptography.hazmat.primitives import serialization | from cryptography.hazmat.primitives import serialization | ||||
| from cryptography.hazmat.primitives.asymmetric import x448 | |||||
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |||||
| from cryptography.hazmat.primitives.kdf.hkdf import HKDF | |||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | |||||
| from noise.connection import NoiseConnection, Keypair | |||||
| import tracemalloc; tracemalloc.start() | |||||
| import argparse | |||||
| import asyncio | import asyncio | ||||
| import base64 | |||||
| import os.path | import os.path | ||||
| import shutil | import shutil | ||||
| import socket | import socket | ||||
| import sys | |||||
| import tempfile | import tempfile | ||||
| import threading | import threading | ||||
| import unittest | import unittest | ||||
| @@ -231,6 +237,52 @@ class Tests_misc(unittest.TestCase): | |||||
| msg = os.urandom(i) | msg = os.urandom(i) | ||||
| self.assertEqual(len(msg), dec(enc(msg) + msg)) | self.assertEqual(len(msg), dec(enc(msg) + msg)) | ||||
| def cmd_genkey(args): | |||||
| keypair = genkeypair() | |||||
| key = x448.X448PrivateKey.generate() | |||||
| # public key part | |||||
| enc = serialization.Encoding.Raw | |||||
| pubformat = serialization.PublicFormat.Raw | |||||
| pub = key.public_key().public_bytes(encoding=enc, format=pubformat) | |||||
| try: | |||||
| fname = args.fname + '.pub' | |||||
| with open(fname, 'x', encoding='ascii') as fp: | |||||
| print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp) | |||||
| except FileExistsError: | |||||
| print('failed to create %s, file exists.' % fname, file=sys.stderr) | |||||
| sys.exit(1) | |||||
| enc = serialization.Encoding.PEM | |||||
| format = serialization.PrivateFormat.PKCS8 | |||||
| encalgo = serialization.NoEncryption() | |||||
| with open(args.fname, 'x', encoding='ascii') as fp: | |||||
| fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii')) | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser() | |||||
| subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help') | |||||
| parser_gk = subparsers.add_parser('genkey', help='generate keys') | |||||
| parser_gk.add_argument('fname', type=str, help='file name for the key') | |||||
| parser_gk.set_defaults(func=cmd_genkey) | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| fun = args.func | |||||
| except AttributeError: | |||||
| parser.print_usage() | |||||
| sys.exit(5) | |||||
| fun(args) | |||||
| if __name__ == '__main__': # pragma: no cover | |||||
| main() | |||||
| def _asyncsockpair(): | def _asyncsockpair(): | ||||
| '''Create a pair of sockets that are bound to each other. | '''Create a pair of sockets that are bound to each other. | ||||
| The function will return a tuple of two coroutine's, that | The function will return a tuple of two coroutine's, that | ||||
| @@ -241,7 +293,64 @@ def _asyncsockpair(): | |||||
| return asyncio.open_connection(sock=socka), \ | return asyncio.open_connection(sock=socka), \ | ||||
| asyncio.open_connection(sock=sockb) | asyncio.open_connection(sock=sockb) | ||||
| class Tests(unittest.TestCase): | |||||
| class TestMain(unittest.TestCase): | |||||
| def setUp(self): | |||||
| # setup temporary directory | |||||
| d = os.path.realpath(tempfile.mkdtemp()) | |||||
| self.basetempdir = d | |||||
| self.tempdir = os.path.join(d, 'subdir') | |||||
| os.mkdir(self.tempdir) | |||||
| # Generate key pairs | |||||
| self.server_key_pair = genkeypair() | |||||
| self.client_key_pair = genkeypair() | |||||
| os.chdir(self.tempdir) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.basetempdir) | |||||
| self.tempdir = None | |||||
| def test_noargs(self): | |||||
| sys.argv = [ 'prog' ] | |||||
| with self.assertRaises(SystemExit) as cm: | |||||
| main() | |||||
| # XXX - not checking error message | |||||
| # And that it exited w/ the correct code | |||||
| self.assertEqual(5, cm.exception.code) | |||||
| def test_genkey(self): | |||||
| # that it can generate a key | |||||
| sys.argv = [ 'prog', 'genkey', 'somefile' ] | |||||
| main() | |||||
| with open('somefile.pub', encoding='ascii') as fp: | |||||
| lines = fp.readlines() | |||||
| self.assertEqual(len(lines), 1) | |||||
| keytype, keyvalue = lines[0].split() | |||||
| self.assertEqual(keytype, 'ntun-x448') | |||||
| key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) | |||||
| with open('somefile', encoding='ascii') as fp: | |||||
| data = fp.read().encode('ascii') | |||||
| key = load_pem_private_key(data, password=None, backend=default_backend()) | |||||
| self.assertIsInstance(key, x448.X448PrivateKey) | |||||
| # that a second call fails | |||||
| with self.assertRaises(SystemExit) as cm: | |||||
| main() | |||||
| # XXX - not checking error message | |||||
| # And that it exited w/ the correct code | |||||
| self.assertEqual(1, cm.exception.code) | |||||
| class TestNoiseFowarder(unittest.TestCase): | |||||
| def setUp(self): | def setUp(self): | ||||
| # setup temporary directory | # setup temporary directory | ||||
| d = os.path.realpath(tempfile.mkdtemp()) | d = os.path.realpath(tempfile.mkdtemp()) | ||||