diff --git a/solardash/__init__.py b/solardash/__init__.py index 93563e6..3bb5956 100644 --- a/solardash/__init__.py +++ b/solardash/__init__.py @@ -38,6 +38,31 @@ logging.basicConfig(level=logging.DEBUG) from aiohttp import web from RainEagle.parse import LogDir as RELogDir, _cmaiter +class StickyChannel(object): + def __init__(self): + self._queues = set() + self._lastvalue = None + + async def post(self, value): + self._lastvalue = value + + for i in self._queues: + await i.put(value) + + async def __aiter__(self): + q = asyncio.Queue() + + try: + self._queues.add(q) + + if self._lastvalue is not None: + yield self._lastvalue + + while True: + yield await q.get() + finally: + self._queues.remove(q) + class SolarDataWS(object): def __init__(self, reprefix): self._raineagle = RELogDir(reprefix) @@ -114,6 +139,103 @@ def async_test(f): return wrapper +class MiscTest(unittest.TestCase): + @async_test + async def test_stickychannel_waitfirst(self): + chan = StickyChannel() + + loop = asyncio.get_event_loop() + + order = [] + dataval = 'data' + async def postmsg(): + await asyncio.sleep(.02) + order.append('a') + await chan.post(dataval) + order.append('b') + + postmsgtask = loop.create_task(postmsg()) + + async with _cmaiter(chan.__aiter__()) as chaniter: + order.append('c') + val = await chaniter.__anext__() + order.append('d') + + self.assertEqual(val, dataval) + + self.assertEqual(order, [ 'c', 'a', 'b', 'd' ]) + + await postmsgtask + + @async_test + async def test_stickychannel_multiple(self): + chan = StickyChannel() + + loop = asyncio.get_event_loop() + + dataval = 'data' + dataval2 = 'data2' + async def postmsg(): + await asyncio.sleep(.02) + await chan.post(dataval) + await asyncio.sleep(.01) + await chan.post(dataval2) + + postmsgtask = loop.create_task(postmsg()) + + async with _cmaiter(chan.__aiter__()) as chaniter, \ + _cmaiter(chan.__aiter__()) as chaniter2: + val = await chaniter.__anext__() + val2 = await chaniter2.__anext__() + + await asyncio.sleep(.01) + + val3 = await chaniter.__anext__() + val4 = await chaniter2.__anext__() + + async with _cmaiter(chan.__aiter__()) as chaniter3: + val5 = await chaniter3.__anext__() + + self.assertEqual(val, dataval) + self.assertEqual(val2, dataval) + self.assertEqual(val3, dataval2) + self.assertEqual(val4, dataval2) + self.assertEqual(val5, dataval2) + + await postmsgtask + + @async_test + async def test_stickychannel_valuefirst(self): + chan = StickyChannel() + + loop = asyncio.get_event_loop() + + await chan.post(5) + + order = [] + dataval = 'data' + async def postmsg(): + await asyncio.sleep(.02) + order.append('a') + await chan.post(dataval) + order.append('b') + + postmsgtask = loop.create_task(postmsg()) + + async with _cmaiter(chan.__aiter__()) as chaniter: + order.append('c') + val = await chaniter.__anext__() + order.append('d') + val2 = await chaniter.__anext__() + order.append('e') + + self.assertEqual(val, 5) + self.assertEqual(val2, dataval) + + self.assertEqual(order, [ 'c', 'd', 'a', 'b', 'e' ]) + + await postmsgtask + class Test(unittest.TestCase): def setUp(self): # setup temporary directory