An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

965 lines
27 KiB

  1. from cryptography.hazmat.backends import default_backend
  2. from cryptography.hazmat.primitives import hashes
  3. from cryptography.hazmat.primitives import serialization
  4. from cryptography.hazmat.primitives.asymmetric import x448
  5. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  6. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  7. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  8. from noise.connection import NoiseConnection, Keypair
  9. #import tracemalloc; tracemalloc.start()
  10. import argparse
  11. import asyncio
  12. import base64
  13. import os.path
  14. import shutil
  15. import socket
  16. import sys
  17. import tempfile
  18. import time
  19. import threading
  20. import unittest
  21. _backend = default_backend()
  22. def loadprivkey(fname):
  23. with open(fname, encoding='ascii') as fp:
  24. data = fp.read().encode('ascii')
  25. key = load_pem_private_key(data, password=None, backend=default_backend())
  26. return key
  27. def loadprivkeyraw(fname):
  28. key = loadprivkey(fname)
  29. enc = serialization.Encoding.Raw
  30. privformat = serialization.PrivateFormat.Raw
  31. encalgo = serialization.NoEncryption()
  32. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  33. def loadpubkeyraw(fname):
  34. with open(fname, encoding='ascii') as fp:
  35. lines = fp.readlines()
  36. # XXX
  37. #self.assertEqual(len(lines), 1)
  38. keytype, keyvalue = lines[0].split()
  39. if keytype != 'ntun-x448':
  40. raise RuntimeError
  41. return base64.urlsafe_b64decode(keyvalue)
  42. def genkeypair():
  43. '''Generates a keypair, and returns a tuple of (public, private).
  44. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  45. key = x448.X448PrivateKey.generate()
  46. enc = serialization.Encoding.Raw
  47. pubformat = serialization.PublicFormat.Raw
  48. privformat = serialization.PrivateFormat.Raw
  49. encalgo = serialization.NoEncryption()
  50. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  51. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  52. return pub, priv
  53. def _makefut(obj):
  54. loop = asyncio.get_running_loop()
  55. fut = loop.create_future()
  56. fut.set_result(obj)
  57. return fut
  58. def _makeunix(path):
  59. '''Make a properly formed unix path socket string.'''
  60. return 'unix:%s' % path
  61. def _parsesockstr(sockstr):
  62. proto, rem = sockstr.split(':', 1)
  63. return proto, rem
  64. async def connectsockstr(sockstr):
  65. proto, rem = _parsesockstr(sockstr)
  66. reader, writer = await asyncio.open_unix_connection(rem)
  67. return reader, writer
  68. async def listensockstr(sockstr, cb):
  69. '''Wrapper for asyncio.start_x_server.
  70. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  71. If the proto has a default parameter, the value can be used
  72. directly, like: 'proto:value'. This is only allowed when the
  73. value can unambiguously be determined not to be a param.
  74. The characters that define 'param' must be all lower case ascii
  75. characters and may contain an underscore. The first character
  76. must not be and underscore.
  77. Supported protocols:
  78. unix:
  79. Default parameter is path.
  80. The path parameter specifies the path to the
  81. unix domain socket. The path MUST start w/ a
  82. slash if it is used as a default parameter.
  83. '''
  84. proto, rem = _parsesockstr(sockstr)
  85. server = await asyncio.start_unix_server(cb, path=rem)
  86. return server
  87. # !!python makemessagelengths.py
  88. _handshakelens = \
  89. [72, 72, 88]
  90. def _genciphfun(hash, ad):
  91. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  92. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  93. key = hkdf.derive(hash)
  94. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  95. backend=_backend)
  96. enctor = cipher.encryptor()
  97. def encfun(data):
  98. # Returns the two bytes for length
  99. val = len(data)
  100. encbytes = enctor.update(data[:16])
  101. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  102. return (val ^ mask).to_bytes(length=2, byteorder='big')
  103. def decfun(data):
  104. # takes off the data and returns the total
  105. # length
  106. val = int.from_bytes(data[:2], byteorder='big')
  107. encbytes = enctor.update(data[2:2 + 16])
  108. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  109. return val ^ mask
  110. return encfun, decfun
  111. async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
  112. '''A function that forwards data between the plain text pair of
  113. streams to the encrypted session.
  114. The mode paramater must be one of 'init' or 'resp' for initiator
  115. and responder.
  116. The encrdrwrr is an await object that will return a tunle of the
  117. reader and writer streams for the encrypted side of the
  118. connection.
  119. The ptpairfun parameter is a function that will be passed the
  120. public key bytes for the remote client. This can be used to
  121. both validate that the correct client is connecting, and to
  122. pass back the correct plain text reader/writer objects that
  123. match the provided static key. The function must be an async
  124. function.
  125. In the case of the initiator, pub_key must be provided and will
  126. be used to authenticate the responder side of the connection.
  127. The priv_key parameter is used to authenticate this side of the
  128. session.
  129. Both priv_key and pub_key parameters must be 56 bytes. For example,
  130. the pair that is returned by genkeypair.
  131. '''
  132. rdr, wrr = await encrdrwrr
  133. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  134. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  135. if pub_key is not None:
  136. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  137. pub_key)
  138. if mode == 'resp':
  139. proto.set_as_responder()
  140. proto.start_handshake()
  141. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  142. wrr.write(proto.write_message())
  143. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  144. elif mode == 'init':
  145. proto.set_as_initiator()
  146. proto.start_handshake()
  147. wrr.write(proto.write_message())
  148. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  149. wrr.write(proto.write_message())
  150. if not proto.handshake_finished: # pragma: no cover
  151. raise RuntimeError('failed to finish handshake')
  152. reader, writer = await ptpairfun(getattr(proto.get_keypair(
  153. Keypair.REMOTE_STATIC), 'public_bytes', None))
  154. # generate the keys for lengths
  155. # XXX - get_handshake_hash is probably not the best option, but
  156. # this is only to obscure lengths, it is not required to be secure
  157. # as the underlying NoiseProtocol securely validates everything.
  158. # It is marginally useful as writing patterns likely expose the
  159. # true length. Adding padding could marginally help w/ this.
  160. if mode == 'resp':
  161. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  162. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  163. elif mode == 'init':
  164. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  165. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  166. async def decses():
  167. try:
  168. while True:
  169. try:
  170. msg = await rdr.readexactly(2 + 16)
  171. except asyncio.streams.IncompleteReadError:
  172. if rdr.at_eof():
  173. return 'dec'
  174. tlen = declenfun(msg)
  175. rmsg = await rdr.readexactly(tlen - 16)
  176. tmsg = msg[2:] + rmsg
  177. writer.write(proto.decrypt(tmsg))
  178. await writer.drain()
  179. #except:
  180. # import traceback
  181. # traceback.print_exc()
  182. # raise
  183. finally:
  184. try:
  185. writer.write_eof()
  186. except OSError as e:
  187. if e.errno != 57:
  188. raise
  189. async def encses():
  190. try:
  191. while True:
  192. # largest message
  193. ptmsg = await reader.read(65535 - 16)
  194. if not ptmsg:
  195. # eof
  196. return 'enc'
  197. encmsg = proto.encrypt(ptmsg)
  198. wrr.write(enclenfun(encmsg))
  199. wrr.write(encmsg)
  200. await wrr.drain()
  201. #except:
  202. # import traceback
  203. # traceback.print_exc()
  204. # raise
  205. finally:
  206. wrr.write_eof()
  207. return await asyncio.gather(decses(), encses())
  208. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  209. # Slightly modified to timeout and to print trace back when canceled.
  210. # This makes it easier to figure out what "froze".
  211. def async_test(f):
  212. def wrapper(*args, **kwargs):
  213. async def tbcapture():
  214. try:
  215. return await f(*args, **kwargs)
  216. except asyncio.CancelledError as e:
  217. # if we are going to be cancelled, print out a tb
  218. import traceback
  219. traceback.print_exc()
  220. raise
  221. loop = asyncio.get_event_loop()
  222. # timeout after 4 seconds
  223. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  224. return wrapper
  225. class Tests_misc(unittest.TestCase):
  226. def test_listensockstr(self):
  227. # XXX write test
  228. pass
  229. def test_genciphfun(self):
  230. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  231. msg = b'this is a bunch of data'
  232. tb = enc(msg)
  233. self.assertEqual(len(msg), dec(tb + msg))
  234. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  235. msg = os.urandom(i)
  236. self.assertEqual(len(msg), dec(enc(msg) + msg))
  237. def cmd_client(args):
  238. privkey = loadprivkeyraw(args.clientkey)
  239. pubkey = loadpubkeyraw(args.servkey)
  240. async def runnf(rdr, wrr):
  241. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  242. a = await NoiseForwarder('init',
  243. encpair, lambda x: _makefut((rdr, wrr)),
  244. priv_key=privkey, pub_key=pubkey)
  245. # Setup client listener
  246. ssock = listensockstr(args.clientlisten, runnf)
  247. loop = asyncio.get_event_loop()
  248. obj = loop.run_until_complete(ssock)
  249. loop.run_until_complete(obj.serve_forever())
  250. def cmd_server(args):
  251. privkey = loadprivkeyraw(args.servkey)
  252. pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
  253. async def runnf(rdr, wrr):
  254. async def checkclientfun(clientkey):
  255. if clientkey not in pubkeys:
  256. raise RuntimeError('invalid key provided')
  257. return await connectsockstr(args.servtarget)
  258. a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
  259. checkclientfun, priv_key=privkey)
  260. # Setup server listener
  261. ssock = listensockstr(args.servlisten, runnf)
  262. loop = asyncio.get_event_loop()
  263. obj = loop.run_until_complete(ssock)
  264. loop.run_until_complete(obj.serve_forever())
  265. def cmd_genkey(args):
  266. keypair = genkeypair()
  267. key = x448.X448PrivateKey.generate()
  268. # public key part
  269. enc = serialization.Encoding.Raw
  270. pubformat = serialization.PublicFormat.Raw
  271. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  272. try:
  273. fname = args.fname + '.pub'
  274. with open(fname, 'x', encoding='ascii') as fp:
  275. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  276. except FileExistsError:
  277. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  278. sys.exit(2)
  279. enc = serialization.Encoding.PEM
  280. format = serialization.PrivateFormat.PKCS8
  281. encalgo = serialization.NoEncryption()
  282. with open(args.fname, 'x', encoding='ascii') as fp:
  283. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  284. def main():
  285. parser = argparse.ArgumentParser()
  286. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  287. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  288. parser_gk.add_argument('fname', type=str, help='file name for the key')
  289. parser_gk.set_defaults(func=cmd_genkey)
  290. parser_serv = subparsers.add_parser('server', help='run a server')
  291. parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  292. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  293. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  294. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  295. parser_serv.set_defaults(func=cmd_server)
  296. parser_client = subparsers.add_parser('client', help='run a client')
  297. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  298. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  299. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  300. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  301. parser_client.set_defaults(func=cmd_client)
  302. args = parser.parse_args()
  303. try:
  304. fun = args.func
  305. except AttributeError:
  306. parser.print_usage()
  307. sys.exit(5)
  308. fun(args)
  309. if __name__ == '__main__': # pragma: no cover
  310. main()
  311. def _asyncsockpair():
  312. '''Create a pair of sockets that are bound to each other.
  313. The function will return a tuple of two coroutine's, that
  314. each, when await'ed upon, will return the reader/writer pair.'''
  315. socka, sockb = socket.socketpair()
  316. return asyncio.open_connection(sock=socka), \
  317. asyncio.open_connection(sock=sockb)
  318. async def _awaitfile(fname):
  319. while not os.path.exists(fname):
  320. await asyncio.sleep(.01)
  321. return True
  322. class TestMain(unittest.TestCase):
  323. def setUp(self):
  324. # setup temporary directory
  325. d = os.path.realpath(tempfile.mkdtemp())
  326. self.basetempdir = d
  327. self.tempdir = os.path.join(d, 'subdir')
  328. os.mkdir(self.tempdir)
  329. # Generate key pairs
  330. self.server_key_pair = genkeypair()
  331. self.client_key_pair = genkeypair()
  332. os.chdir(self.tempdir)
  333. def tearDown(self):
  334. #print('td:', time.time())
  335. shutil.rmtree(self.basetempdir)
  336. self.tempdir = None
  337. @async_test
  338. async def test_noargs(self):
  339. proc = await self.run_with_args()
  340. await proc.wait()
  341. # XXX - not checking error message
  342. # And that it exited w/ the correct code
  343. self.assertEqual(proc.returncode, 5)
  344. def run_with_args(self, *args, pipes=True):
  345. kwargs = {}
  346. if pipes:
  347. kwargs.update(dict(
  348. stdout=asyncio.subprocess.PIPE,
  349. stderr=asyncio.subprocess.PIPE))
  350. return asyncio.create_subprocess_exec(sys.executable,
  351. # XXX - figure out how to add coverage data on these runs
  352. #'-m', 'coverage', 'run', '-p',
  353. __file__, *args, **kwargs)
  354. async def genkey(self, name):
  355. proc = await self.run_with_args('genkey', name, pipes=False)
  356. await proc.wait()
  357. self.assertEqual(proc.returncode, 0)
  358. @async_test
  359. async def test_loadpubkey(self):
  360. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  361. await self.genkey(keypath)
  362. privkey = loadprivkey(keypath)
  363. enc = serialization.Encoding.Raw
  364. pubformat = serialization.PublicFormat.Raw
  365. pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
  366. pubkey = loadpubkeyraw(keypath + '.pub')
  367. self.assertEqual(pubkeybytes, pubkey)
  368. privrawkey = loadprivkeyraw(keypath)
  369. enc = serialization.Encoding.Raw
  370. privformat = serialization.PrivateFormat.Raw
  371. encalgo = serialization.NoEncryption()
  372. rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  373. self.assertEqual(rprivrawkey, privrawkey)
  374. @async_test
  375. async def test_clientkeymismatch(self):
  376. # make sure that if there's a client key mismatch, we
  377. # don't connect
  378. # Generate necessar keys
  379. servkeypath = os.path.join(self.tempdir, 'server_key')
  380. await self.genkey(servkeypath)
  381. clientkeypath = os.path.join(self.tempdir, 'client_key')
  382. await self.genkey(clientkeypath)
  383. badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
  384. await self.genkey(badclientkeypath)
  385. # forwards connectsion to this socket (created by client)
  386. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  387. ptclientstr = _makeunix(ptclientpath)
  388. # this is the socket server listen to
  389. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  390. incservstr = _makeunix(incservpath)
  391. # to this socket, opened by server
  392. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  393. servtargstr = _makeunix(servtargpath)
  394. # Setup server target listener
  395. ptsock = []
  396. ptsockevent = asyncio.Event()
  397. def ptsockaccept(reader, writer, ptsock=ptsock):
  398. ptsock.append((reader, writer))
  399. ptsockevent.set()
  400. # Bind to pt listener
  401. lsock = await listensockstr(servtargstr, ptsockaccept)
  402. # Startup the server
  403. server = await self.run_with_args('server',
  404. '-c', clientkeypath + '.pub',
  405. servkeypath, incservstr, servtargstr)
  406. # Startup the client with the "bad" key
  407. client = await self.run_with_args('client',
  408. badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr)
  409. # wait for server target to be created
  410. await _awaitfile(servtargpath)
  411. # wait for server to start
  412. await _awaitfile(incservpath)
  413. # wait for client to start
  414. await _awaitfile(ptclientpath)
  415. # Connect to the client
  416. reader, writer = await connectsockstr(ptclientstr)
  417. with self.assertRaises(asyncio.futures.TimeoutError):
  418. # make sure that we don't get the conenction
  419. await asyncio.wait_for(ptsockevent.wait(), 1)
  420. # Make sure that when the server is terminated
  421. server.terminate()
  422. # that it's stderr
  423. stdout, stderr = await server.communicate()
  424. #print('s:', repr((stdout, stderr)))
  425. # doesn't have an exceptions never retrieved
  426. # even the example echo server has this same leak
  427. #self.assertNotIn(b'Task exception was never retrieved', stderr)
  428. @async_test
  429. async def test_end2end(self):
  430. # Generate necessar keys
  431. servkeypath = os.path.join(self.tempdir, 'server_key')
  432. await self.genkey(servkeypath)
  433. clientkeypath = os.path.join(self.tempdir, 'client_key')
  434. await self.genkey(clientkeypath)
  435. # forwards connectsion to this socket (created by client)
  436. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  437. ptclientstr = _makeunix(ptclientpath)
  438. # this is the socket server listen to
  439. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  440. incservstr = _makeunix(incservpath)
  441. # to this socket, opened by server
  442. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  443. servtargstr = _makeunix(servtargpath)
  444. # Setup server target listener
  445. ptsock = []
  446. ptsockevent = asyncio.Event()
  447. def ptsockaccept(reader, writer, ptsock=ptsock):
  448. ptsock.append((reader, writer))
  449. ptsockevent.set()
  450. # Bind to pt listener
  451. lsock = await listensockstr(servtargstr, ptsockaccept)
  452. # Startup the server
  453. server = await self.run_with_args('server',
  454. '-c', clientkeypath + '.pub',
  455. servkeypath, incservstr, servtargstr,
  456. pipes=False)
  457. # Startup the client
  458. client = await self.run_with_args('client',
  459. clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
  460. pipes=False)
  461. # wait for server target to be created
  462. await _awaitfile(servtargpath)
  463. # wait for server to start
  464. await _awaitfile(incservpath)
  465. # wait for client to start
  466. await _awaitfile(ptclientpath)
  467. # Connect to the client
  468. reader, writer = await connectsockstr(ptclientstr)
  469. # send a message
  470. ptmsg = b'this is a message for testing'
  471. writer.write(ptmsg)
  472. # make sure that we got the conenction
  473. await ptsockevent.wait()
  474. # get the connection
  475. endrdr, endwrr = ptsock[0]
  476. # make sure we can read back what we sent
  477. self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
  478. # test some additional messages
  479. for i in [ 129, 1287, 28792, 129872 ]:
  480. # in on direction
  481. msg = os.urandom(i)
  482. writer.write(msg)
  483. self.assertEqual(msg, await endrdr.readexactly(len(msg)))
  484. # and the other
  485. endwrr.write(msg)
  486. self.assertEqual(msg, await reader.readexactly(len(msg)))
  487. @async_test
  488. async def test_genkey(self):
  489. # that it can generate a key
  490. proc = await self.run_with_args('genkey', 'somefile')
  491. await proc.wait()
  492. #print(await proc.communicate())
  493. self.assertEqual(proc.returncode, 0)
  494. with open('somefile.pub', encoding='ascii') as fp:
  495. lines = fp.readlines()
  496. self.assertEqual(len(lines), 1)
  497. keytype, keyvalue = lines[0].split()
  498. self.assertEqual(keytype, 'ntun-x448')
  499. key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
  500. key = loadprivkey('somefile')
  501. self.assertIsInstance(key, x448.X448PrivateKey)
  502. # that a second call fails
  503. proc = await self.run_with_args('genkey', 'somefile')
  504. await proc.wait()
  505. stdoutdata, stderrdata = await proc.communicate()
  506. self.assertFalse(stdoutdata)
  507. self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
  508. # And that it exited w/ the correct code
  509. self.assertEqual(proc.returncode, 2)
  510. class TestNoiseFowarder(unittest.TestCase):
  511. def setUp(self):
  512. # setup temporary directory
  513. d = os.path.realpath(tempfile.mkdtemp())
  514. self.basetempdir = d
  515. self.tempdir = os.path.join(d, 'subdir')
  516. os.mkdir(self.tempdir)
  517. # Generate key pairs
  518. self.server_key_pair = genkeypair()
  519. self.client_key_pair = genkeypair()
  520. def tearDown(self):
  521. shutil.rmtree(self.basetempdir)
  522. self.tempdir = None
  523. @async_test
  524. async def test_clientkeymissmatch(self):
  525. # generate a key that is incorrect
  526. wrongclient_key_pair = genkeypair()
  527. # the secure socket
  528. clssockapair, clssockbpair = _asyncsockpair()
  529. reader, writer = await clssockapair
  530. async def wrongkey(v):
  531. raise ValueError('no key matches')
  532. # create the server
  533. servnf = asyncio.create_task(NoiseForwarder('resp',
  534. clssockbpair, wrongkey,
  535. priv_key=self.server_key_pair[1]))
  536. # Create client
  537. proto = NoiseConnection.from_name(
  538. b'Noise_XK_448_ChaChaPoly_SHA256')
  539. proto.set_as_initiator()
  540. # Setup wrong client key
  541. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  542. wrongclient_key_pair[1])
  543. # but the correct server key
  544. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  545. self.server_key_pair[0])
  546. proto.start_handshake()
  547. # Send first message
  548. message = proto.write_message()
  549. self.assertEqual(len(message), _handshakelens[0])
  550. writer.write(message)
  551. # Get response
  552. respmsg = await reader.readexactly(_handshakelens[1])
  553. proto.read_message(respmsg)
  554. # Send final reply
  555. message = proto.write_message()
  556. writer.write(message)
  557. # Make sure handshake has completed
  558. self.assertTrue(proto.handshake_finished)
  559. with self.assertRaises(ValueError):
  560. await servnf
  561. @async_test
  562. async def test_server(self):
  563. # Test is plumbed:
  564. # (reader, writer) -> servsock ->
  565. # (rdr, wrr) NoiseForward (reader, writer) ->
  566. # servptsock -> (ptsock[0], ptsock[1])
  567. # Path that the server will sit on
  568. servsockpath = os.path.join(self.tempdir, 'servsock')
  569. servarg = _makeunix(servsockpath)
  570. # Path that the server will send pt data to
  571. servptpath = os.path.join(self.tempdir, 'servptsock')
  572. # Setup pt target listener
  573. pttarg = _makeunix(servptpath)
  574. ptsock = []
  575. ptsockevent = asyncio.Event()
  576. def ptsockaccept(reader, writer, ptsock=ptsock):
  577. ptsock.append((reader, writer))
  578. ptsockevent.set()
  579. # Bind to pt listener
  580. lsock = await listensockstr(pttarg, ptsockaccept)
  581. nfs = []
  582. event = asyncio.Event()
  583. async def runnf(rdr, wrr):
  584. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  585. a = await NoiseForwarder('resp',
  586. _makefut((rdr, wrr)), lambda x: ptpairfun,
  587. priv_key=self.server_key_pair[1])
  588. nfs.append(a)
  589. event.set()
  590. # Setup server listener
  591. ssock = await listensockstr(servarg, runnf)
  592. # Connect to server
  593. reader, writer = await connectsockstr(servarg)
  594. # Create client
  595. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  596. proto.set_as_initiator()
  597. # Setup required keys
  598. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  599. self.client_key_pair[1])
  600. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  601. self.server_key_pair[0])
  602. proto.start_handshake()
  603. # Send first message
  604. message = proto.write_message()
  605. self.assertEqual(len(message), _handshakelens[0])
  606. writer.write(message)
  607. # Get response
  608. respmsg = await reader.readexactly(_handshakelens[1])
  609. proto.read_message(respmsg)
  610. # Send final reply
  611. message = proto.write_message()
  612. writer.write(message)
  613. # Make sure handshake has completed
  614. self.assertTrue(proto.handshake_finished)
  615. # generate the keys for lengths
  616. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  617. b'toresp')
  618. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  619. b'toinit')
  620. # write a test message
  621. ptmsg = b'this is a test message that should be a little in length'
  622. encmsg = proto.encrypt(ptmsg)
  623. writer.write(enclenfun(encmsg))
  624. writer.write(encmsg)
  625. # wait for the connection to arrive
  626. await ptsockevent.wait()
  627. ptreader, ptwriter = ptsock[0]
  628. # read the test message
  629. rptmsg = await ptreader.readexactly(len(ptmsg))
  630. self.assertEqual(rptmsg, ptmsg)
  631. # write a different message
  632. ptmsg = os.urandom(2843)
  633. encmsg = proto.encrypt(ptmsg)
  634. writer.write(enclenfun(encmsg))
  635. writer.write(encmsg)
  636. # read the test message
  637. rptmsg = await ptreader.readexactly(len(ptmsg))
  638. self.assertEqual(rptmsg, ptmsg)
  639. # now try the other way
  640. ptmsg = os.urandom(912)
  641. ptwriter.write(ptmsg)
  642. # find out how much we need to read
  643. encmsg = await reader.readexactly(2 + 16)
  644. tlen = declenfun(encmsg)
  645. # read the rest of the message
  646. rencmsg = await reader.readexactly(tlen - 16)
  647. tmsg = encmsg[2:] + rencmsg
  648. rptmsg = proto.decrypt(tmsg)
  649. self.assertEqual(rptmsg, ptmsg)
  650. # shut down sending
  651. writer.write_eof()
  652. # so pt reader should be shut down
  653. self.assertEqual(b'', await ptreader.read(1))
  654. self.assertTrue(ptreader.at_eof())
  655. # shut down pt
  656. ptwriter.write_eof()
  657. # make sure the enc reader is eof
  658. self.assertEqual(b'', await reader.read(1))
  659. self.assertTrue(reader.at_eof())
  660. await event.wait()
  661. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  662. @async_test
  663. async def test_serverclient(self):
  664. # plumbing:
  665. #
  666. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  667. #
  668. ptcsockapair, ptcsockbpair = _asyncsockpair()
  669. ptcareader, ptcawriter = await ptcsockapair
  670. #ptcsockbpair passed directly
  671. clssockapair, clssockbpair = _asyncsockpair()
  672. #both passed directly
  673. ptssockapair, ptssockbpair = _asyncsockpair()
  674. #ptssockapair passed directly
  675. ptsbreader, ptsbwriter = await ptssockbpair
  676. async def validateclientkey(pubkey):
  677. if pubkey != self.client_key_pair[0]:
  678. raise ValueError('invalid key')
  679. return await ptssockapair
  680. clientnf = asyncio.create_task(NoiseForwarder('init',
  681. clssockapair, lambda x: ptcsockbpair,
  682. priv_key=self.client_key_pair[1],
  683. pub_key=self.server_key_pair[0]))
  684. servnf = asyncio.create_task(NoiseForwarder('resp',
  685. clssockbpair, validateclientkey,
  686. priv_key=self.server_key_pair[1]))
  687. # send a message
  688. msga = os.urandom(183)
  689. ptcawriter.write(msga)
  690. # make sure we get the same message
  691. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  692. # send a second message
  693. msga = os.urandom(2834)
  694. ptcawriter.write(msga)
  695. # make sure we get the same message
  696. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  697. # send a message larger than the block size
  698. msga = os.urandom(103958)
  699. ptcawriter.write(msga)
  700. # make sure we get the same message
  701. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  702. # send a message the other direction
  703. msga = os.urandom(103958)
  704. ptsbwriter.write(msga)
  705. # make sure we get the same message
  706. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  707. # close down the pt writers, the rest should follow
  708. ptsbwriter.write_eof()
  709. ptcawriter.write_eof()
  710. # make sure they are closed, and there is no more data
  711. self.assertEqual(b'', await ptsbreader.read(1))
  712. self.assertTrue(ptsbreader.at_eof())
  713. self.assertEqual(b'', await ptcareader.read(1))
  714. self.assertTrue(ptcareader.at_eof())
  715. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  716. self.assertEqual([ 'dec', 'enc' ], await servnf)