Commit 1c77387c authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Refactor runner.py to fix connection re-use

parent d6cbb346
Loading
Loading
Loading
Loading
+167 −102
Original line number Diff line number Diff line
# Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2022-2023 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -16,6 +16,9 @@ from __future__ import annotations

import logging
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from typing import TypeAlias
from typing import TypeVar
from warnings import warn

import anyio.abc
@@ -37,8 +40,8 @@ kiB = 2**10
MiB = 2**20

_VALID_FINAL_RESPONSES = Reject, Discard, Accept, TemporaryFailure, ReplyCode
_VALID_EVENT_MESSAGE = Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, \
	Header, EndOfHeaders, Body, EndOfMessage
_VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \
	Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort
_DISABLE_PROTOCOL_FLAGS = ProtocolFlags.NO_CONNECT | ProtocolFlags.NO_HELO | \
	ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | ProtocolFlags.NO_BODY | \
	ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_EOH | ProtocolFlags.NO_UNKNOWN | \
@@ -90,12 +93,15 @@ class Runner:
		buff = SimpleBuffer(1*MiB)
		proto = FilterProtocol()
		sender = _sender(client, proto)
		channels = list[MessageChannel]()
		macro: Macro|None = None

		await sender.asend(None)  # type: ignore # initialise

		async with anyio.create_task_group() as tasks, aclosing(sender), aclosing(client):
		async with (
			anyio.create_task_group() as tasks,
			aclosing(sender), aclosing(client),
			_TaskRunner(tasks) as runner,
		):
			while 1:
				try:
					buff[:] = await client.receive(buff.available)
@@ -104,49 +110,33 @@ class Runner:
					anyio.ClosedResourceError,
					anyio.BrokenResourceError,
				):
					for channel in channels:
						await channel.aclose()
					await runner.aclose()
					return
				for message in proto.read_from(buff):
					match message:
						case Negotiate():
							await self._negotiate(message, sender)
							await sender.asend(await self._negotiate(message, sender))
						case Macro() as macro:
							# Note that this Macro will hang around as "macro"; this is for
							# Connect messages.
							for channel in channels:
								await channel.send(macro)
							await runner.set_macros(macro)
						case Connect():
							channels[:] = await self._connect(message, sender, tasks, macro)
							await sender.asend(await self._connect(message, sender, runner, macro))
						case Abort():
							for channel in channels:
								await channel.aclose()
							await runner.abort(message)
							await runner.start(False, self.use_skip)
						case Close():
							await runner.aclose()
							return
						case _:
							assert isinstance(message, _VALID_EVENT_MESSAGE)
							skip = isinstance(message, Body)
							for channel in channels:
								await channel.send(message)
								match (await channel.receive()):
									case Skip():
										continue
									case Continue():
										skip = False
									case Accept():
										await channel.aclose()
										channels.remove(channel)
									case resp:
										await sender.asend(resp)
										break
							else:
								await sender.asend(
									Accept() if len(channels) == 0 else
									Skip() if skip else
									Continue(),
								)

	async def _negotiate(self, message: Negotiate, sender: Sender) -> None:
							# TODO: Upgrade and remove ignores once python/mypy#14242 is in
							# TODO: Should remove assert once kilter.protocol#5 is resolved
							# Type narrowing should do the job adequately
							# https://code.kodo.org.uk/kilter/kilter.protocol/-/issues/5
							assert isinstance(message, _VALID_EVENT_MESSAGE)  # type: ignore[misc,arg-type]
							await sender.asend(await runner.message_events(message))  # type: ignore[arg-type]

	async def _negotiate(self, message: Negotiate, sender: Sender) -> Negotiate:
		# TODO: actually negotiate what the filter wants, not just "everything"
		actions = set(ActionFlags)  # All actions!
		if actions != ActionFlags.unpack(message.action_flags):
@@ -156,59 +146,112 @@ class Runner:
		resp.protocol_flags = message.protocol_flags & ~_DISABLE_PROTOCOL_FLAGS
		resp.action_flags = ActionFlags.pack(actions)

		await sender.asend(resp)

		self.use_skip = bool(resp.protocol_flags & ProtocolFlags.SKIP)

		return resp

	async def _connect(
		self,
		message: Connect,
		sender: Sender,
		tasks: anyio.abc.TaskGroup,
		runner: _TaskRunner,
		macro: Macro|None,
	) -> list[MessageChannel]:
		channels = list[MessageChannel]()
	) -> ResponseMessage:
		for fltr in self.filters:
			lchannel, rchannel = _make_message_channel()
			channels.append(lchannel)
			session = Session(message, sender, _Broadcast())
			runner.add_filter(fltr, session)
		if macro:
				await session.deliver(macro)
			match await tasks.start(
				_runner, fltr, session, rchannel, self.use_skip,
			):
			await runner.set_macros(macro)
		return await runner.start(True, self.use_skip)


class _TaskRunner:

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="_TaskRunner")

	def __init__(self, tasks: anyio.abc.TaskGroup):
		self.tasks = tasks
		self.filters = list[tuple[Filter, Session]]()
		self.channels = list[MessageChannel]()

	async def __aenter__(self: Self) -> Self:
		return self

	async def __aexit__(self, *_: object) -> None:
		await self.aclose()

	def add_filter(self, flter: Filter, session: Session, /) -> None:
		self.filters.append((flter, session))

	async def start(self, first_connect: bool, use_skip: bool) -> ResponseMessage:
		if self.channels:
			raise RuntimeError(f"{self} is already running tasks")
		final: ResponseMessage = Accept()
		for flter, session in self.filters:
			lchannel, rchannel = _make_message_channel()
			self.channels.append(lchannel)
			match await self.tasks.start(self._runner, flter, session, rchannel, first_connect, use_skip):
				case Accept():
					self.channels.remove(lchannel)
				case Continue():
					continue
				case Message() as resp:
					await sender.asend(resp)
					return []
				case TemporaryFailure() as final:  # replaces final
					pass
				case Reject()|Discard()|ReplyCode() as resp:
					if not first_connect:
						logging.warning("Unexpected response from filter after restart")
						continue
					return resp
				case _ as arg:  # pragma: no-cover
					raise TypeError(
						f"task_status.started called with bad type: "
						f"{arg!r}",
					)
		await sender.asend(Continue())
		return channels
					raise TypeError(f"task_status.started called with bad type: {arg!r}")
		return final if len(self.channels) == 0 else Continue()

	async def set_macros(self, message: Macro) -> None:
		if self.channels:
			for channel in self.channels:
				await channel.send(message)
		else:
			for _, session in self.filters:
				await session.deliver(message)

def _make_message_channel() -> tuple[MessageChannel, MessageChannel]:
	lsend, rrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	rsend, lrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)

	async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage:
		skip = isinstance(message, Body)
		for channel in self.channels:
			await channel.send(message)
			match (await channel.receive()):
				case Skip():
					continue
				case Continue():
					skip = False
				case Accept():
					await channel.aclose()
					self.channels.remove(channel)
				case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp:
					return resp
		return (
			Accept() if len(self.channels) == 0 else
			Skip() if skip else
			Continue()
		)

async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender:
	buff = SimpleBuffer(1*kiB)
	while 1:
		proto.write_to(buff, (yield))
		await client.send(buff[:])
		del buff[:]
	async def abort(self, abort: Abort) -> None:
		for channel in self.channels:
			await channel.send(abort)
			await channel.receive()
			await channel.aclose()
		del self.channels[:]

	async def aclose(self) -> None:
		self.tasks.cancel_scope.cancel()
		del self.channels[:]

	@staticmethod
	async def _runner(
		fltr: Filter,
		session: Session,
		channel: MessageChannel,
		first_connect: bool,
		use_skip: bool, *,
		task_status: anyio.abc.TaskStatus,
	) -> None:
@@ -223,6 +266,9 @@ async def _runner(
				session.broadcast.task_status = task_status
				try:
					final_resp = await fltr(session)
				except Aborted:
					logging.info(f"aborted filter {qualname(fltr)}")
					return
				except Exception:
					logging.exception(f"error in filter {qualname(fltr)}")
					final_resp = TemporaryFailure()
@@ -242,9 +288,28 @@ async def _runner(
				if isinstance(message, Macro):
					await session.deliver(message)
					continue
			assert isinstance(message, _VALID_EVENT_MESSAGE)
			resp = await session.deliver(message)
				# TODO: Upgrade and remove ignores once python/mypy#14242 is in
				assert isinstance(message, _VALID_EVENT_MESSAGE)  # type: ignore[misc,arg-type]
				resp = await session.deliver(message)  # type: ignore[arg-type]
				if final_resp is not None:
					break  # type: ignore
				if isinstance(message, Abort):
					await channel.send(Continue())
					await channel.aclose()
					return
				await channel.send(Skip() if use_skip and resp == Skip else Continue())
			await channel.send(final_resp)


def _make_message_channel() -> tuple[MessageChannel, MessageChannel]:
	lsend, rrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	rsend, lrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)


async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender:
	buff = SimpleBuffer(1*kiB)
	while 1:
		proto.write_to(buff, (yield))
		await client.send(buff[:])
		del buff[:]