Loading kilter/service/util.py +19 −2 Original line number Diff line number Diff line Loading @@ -32,6 +32,7 @@ class Broadcast(anyio.Condition, Generic[T]): def __init__(self) -> None: super().__init__() self.obj: Optional[T] = None self.exc: Optional[BaseException|type[BaseException]] = None async def pre_receive_hook(self) -> None: """ Loading @@ -51,22 +52,36 @@ class Broadcast(anyio.Condition, Generic[T]): implements it. """ # noqa: D401 async def abort(self, exc: BaseException|type[BaseException]) -> None: """ Send a notification to all listeners to abort by raising an exception """ async with self: assert self.exc is None and self.obj is None self.exc = exc self.notify_all() await self._post() async def send(self, obj: T) -> None: """ Send a message object and block until all listeners have received it """ async with self: assert self.exc is None and self.obj is None self.obj = obj self.notify_all() await self._post() async def _post(self) -> None: await anyio.sleep(0.0) # ensure listeners have opportunity to wait for locks await self.post_send_hook() # Ensure all listeners have had a chance to lock and process self.obj while 1: async with self: if self.statistics().lock_statistics.tasks_waiting: if self.statistics().lock_statistics.tasks_waiting: # pragma: no-branch continue self.obj = None self.obj = self.exc = None break async def receive(self) -> T: Loading @@ -75,5 +90,7 @@ class Broadcast(anyio.Condition, Generic[T]): """ await self.pre_receive_hook() await self.wait() if self.exc is not None: raise self.exc assert self.obj is not None return self.obj tests/test_broadcast.py +18 −0 Original line number Diff line number Diff line Loading @@ -85,3 +85,21 @@ class BroadcastTests(AsyncTestCase): await broadcast.send(n) assert messages == [1, 1, 2, 2, 3, 3, 4, 4] async def test_abort(self) -> None: """ Check that aborting with multiple listeners works """ broadcast = Broadcast[int]() async def listener() -> None: async with broadcast: with self.assertRaises(ValueError): _ = await broadcast.receive() async with trio.open_nursery() as task_group: task_group.start_soon(listener) task_group.start_soon(listener) await trio.testing.wait_all_tasks_blocked() await broadcast.abort(ValueError) Loading
kilter/service/util.py +19 −2 Original line number Diff line number Diff line Loading @@ -32,6 +32,7 @@ class Broadcast(anyio.Condition, Generic[T]): def __init__(self) -> None: super().__init__() self.obj: Optional[T] = None self.exc: Optional[BaseException|type[BaseException]] = None async def pre_receive_hook(self) -> None: """ Loading @@ -51,22 +52,36 @@ class Broadcast(anyio.Condition, Generic[T]): implements it. """ # noqa: D401 async def abort(self, exc: BaseException|type[BaseException]) -> None: """ Send a notification to all listeners to abort by raising an exception """ async with self: assert self.exc is None and self.obj is None self.exc = exc self.notify_all() await self._post() async def send(self, obj: T) -> None: """ Send a message object and block until all listeners have received it """ async with self: assert self.exc is None and self.obj is None self.obj = obj self.notify_all() await self._post() async def _post(self) -> None: await anyio.sleep(0.0) # ensure listeners have opportunity to wait for locks await self.post_send_hook() # Ensure all listeners have had a chance to lock and process self.obj while 1: async with self: if self.statistics().lock_statistics.tasks_waiting: if self.statistics().lock_statistics.tasks_waiting: # pragma: no-branch continue self.obj = None self.obj = self.exc = None break async def receive(self) -> T: Loading @@ -75,5 +90,7 @@ class Broadcast(anyio.Condition, Generic[T]): """ await self.pre_receive_hook() await self.wait() if self.exc is not None: raise self.exc assert self.obj is not None return self.obj
tests/test_broadcast.py +18 −0 Original line number Diff line number Diff line Loading @@ -85,3 +85,21 @@ class BroadcastTests(AsyncTestCase): await broadcast.send(n) assert messages == [1, 1, 2, 2, 3, 3, 4, 4] async def test_abort(self) -> None: """ Check that aborting with multiple listeners works """ broadcast = Broadcast[int]() async def listener() -> None: async with broadcast: with self.assertRaises(ValueError): _ = await broadcast.receive() async with trio.open_nursery() as task_group: task_group.start_soon(listener) task_group.start_soon(listener) await trio.testing.wait_all_tasks_blocked() await broadcast.abort(ValueError)