Commit 4f666624 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add Broadcast.abort()

Closes #7
parent 36471d3d
Loading
Loading
Loading
Loading
+19 −2
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ class Broadcast(anyio.Condition, Generic[T]):
	def __init__(self) -> None:
		super().__init__()
		self.obj: Optional[T] = None
		self.exc: Optional[BaseException|type[BaseException]] = None

	async def pre_receive_hook(self) -> None:
		"""
@@ -51,22 +52,36 @@ class Broadcast(anyio.Condition, Generic[T]):
		implements it.
		"""  # noqa: D401

	async def abort(self, exc: BaseException|type[BaseException]) -> None:
		"""
		Send a notification to all listeners to abort by raising an exception
		"""
		async with self:
			assert self.exc is None and self.obj is None
			self.exc = exc
			self.notify_all()
		await self._post()

	async def send(self, obj: T) -> None:
		"""
		Send a message object and block until all listeners have received it
		"""
		async with self:
			assert self.exc is None and self.obj is None
			self.obj = obj
			self.notify_all()
		await self._post()

	async def _post(self) -> None:
		await anyio.sleep(0.0)  # ensure listeners have opportunity to wait for locks
		await self.post_send_hook()

		# Ensure all listeners have had a chance to lock and process self.obj
		while 1:
			async with self:
				if self.statistics().lock_statistics.tasks_waiting:
				if self.statistics().lock_statistics.tasks_waiting:  # pragma: no-branch
					continue
				self.obj = None
				self.obj = self.exc = None
				break

	async def receive(self) -> T:
@@ -75,5 +90,7 @@ class Broadcast(anyio.Condition, Generic[T]):
		"""
		await self.pre_receive_hook()
		await self.wait()
		if self.exc is not None:
			raise self.exc
		assert self.obj is not None
		return self.obj
+18 −0
Original line number Diff line number Diff line
@@ -85,3 +85,21 @@ class BroadcastTests(AsyncTestCase):
				await broadcast.send(n)

		assert messages == [1, 1, 2, 2, 3, 3, 4, 4]

	async def test_abort(self) -> None:
		"""
		Check that aborting with multiple listeners works
		"""
		broadcast = Broadcast[int]()

		async def listener() -> None:
			async with broadcast:
				with self.assertRaises(ValueError):
					_ = await broadcast.receive()

		async with trio.open_nursery() as task_group:
			task_group.start_soon(listener)
			task_group.start_soon(listener)
			await trio.testing.wait_all_tasks_blocked()

			await broadcast.abort(ValueError)