import asyncio
import random
import socket
import struct
import unittest

from util import *

# This function based upon code from:
# https://gist.github.com/petrdvor/e802bec72e78ace061ab9d4469418fae#file-async-multicast-receiver-server-py-L54-L72
def make_multisock(maddr):
	# family, type, proto, ??, addr)
	addrinfo = socket.getaddrinfo(*maddr, type=socket.SOCK_DGRAM)[0]

	sock = socket.socket(*addrinfo[:2])
	sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

	sock.bind(maddr)

	group_bin = socket.inet_pton(addrinfo[0], addrinfo[4][0])
	mreq = group_bin + struct.pack('=I', socket.INADDR_ANY)
	sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)

	return sock

class StupidProtocol(object):
	def __init__(self):
		self.transport = None

	def close(self):
		return self.transport.close()

	def connection_lost(self, exc):
		self.transport = None

	def connection_made(self, transport):
		# Note: the connection_made call seems to be sync.  This
		# isn't documented, and I don't know how to force a test
		# if it isn't.
		self.transport = transport

class ReceiverProtocol(StupidProtocol):
	def __init__(self):
		super().__init__()

		self._q = asyncio.Queue()

	def datagram_received(self, data, addr):
		self._q.put_nowait((data, addr))

	async def recv(self):
		return await self._q.get()

async def create_multicast_receiver(maddr):
	sock = make_multisock(maddr)

	loop = asyncio.get_running_loop()
	transport, protocol = await loop.create_datagram_endpoint(
	    lambda: ReceiverProtocol(),
	    sock=sock)

	return protocol

class TransmitterProtocol(StupidProtocol):
	async def send(self, msg):
		self.transport.sendto(msg)

async def create_multicast_transmitter(maddr):
	loop = asyncio.get_running_loop()
	transport, protocol = await loop.create_datagram_endpoint(
	    lambda: TransmitterProtocol(),
	    remote_addr=maddr)

	return protocol

class TestMulticast(unittest.IsolatedAsyncioTestCase):
	@timeout(2)
	async def test_multicast(self):
		# see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
		maddr = ('224.0.0.199', 3485)

		l1 = await create_multicast_receiver(maddr)
		l2 = await create_multicast_receiver(maddr)

		t1 = await create_multicast_transmitter(maddr)

		msg = b'test message'

		await t1.send(msg)
		await t1.send(msg)

		self.assertEqual((await l1.recv())[0], msg)
		self.assertEqual((await l2.recv())[0], msg)

		self.assertEqual((await l1.recv())[0], msg)
		self.assertEqual((await l2.recv())[0], msg)

		t1.close()
		l1.close()
		l2.close()