diff --git a/ntunnel.py b/ntunnel.py index 3736243..29df1fe 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -1,15 +1,21 @@ -from noise.connection import NoiseConnection, Keypair -from cryptography.hazmat.primitives.kdf.hkdf import HKDF -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.asymmetric import x448 +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import x448 +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from noise.connection import NoiseConnection, Keypair + +import tracemalloc; tracemalloc.start() +import argparse import asyncio +import base64 import os.path import shutil import socket +import sys import tempfile import threading import unittest @@ -231,6 +237,52 @@ class Tests_misc(unittest.TestCase): msg = os.urandom(i) self.assertEqual(len(msg), dec(enc(msg) + msg)) +def cmd_genkey(args): + keypair = genkeypair() + + key = x448.X448PrivateKey.generate() + + # public key part + enc = serialization.Encoding.Raw + pubformat = serialization.PublicFormat.Raw + pub = key.public_key().public_bytes(encoding=enc, format=pubformat) + + try: + fname = args.fname + '.pub' + with open(fname, 'x', encoding='ascii') as fp: + print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp) + except FileExistsError: + print('failed to create %s, file exists.' % fname, file=sys.stderr) + sys.exit(1) + + enc = serialization.Encoding.PEM + format = serialization.PrivateFormat.PKCS8 + encalgo = serialization.NoEncryption() + with open(args.fname, 'x', encoding='ascii') as fp: + fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii')) + +def main(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help') + + parser_gk = subparsers.add_parser('genkey', help='generate keys') + parser_gk.add_argument('fname', type=str, help='file name for the key') + parser_gk.set_defaults(func=cmd_genkey) + + args = parser.parse_args() + + try: + fun = args.func + except AttributeError: + parser.print_usage() + sys.exit(5) + + fun(args) + +if __name__ == '__main__': # pragma: no cover + main() + def _asyncsockpair(): '''Create a pair of sockets that are bound to each other. The function will return a tuple of two coroutine's, that @@ -241,7 +293,64 @@ def _asyncsockpair(): return asyncio.open_connection(sock=socka), \ asyncio.open_connection(sock=sockb) -class Tests(unittest.TestCase): +class TestMain(unittest.TestCase): + def setUp(self): + # setup temporary directory + d = os.path.realpath(tempfile.mkdtemp()) + self.basetempdir = d + self.tempdir = os.path.join(d, 'subdir') + os.mkdir(self.tempdir) + + # Generate key pairs + self.server_key_pair = genkeypair() + self.client_key_pair = genkeypair() + + os.chdir(self.tempdir) + + def tearDown(self): + shutil.rmtree(self.basetempdir) + self.tempdir = None + + def test_noargs(self): + sys.argv = [ 'prog' ] + + with self.assertRaises(SystemExit) as cm: + main() + + # XXX - not checking error message + + # And that it exited w/ the correct code + self.assertEqual(5, cm.exception.code) + + def test_genkey(self): + # that it can generate a key + sys.argv = [ 'prog', 'genkey', 'somefile' ] + main() + + with open('somefile.pub', encoding='ascii') as fp: + lines = fp.readlines() + self.assertEqual(len(lines), 1) + + keytype, keyvalue = lines[0].split() + + self.assertEqual(keytype, 'ntun-x448') + key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) + + with open('somefile', encoding='ascii') as fp: + data = fp.read().encode('ascii') + key = load_pem_private_key(data, password=None, backend=default_backend()) + self.assertIsInstance(key, x448.X448PrivateKey) + + # that a second call fails + with self.assertRaises(SystemExit) as cm: + main() + + # XXX - not checking error message + + # And that it exited w/ the correct code + self.assertEqual(1, cm.exception.code) + +class TestNoiseFowarder(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp())