Commit 62cdc973 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Inject TaskStatus sync into Session via Broadcast hooks

This allows synchronisation of any early response, before broadcast
messages are awaited.
parent 0478f3b8
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -57,12 +57,15 @@ class _Broadcast(Broadcast[EventMessage]):
	def __init__(self) -> None:
		super().__init__()
		self._ready = anyio.Condition()
		self.task_status: anyio.abc.TaskStatus|None = None

	async def shutdown_hook(self) -> None:
		async with self._ready:
			self._ready.notify_all()
		await self.pre_receive_hook()

	async def pre_receive_hook(self) -> None:
		if self.task_status is not None:
			self.task_status.started()
			self.task_status = None
		async with self._ready:
			self._ready.notify_all()

@@ -224,7 +227,8 @@ async def _runner(
	) -> None:
		nonlocal final_resp
		async with session:
			task_status.started()
			assert isinstance(session.broadcast, _Broadcast)
			session.broadcast.task_status = task_status
			final_resp = await fltr(session)
		if not isinstance(final_resp, _VALID_FINAL_RESPONSES):
			warn(f"expected a final response from {fltr}, got {final_resp}")
@@ -245,8 +249,6 @@ async def _runner(
			assert isinstance(message, _VALID_EVENT_MESSAGE)
			resp = await session.deliver(message)
			if final_resp is not None:
				await channel.send(final_resp)  # type: ignore
			elif use_skip and resp == Skip:
				await channel.send(Skip())
			else:
				await channel.send(Continue())
				break  # type: ignore
			await channel.send(Skip() if use_skip and resp == Skip else Continue())
		await channel.send(final_resp)