|
- import asyncio
- import random
- import socket
- import struct
- import unittest
-
- from util import *
-
- # This function based upon code from:
- # https://gist.github.com/petrdvor/e802bec72e78ace061ab9d4469418fae#file-async-multicast-receiver-server-py-L54-L72
- def make_multisock(maddr):
- # family, type, proto, ??, addr)
- addrinfo = socket.getaddrinfo(*maddr, type=socket.SOCK_DGRAM)[0]
-
- sock = socket.socket(*addrinfo[:2])
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-
- sock.bind(maddr)
-
- group_bin = socket.inet_pton(addrinfo[0], addrinfo[4][0])
- mreq = group_bin + struct.pack('=I', socket.INADDR_ANY)
- sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
-
- return sock
-
- class StupidProtocol(object):
- def __init__(self):
- self.transport = None
-
- def close(self):
- return self.transport.close()
-
- def connection_lost(self, exc):
- self.transport = None
-
- def connection_made(self, transport):
- # Note: the connection_made call seems to be sync. This
- # isn't documented, and I don't know how to force a test
- # if it isn't.
- self.transport = transport
-
- class ReceiverProtocol(StupidProtocol):
- def __init__(self):
- super().__init__()
-
- self._q = asyncio.Queue()
-
- def datagram_received(self, data, addr):
- self._q.put_nowait((data, addr))
-
- async def recv(self):
- return await self._q.get()
-
- async def create_multicast_receiver(maddr):
- sock = make_multisock(maddr)
-
- loop = asyncio.get_running_loop()
- transport, protocol = await loop.create_datagram_endpoint(
- lambda: ReceiverProtocol(),
- sock=sock)
-
- return protocol
-
- class TransmitterProtocol(StupidProtocol):
- async def send(self, msg):
- self.transport.sendto(msg)
-
- async def create_multicast_transmitter(maddr):
- loop = asyncio.get_running_loop()
- transport, protocol = await loop.create_datagram_endpoint(
- lambda: TransmitterProtocol(),
- remote_addr=maddr)
-
- return protocol
-
- class TestMulticast(unittest.IsolatedAsyncioTestCase):
- @timeout(2)
- async def test_multicast(self):
- # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
- maddr = ('224.0.0.199', 3485)
-
- l1 = await create_multicast_receiver(maddr)
- l2 = await create_multicast_receiver(maddr)
-
- t1 = await create_multicast_transmitter(maddr)
- print('tm:', repr(t1))
-
- msg = b'test message'
-
- await t1.send(msg)
- await t1.send(msg)
-
- self.assertEqual((await l1.recv())[0], msg)
- self.assertEqual((await l2.recv())[0], msg)
-
- self.assertEqual((await l1.recv())[0], msg)
- self.assertEqual((await l2.recv())[0], msg)
-
- t1.close()
- l1.close()
- l2.close()
|