Commit 1c98176e authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add hooks and aclose to Broadcast

This is preparatory work for runners, which need to inject
syncronisation into the Broadcast flow to prevent race conditions when
collecting responses returned by filters.
parent 832d708f
Loading
Loading
Loading
Loading
+12 −1
Original line number Diff line number Diff line
@@ -134,17 +134,21 @@ class Session:
	The kernel of a filter, providing an API for filters to access messages from and MTA
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Session")

	def __init__(
		self,
		connmsg: Connect,
		sender: AsyncGenerator[None, EditMessage],
		broadcast: util.Broadcast[EventMessage]|None = None,
	):
		self.host = connmsg.hostname
		self.address = connmsg.address
		self.port = connmsg.port

		self._editor = sender
		self._broadcast = util.Broadcast[EventMessage]()
		self._broadcast = broadcast or util.Broadcast[EventMessage]()

		self.headers = HeadersAccessor(self, sender)
		self.body = BodyAccessor(self, sender)
@@ -155,6 +159,13 @@ class Session:
		# so some phases will be skipped; checks should not try to exactly match a phase.
		self.phase = Phase.CONNECT

	async def __aenter__(self: Self) -> Self:
		return self

	async def __aexit__(self, *_: object) -> None:
		# on session close, wake up any remaining deliver() awaitables
		await self._broadcast.aclose()

	async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]:
		"""
		Deliver a message (or its contents) to a task waiting for it
+22 −0
Original line number Diff line number Diff line
@@ -33,6 +33,24 @@ class Broadcast(anyio.Condition, Generic[T]):
		super().__init__()
		self.obj: Optional[T] = None

	async def pre_receive_hook(self) -> None:
		"""
		A hook for subclasses to inject synchronisation instructions before awaiting objects
		"""  # noqa: D401

	async def post_send_hook(self) -> None:
		"""
		A hook for subclasses to inject synchronisation instructions after sending objects
		"""  # noqa: D401

	async def aclose(self) -> None:
		"""
		A hook for subclasses to inject cleanup or synchronisation instructions on close

		Users must ensure this method is called, especially if using a subclass which
		implements it.
		"""  # noqa: D401

	async def send(self, obj: T) -> None:
		"""
		Send a message object and block until all listeners have received it
@@ -41,6 +59,9 @@ class Broadcast(anyio.Condition, Generic[T]):
			self.obj = obj
			self.notify_all()
		await anyio.sleep(0.0)  # ensure listeners have opportunity to wait for locks
		await self.post_send_hook()

		# Ensure all listeners have had a chance to lock and process self.obj
		while 1:
			async with self:
				if self.statistics().lock_statistics.tasks_waiting:
@@ -52,6 +73,7 @@ class Broadcast(anyio.Condition, Generic[T]):
		"""
		Listen for a single message and return it once it arrives
		"""
		await self.pre_receive_hook()
		await self.wait()
		assert self.obj is not None
		return self.obj