An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

811 lines
22 KiB

  1. from cryptography.hazmat.backends import default_backend
  2. from cryptography.hazmat.primitives import hashes
  3. from cryptography.hazmat.primitives import serialization
  4. from cryptography.hazmat.primitives.asymmetric import x448
  5. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  6. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  7. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  8. from noise.connection import NoiseConnection, Keypair
  9. #import tracemalloc; tracemalloc.start()
  10. import argparse
  11. import asyncio
  12. import base64
  13. import os.path
  14. import shutil
  15. import socket
  16. import sys
  17. import tempfile
  18. import time
  19. import threading
  20. import unittest
  21. _backend = default_backend()
  22. def loadprivkey(fname):
  23. with open(fname, encoding='ascii') as fp:
  24. data = fp.read().encode('ascii')
  25. key = load_pem_private_key(data, password=None, backend=default_backend())
  26. return key
  27. def loadprivkeyraw(fname):
  28. key = loadprivkey(fname)
  29. enc = serialization.Encoding.Raw
  30. privformat = serialization.PrivateFormat.Raw
  31. encalgo = serialization.NoEncryption()
  32. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  33. def loadpubkeyraw(fname):
  34. with open(fname, encoding='ascii') as fp:
  35. lines = fp.readlines()
  36. # XXX
  37. #self.assertEqual(len(lines), 1)
  38. keytype, keyvalue = lines[0].split()
  39. if keytype != 'ntun-x448':
  40. raise RuntimeError
  41. return base64.urlsafe_b64decode(keyvalue)
  42. def genkeypair():
  43. '''Generates a keypair, and returns a tuple of (public, private).
  44. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  45. key = x448.X448PrivateKey.generate()
  46. enc = serialization.Encoding.Raw
  47. pubformat = serialization.PublicFormat.Raw
  48. privformat = serialization.PrivateFormat.Raw
  49. encalgo = serialization.NoEncryption()
  50. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  51. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  52. return pub, priv
  53. def _makefut(obj):
  54. loop = asyncio.get_running_loop()
  55. fut = loop.create_future()
  56. fut.set_result(obj)
  57. return fut
  58. def _makeunix(path):
  59. '''Make a properly formed unix path socket string.'''
  60. return 'unix:%s' % path
  61. def _parsesockstr(sockstr):
  62. proto, rem = sockstr.split(':', 1)
  63. return proto, rem
  64. async def connectsockstr(sockstr):
  65. proto, rem = _parsesockstr(sockstr)
  66. reader, writer = await asyncio.open_unix_connection(rem)
  67. return reader, writer
  68. async def listensockstr(sockstr, cb):
  69. '''Wrapper for asyncio.start_x_server.
  70. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  71. If the proto has a default parameter, the value can be used
  72. directly, like: 'proto:value'. This is only allowed when the
  73. value can unambiguously be determined not to be a param.
  74. The characters that define 'param' must be all lower case ascii
  75. characters and may contain an underscore. The first character
  76. must not be and underscore.
  77. Supported protocols:
  78. unix:
  79. Default parameter is path.
  80. The path parameter specifies the path to the
  81. unix domain socket. The path MUST start w/ a
  82. slash if it is used as a default parameter.
  83. '''
  84. proto, rem = _parsesockstr(sockstr)
  85. server = await asyncio.start_unix_server(cb, path=rem)
  86. return server
  87. # !!python makemessagelengths.py
  88. _handshakelens = \
  89. [72, 72, 88]
  90. def _genciphfun(hash, ad):
  91. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  92. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  93. key = hkdf.derive(hash)
  94. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  95. backend=_backend)
  96. enctor = cipher.encryptor()
  97. def encfun(data):
  98. # Returns the two bytes for length
  99. val = len(data)
  100. encbytes = enctor.update(data[:16])
  101. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  102. return (val ^ mask).to_bytes(length=2, byteorder='big')
  103. def decfun(data):
  104. # takes off the data and returns the total
  105. # length
  106. val = int.from_bytes(data[:2], byteorder='big')
  107. encbytes = enctor.update(data[2:2 + 16])
  108. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  109. return val ^ mask
  110. return encfun, decfun
  111. async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
  112. rdr, wrr = await rdrwrr
  113. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  114. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  115. if pub_key is not None:
  116. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  117. pub_key)
  118. if mode == 'resp':
  119. proto.set_as_responder()
  120. proto.start_handshake()
  121. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  122. wrr.write(proto.write_message())
  123. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  124. elif mode == 'init':
  125. proto.set_as_initiator()
  126. proto.start_handshake()
  127. wrr.write(proto.write_message())
  128. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  129. wrr.write(proto.write_message())
  130. if not proto.handshake_finished: # pragma: no cover
  131. raise RuntimeError('failed to finish handshake')
  132. # generate the keys for lengths
  133. # XXX - get_handshake_hash is probably not the best option, but
  134. # this is only to obscure lengths, it is not required to be secure
  135. # as the underlying NoiseProtocol securely validates everything.
  136. # It is marginally useful as writing patterns likely expose the
  137. # true length. Adding padding could marginally help w/ this.
  138. if mode == 'resp':
  139. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  140. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  141. elif mode == 'init':
  142. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  143. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  144. reader, writer = await ptpair
  145. async def decses():
  146. try:
  147. while True:
  148. try:
  149. msg = await rdr.readexactly(2 + 16)
  150. except asyncio.streams.IncompleteReadError:
  151. if rdr.at_eof():
  152. return 'dec'
  153. tlen = declenfun(msg)
  154. rmsg = await rdr.readexactly(tlen - 16)
  155. tmsg = msg[2:] + rmsg
  156. writer.write(proto.decrypt(tmsg))
  157. await writer.drain()
  158. #except:
  159. # import traceback
  160. # traceback.print_exc()
  161. # raise
  162. finally:
  163. try:
  164. writer.write_eof()
  165. except OSError as e:
  166. if e.errno != 57:
  167. raise
  168. async def encses():
  169. try:
  170. while True:
  171. # largest message
  172. ptmsg = await reader.read(65535 - 16)
  173. if not ptmsg:
  174. # eof
  175. return 'enc'
  176. encmsg = proto.encrypt(ptmsg)
  177. wrr.write(enclenfun(encmsg))
  178. wrr.write(encmsg)
  179. await wrr.drain()
  180. #except:
  181. # import traceback
  182. # traceback.print_exc()
  183. # raise
  184. finally:
  185. wrr.write_eof()
  186. return await asyncio.gather(decses(), encses())
  187. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  188. # Slightly modified to timeout and to print trace back when canceled.
  189. # This makes it easier to figure out what "froze".
  190. def async_test(f):
  191. def wrapper(*args, **kwargs):
  192. async def tbcapture():
  193. try:
  194. return await f(*args, **kwargs)
  195. except asyncio.CancelledError as e:
  196. # if we are going to be cancelled, print out a tb
  197. import traceback
  198. traceback.print_exc()
  199. raise
  200. loop = asyncio.get_event_loop()
  201. # timeout after 4 seconds
  202. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  203. return wrapper
  204. class Tests_misc(unittest.TestCase):
  205. def test_listensockstr(self):
  206. # XXX write test
  207. pass
  208. def test_genciphfun(self):
  209. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  210. msg = b'this is a bunch of data'
  211. tb = enc(msg)
  212. self.assertEqual(len(msg), dec(tb + msg))
  213. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  214. msg = os.urandom(i)
  215. self.assertEqual(len(msg), dec(enc(msg) + msg))
  216. def cmd_client(args):
  217. privkey = loadprivkeyraw(args.clientkey)
  218. pubkey = loadpubkeyraw(args.servkey)
  219. async def runnf(rdr, wrr):
  220. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  221. a = await NoiseForwarder('init',
  222. encpair, _makefut((rdr, wrr)),
  223. priv_key=privkey, pub_key=pubkey)
  224. # Setup client listener
  225. ssock = listensockstr(args.clientlisten, runnf)
  226. loop = asyncio.get_event_loop()
  227. obj = loop.run_until_complete(ssock)
  228. loop.run_until_complete(obj.serve_forever())
  229. def cmd_server(args):
  230. privkey = loadprivkeyraw(args.servkey)
  231. async def runnf(rdr, wrr):
  232. ptpair = asyncio.create_task(connectsockstr(args.servtarget))
  233. a = await NoiseForwarder('resp',
  234. _makefut((rdr, wrr)), ptpair,
  235. priv_key=privkey)
  236. # Setup server listener
  237. ssock = listensockstr(args.servlisten, runnf)
  238. loop = asyncio.get_event_loop()
  239. obj = loop.run_until_complete(ssock)
  240. loop.run_until_complete(obj.serve_forever())
  241. def cmd_genkey(args):
  242. keypair = genkeypair()
  243. key = x448.X448PrivateKey.generate()
  244. # public key part
  245. enc = serialization.Encoding.Raw
  246. pubformat = serialization.PublicFormat.Raw
  247. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  248. try:
  249. fname = args.fname + '.pub'
  250. with open(fname, 'x', encoding='ascii') as fp:
  251. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  252. except FileExistsError:
  253. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  254. sys.exit(2)
  255. enc = serialization.Encoding.PEM
  256. format = serialization.PrivateFormat.PKCS8
  257. encalgo = serialization.NoEncryption()
  258. with open(args.fname, 'x', encoding='ascii') as fp:
  259. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  260. def main():
  261. parser = argparse.ArgumentParser()
  262. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  263. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  264. parser_gk.add_argument('fname', type=str, help='file name for the key')
  265. parser_gk.set_defaults(func=cmd_genkey)
  266. parser_serv = subparsers.add_parser('server', help='run a server')
  267. parser_serv.add_argument('-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  268. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  269. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  270. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  271. parser_serv.set_defaults(func=cmd_server)
  272. parser_client = subparsers.add_parser('client', help='run a client')
  273. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  274. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  275. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  276. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  277. parser_client.set_defaults(func=cmd_client)
  278. args = parser.parse_args()
  279. try:
  280. fun = args.func
  281. except AttributeError:
  282. parser.print_usage()
  283. sys.exit(5)
  284. fun(args)
  285. if __name__ == '__main__': # pragma: no cover
  286. main()
  287. def _asyncsockpair():
  288. '''Create a pair of sockets that are bound to each other.
  289. The function will return a tuple of two coroutine's, that
  290. each, when await'ed upon, will return the reader/writer pair.'''
  291. socka, sockb = socket.socketpair()
  292. return asyncio.open_connection(sock=socka), \
  293. asyncio.open_connection(sock=sockb)
  294. async def _awaitfile(fname):
  295. while not os.path.exists(fname):
  296. await asyncio.sleep(.01)
  297. return True
  298. class TestMain(unittest.TestCase):
  299. def setUp(self):
  300. # setup temporary directory
  301. d = os.path.realpath(tempfile.mkdtemp())
  302. self.basetempdir = d
  303. self.tempdir = os.path.join(d, 'subdir')
  304. os.mkdir(self.tempdir)
  305. # Generate key pairs
  306. self.server_key_pair = genkeypair()
  307. self.client_key_pair = genkeypair()
  308. os.chdir(self.tempdir)
  309. def tearDown(self):
  310. #print('td:', time.time())
  311. shutil.rmtree(self.basetempdir)
  312. self.tempdir = None
  313. @async_test
  314. async def test_noargs(self):
  315. proc = await self.run_with_args()
  316. await proc.wait()
  317. # XXX - not checking error message
  318. # And that it exited w/ the correct code
  319. self.assertEqual(proc.returncode, 5)
  320. def run_with_args(self, *args, pipes=True):
  321. kwargs = {}
  322. if pipes:
  323. kwargs.update(dict(
  324. stdout=asyncio.subprocess.PIPE,
  325. stderr=asyncio.subprocess.PIPE))
  326. return asyncio.create_subprocess_exec(sys.executable,
  327. # XXX - figure out how to add coverage data on these runs
  328. #'-m', 'coverage', 'run', '-p',
  329. __file__, *args, **kwargs)
  330. async def genkey(self, name):
  331. proc = await self.run_with_args('genkey', name, pipes=False)
  332. await proc.wait()
  333. self.assertEqual(proc.returncode, 0)
  334. @async_test
  335. async def test_loadpubkey(self):
  336. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  337. await self.genkey(keypath)
  338. privkey = loadprivkey(keypath)
  339. enc = serialization.Encoding.Raw
  340. pubformat = serialization.PublicFormat.Raw
  341. pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
  342. pubkey = loadpubkeyraw(keypath + '.pub')
  343. self.assertEqual(pubkeybytes, pubkey)
  344. privrawkey = loadprivkeyraw(keypath)
  345. enc = serialization.Encoding.Raw
  346. privformat = serialization.PrivateFormat.Raw
  347. encalgo = serialization.NoEncryption()
  348. rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  349. self.assertEqual(rprivrawkey, privrawkey)
  350. @async_test
  351. async def test_end2end(self):
  352. # Generate necessar keys
  353. servkeypath = os.path.join(self.tempdir, 'server_key')
  354. await self.genkey(servkeypath)
  355. clientkeypath = os.path.join(self.tempdir, 'client_key')
  356. await self.genkey(clientkeypath)
  357. await asyncio.sleep(.1)
  358. #import pdb; pdb.set_trace()
  359. # forwards connectsion to this socket (created by client)
  360. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  361. ptclientstr = _makeunix(ptclientpath)
  362. # this is the socket server listen to
  363. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  364. incservstr = _makeunix(incservpath)
  365. # to this socket, opened by server
  366. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  367. servtargstr = _makeunix(servtargpath)
  368. # Setup server target listener
  369. ptsock = []
  370. ptsockevent = asyncio.Event()
  371. def ptsockaccept(reader, writer, ptsock=ptsock):
  372. ptsock.append((reader, writer))
  373. ptsockevent.set()
  374. # Bind to pt listener
  375. lsock = await listensockstr(servtargstr, ptsockaccept)
  376. # Startup the server
  377. server = await self.run_with_args('server',
  378. '-c', clientkeypath + '.pub',
  379. servkeypath, incservstr, servtargstr,
  380. pipes=False)
  381. # Startup the client
  382. client = await self.run_with_args('client',
  383. clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
  384. pipes=False)
  385. # wait for server target to be created
  386. await _awaitfile(servtargpath)
  387. # wait for server to start
  388. await _awaitfile(incservpath)
  389. # wait for client to start
  390. await _awaitfile(ptclientpath)
  391. await asyncio.sleep(.1)
  392. # Connect to the client
  393. reader, writer = await connectsockstr(ptclientstr)
  394. # send a message
  395. ptmsg = b'this is a message for testing'
  396. writer.write(ptmsg)
  397. # make sure that we got the conenction
  398. await ptsockevent.wait()
  399. # get the connection
  400. endrdr, endwrr = ptsock[0]
  401. # make sure we can read back what we sent
  402. self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
  403. # test some additional messages
  404. for i in [ 129, 1287, 28792, 129872 ]:
  405. # in on direction
  406. msg = os.urandom(i)
  407. writer.write(msg)
  408. self.assertEqual(msg, await endrdr.readexactly(len(msg)))
  409. # and the other
  410. endwrr.write(msg)
  411. self.assertEqual(msg, await reader.readexactly(len(msg)))
  412. @async_test
  413. async def test_genkey(self):
  414. # that it can generate a key
  415. proc = await self.run_with_args('genkey', 'somefile')
  416. await proc.wait()
  417. #print(await proc.communicate())
  418. self.assertEqual(proc.returncode, 0)
  419. with open('somefile.pub', encoding='ascii') as fp:
  420. lines = fp.readlines()
  421. self.assertEqual(len(lines), 1)
  422. keytype, keyvalue = lines[0].split()
  423. self.assertEqual(keytype, 'ntun-x448')
  424. key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
  425. key = loadprivkey('somefile')
  426. self.assertIsInstance(key, x448.X448PrivateKey)
  427. # that a second call fails
  428. proc = await self.run_with_args('genkey', 'somefile')
  429. await proc.wait()
  430. stdoutdata, stderrdata = await proc.communicate()
  431. self.assertFalse(stdoutdata)
  432. self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
  433. # And that it exited w/ the correct code
  434. self.assertEqual(proc.returncode, 2)
  435. class TestNoiseFowarder(unittest.TestCase):
  436. def setUp(self):
  437. # setup temporary directory
  438. d = os.path.realpath(tempfile.mkdtemp())
  439. self.basetempdir = d
  440. self.tempdir = os.path.join(d, 'subdir')
  441. os.mkdir(self.tempdir)
  442. # Generate key pairs
  443. self.server_key_pair = genkeypair()
  444. self.client_key_pair = genkeypair()
  445. def tearDown(self):
  446. shutil.rmtree(self.basetempdir)
  447. self.tempdir = None
  448. @async_test
  449. async def test_server(self):
  450. # Test is plumbed:
  451. # (reader, writer) -> servsock ->
  452. # (rdr, wrr) NoiseForward (reader, writer) ->
  453. # servptsock -> (ptsock[0], ptsock[1])
  454. # Path that the server will sit on
  455. servsockpath = os.path.join(self.tempdir, 'servsock')
  456. servarg = _makeunix(servsockpath)
  457. # Path that the server will send pt data to
  458. servptpath = os.path.join(self.tempdir, 'servptsock')
  459. # Setup pt target listener
  460. pttarg = _makeunix(servptpath)
  461. ptsock = []
  462. def ptsockaccept(reader, writer, ptsock=ptsock):
  463. ptsock.append((reader, writer))
  464. # Bind to pt listener
  465. lsock = await listensockstr(pttarg, ptsockaccept)
  466. nfs = []
  467. event = asyncio.Event()
  468. async def runnf(rdr, wrr):
  469. ptpair = asyncio.create_task(connectsockstr(pttarg))
  470. a = await NoiseForwarder('resp',
  471. _makefut((rdr, wrr)), ptpair,
  472. priv_key=self.server_key_pair[1])
  473. nfs.append(a)
  474. event.set()
  475. # Setup server listener
  476. ssock = await listensockstr(servarg, runnf)
  477. # Connect to server
  478. reader, writer = await connectsockstr(servarg)
  479. # Create client
  480. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  481. proto.set_as_initiator()
  482. # Setup required keys
  483. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  484. self.client_key_pair[1])
  485. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  486. self.server_key_pair[0])
  487. proto.start_handshake()
  488. # Send first message
  489. message = proto.write_message()
  490. self.assertEqual(len(message), _handshakelens[0])
  491. writer.write(message)
  492. # Get response
  493. respmsg = await reader.readexactly(_handshakelens[1])
  494. proto.read_message(respmsg)
  495. # Send final reply
  496. message = proto.write_message()
  497. writer.write(message)
  498. # Make sure handshake has completed
  499. self.assertTrue(proto.handshake_finished)
  500. # generate the keys for lengths
  501. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  502. b'toresp')
  503. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  504. b'toinit')
  505. # write a test message
  506. ptmsg = b'this is a test message that should be a little in length'
  507. encmsg = proto.encrypt(ptmsg)
  508. writer.write(enclenfun(encmsg))
  509. writer.write(encmsg)
  510. # XXX - how to sync?
  511. await asyncio.sleep(.1)
  512. ptreader, ptwriter = ptsock[0]
  513. # read the test message
  514. rptmsg = await ptreader.readexactly(len(ptmsg))
  515. self.assertEqual(rptmsg, ptmsg)
  516. # write a different message
  517. ptmsg = os.urandom(2843)
  518. encmsg = proto.encrypt(ptmsg)
  519. writer.write(enclenfun(encmsg))
  520. writer.write(encmsg)
  521. # XXX - how to sync?
  522. await asyncio.sleep(.1)
  523. # read the test message
  524. rptmsg = await ptreader.readexactly(len(ptmsg))
  525. self.assertEqual(rptmsg, ptmsg)
  526. # now try the other way
  527. ptmsg = os.urandom(912)
  528. ptwriter.write(ptmsg)
  529. # find out how much we need to read
  530. encmsg = await reader.readexactly(2 + 16)
  531. tlen = declenfun(encmsg)
  532. # read the rest of the message
  533. rencmsg = await reader.readexactly(tlen - 16)
  534. tmsg = encmsg[2:] + rencmsg
  535. rptmsg = proto.decrypt(tmsg)
  536. self.assertEqual(rptmsg, ptmsg)
  537. # shut down sending
  538. writer.write_eof()
  539. # so pt reader should be shut down
  540. self.assertEqual(b'', await ptreader.read(1))
  541. self.assertTrue(ptreader.at_eof())
  542. # shut down pt
  543. ptwriter.write_eof()
  544. # make sure the enc reader is eof
  545. self.assertEqual(b'', await reader.read(1))
  546. self.assertTrue(reader.at_eof())
  547. await event.wait()
  548. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  549. @async_test
  550. async def test_serverclient(self):
  551. # plumbing:
  552. #
  553. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  554. #
  555. ptcsockapair, ptcsockbpair = _asyncsockpair()
  556. ptcareader, ptcawriter = await ptcsockapair
  557. #ptcsockbpair passed directly
  558. clssockapair, clssockbpair = _asyncsockpair()
  559. #both passed directly
  560. ptssockapair, ptssockbpair = _asyncsockpair()
  561. #ptssockapair passed directly
  562. ptsbreader, ptsbwriter = await ptssockbpair
  563. clientnf = asyncio.create_task(NoiseForwarder('init',
  564. clssockapair, ptcsockbpair,
  565. priv_key=self.client_key_pair[1],
  566. pub_key=self.server_key_pair[0]))
  567. servnf = asyncio.create_task(NoiseForwarder('resp',
  568. clssockbpair, ptssockapair,
  569. priv_key=self.server_key_pair[1]))
  570. # send a message
  571. msga = os.urandom(183)
  572. ptcawriter.write(msga)
  573. # make sure we get the same message
  574. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  575. # send a second message
  576. msga = os.urandom(2834)
  577. ptcawriter.write(msga)
  578. # make sure we get the same message
  579. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  580. # send a message larger than the block size
  581. msga = os.urandom(103958)
  582. ptcawriter.write(msga)
  583. # make sure we get the same message
  584. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  585. # send a message the other direction
  586. msga = os.urandom(103958)
  587. ptsbwriter.write(msga)
  588. # make sure we get the same message
  589. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  590. # close down the pt writers, the rest should follow
  591. ptsbwriter.write_eof()
  592. ptcawriter.write_eof()
  593. # make sure they are closed, and there is no more data
  594. self.assertEqual(b'', await ptsbreader.read(1))
  595. self.assertTrue(ptsbreader.at_eof())
  596. self.assertEqual(b'', await ptcareader.read(1))
  597. self.assertTrue(ptcareader.at_eof())
  598. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  599. self.assertEqual([ 'dec', 'enc' ], await servnf)