Verified Commit 8ffc5d5a authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Reuse _TaskRunner.channels to map channels to filters

This will be of use in an upcoming change…
parent 533476be
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ class _TaskRunner:
	def __init__(self, tasks: anyio.abc.TaskGroup):
		self.tasks = tasks
		self.filters = list[tuple[Filter, Session]]()
		self.channels = list[MessageChannel]()
		self.channels = dict[MessageChannel, Filter]()

	async def __aenter__(self) -> Self:
		return self
@@ -205,10 +205,10 @@ class _TaskRunner:
		final: ResponseMessage = Accept()
		for flter, session in self.filters:
			lchannel, rchannel = _make_message_channel()
			self.channels.append(lchannel)
			self.channels[lchannel] = flter
			match await self.tasks.start(self._runner, flter, session, rchannel, use_skip):
				case Accept():
					self.channels.remove(lchannel)
					del self.channels[lchannel]
				case Continue():
					continue
				case TemporaryFailure() as final:  # replaces final
@@ -235,7 +235,7 @@ class _TaskRunner:

	async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage|Skip:
		skip = isinstance(message, Body)
		for channel in self.channels:
		for channel in list(self.channels):
			await channel.send(message)
			match (await channel.receive()):
				case Skip():
@@ -244,7 +244,7 @@ class _TaskRunner:
					skip = False
				case Accept():
					await channel.aclose()
					self.channels.remove(channel)
					del self.channels[channel]
				case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp:
					return resp
		return (
@@ -260,12 +260,12 @@ class _TaskRunner:
			await channel.send(abort)
			await channel.receive()
			await channel.aclose()
		del self.channels[:]
		self.channels.clear()

	async def aclose(self) -> None:
		_logger.info("Closing runners")
		self.tasks.cancel_scope.cancel()
		del self.channels[:]
		self.channels.clear()

	@staticmethod
	async def _runner(