Verified Commit 5f09425f authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Support no-response filters

parent 014b2bd6
Loading
Loading
Loading
Loading
+37 −5
Original line number Diff line number Diff line
@@ -57,6 +57,12 @@ class NegotiationError(Exception):
	"""


class _CloseFilter:

	def __init__(self, filtr: Filter):
		self.filter = filtr


class _Broadcast(Broadcast[EventMessage]):

	def __init__(self) -> None:
@@ -83,7 +89,7 @@ class Runner:
	def __init__(self, *filters: Filter):
		if len(filters) == 0:  # pragma: no-cover
			raise TypeError("Runner requires at least one filter to run")
		self.filters = filters
		self.filters = list(filters)
		self.use_skip = True

	async def __call__(self, client: anyio.abc.ByteStream) -> None:
@@ -138,7 +144,17 @@ class Runner:
							if aborted:
								aborted = False
								await runner.start(False, self.use_skip)
							await sender.asend(await runner.message_events(message))
							needs_response = proto.needs_response(message)
							match await runner.message_events(message, needs_response):
								case None:
									assert not needs_response
								case _CloseFilter() as notif:
									self.filters.remove(notif.filter)
								case resp if needs_response:
									assert resp is not None and not isinstance(resp, _CloseFilter)
									await sender.asend(resp)
								case resp:
									raise RuntimeError(f"unexpected response: {resp}")

	async def _negotiate(self, message: Negotiate) -> Negotiate:
		_logger.info("Negotiating with MTA")
@@ -233,7 +249,11 @@ class _TaskRunner:
			for _, session in self.filters:
				await session.deliver(message)

	async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage|Skip:
	async def message_events(
		self,
		message: _VALID_EVENT_MESSAGE,
		needs_response: bool,
	) -> ResponseMessage|Skip|_CloseFilter|None:
		skip = isinstance(message, Body)
		for channel in list(self.channels):
			await channel.send(message)
@@ -243,16 +263,28 @@ class _TaskRunner:
				case Continue():
					skip = False
				case Accept():
					await channel.aclose()
					del self.channels[channel]
					await self.close_channel(channel)
				case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp:
					await self.close_channel(channel)
					if not needs_response:
						filtr = self.channels[channel]
						_logger.warning(
							f"Unexpected response from filter {self.channels[channel]}",
						)
						return _CloseFilter(filtr)
					return resp
		if not needs_response:
			return None
		return (
			Accept() if len(self.channels) == 0 else
			Skip() if skip else
			Continue()
		)

	async def close_channel(self, channel: MessageChannel) -> None:
		await channel.aclose()
		del self.channels[channel]

	async def abort(self, abort: Abort) -> None:
		if self.channels:
			_logger.info("Aborting filters")