Verified Commit 6aa9a09d authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Rewrite much of the runner module

Combines a refactor of the runner module and a bit of the session
module for clarity and maintainability, and a fix for #21.

Fixes #21
parent 717d4f63
Loading
Loading
Loading
Loading
+132 −227
Original line number Diff line number Diff line
@@ -14,16 +14,16 @@ The primary class in this module (`Runner`) is intended to be used with an

from __future__ import annotations

import enum
import logging
from collections import defaultdict
from collections.abc import Iterable
from typing import Final
from typing import TypeAlias
from warnings import warn

import anyio.abc
from anyio.streams.stapled import StapledObjectStream
from async_generator import aclosing
from typing_extensions import Self

from kilter.protocol.buffer import SimpleBuffer
from kilter.protocol.core import EventMessage
@@ -41,14 +41,16 @@ from .session import Session
from .util import Broadcast
from .util import qualname

MessageChannel: TypeAlias = anyio.abc.ObjectStream[Message]
__all__ = [
	"Runner",
	"NegotiationError",
]

FinalResponse: TypeAlias = FilterResponse | TemporaryFailure

kiB: Final = 2**10
MiB: Final = 2**20

_VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \
	Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort

_logger = logging.getLogger(__package__)


@@ -58,25 +60,27 @@ class NegotiationError(Exception):
	"""


class _CloseFilter:
class State(enum.Enum):

	def __init__(self, filtr: Filter):
		self.filter = filtr
	CONNECTED = enum.auto()
	SESSION = enum.auto()
	SESSION_ABORTED = enum.auto()
	MESSAGE = enum.auto()
	MESSAGE_ABORTED = enum.auto()


class _Broadcast(Broadcast[EventMessage]):

	def __init__(self) -> None:
		super().__init__()
		self.task_status: anyio.abc.TaskStatus[None]|None = None
		self.task_status = list[anyio.abc.TaskStatus[None]]()

	async def shutdown_hook(self) -> None:
		await self.pre_receive_hook()

	async def pre_receive_hook(self) -> None:
		if self.task_status is not None:
			self.task_status.started()
			self.task_status = None
		while self.task_status:
			self.task_status.pop().started()


class Sender:
@@ -122,13 +126,13 @@ class Runner:
		buff = SimpleBuffer(1*MiB)
		proto = FilterProtocol(abort_on_unknown=True)
		sender = Sender(client, proto)
		macro: Macro|None = None
		aborted = False
		session = Session(sender, _Broadcast())
		runner = SessionRunner(session)
		state = State.CONNECTED

		async with (
			aclosing(client),
			anyio.create_task_group() as tasks,
			_TaskRunner(tasks) as runner,
		):
			while 1:
				try:
@@ -138,54 +142,57 @@ class Runner:
					anyio.ClosedResourceError,
					anyio.BrokenResourceError,
				):
					await runner.aclose()
					return
				for message in proto.read_from(buff):
					if __debug__:
						_logger.debug(f"received: {message}")

					# If previous message was Abort, restart filters for any non-Abort/Close
					# message
					if state in (State.SESSION_ABORTED, State.MESSAGE_ABORTED):
						if not isinstance(message, Abort|Close):
							await runner.start(self.filters, tasks)
						state = (
							State.CONNECTED if state == State.SESSION_ABORTED else
							State.SESSION
						)

					match message:
						case Negotiate():
							await sender.send(await self._negotiate(message))
						case Macro() as macro:
							# Note that this Macro will hang around as "macro"; this is for
							# Connect messages.
							await runner.set_macros(macro)
							continue
						case Connect():
							await self._prepare_filters(message, sender, runner)
							if macro:
								await runner.set_macros(macro)
							needs_response = proto.needs_response(message)
							match await runner.start(needs_response, True, self.use_skip):
								case None:
									assert not needs_response
								case _CloseFilter() as notif:
									self.filters.remove(notif.filter)
								case c_resp if needs_response:
									assert c_resp is not None and not isinstance(c_resp, _CloseFilter)
									await sender.send(c_resp)
								case c_resp:
									raise RuntimeError(f"unexpected response: {c_resp}")
							_logger.info(f"Client connected from {message.hostname}")
							await session.deliver(message)
							await runner.start(self.filters, tasks)
							if proto.needs_response(message):
								await sender.send(await runner.check_response() or Continue())
							continue
						case Helo():
							state = State.SESSION
						case EnvelopeFrom():
							state = State.MESSAGE
						case Abort() if state in (State.SESSION, State.MESSAGE):
							state = (
								State.SESSION_ABORTED if state == State.SESSION else
								State.MESSAGE_ABORTED
							)
						case Abort():
							aborted = True
							await runner.abort(message)
							_logger.warning("Unexpected Abort received")
							state = State.CONNECTED
						case Close():
							await runner.aclose()
							tasks.cancel_scope.cancel()
							return
						case _:
							if aborted:
								aborted = False
								await runner.start(True, False, self.use_skip)
							needs_response = proto.needs_response(message)
							match await runner.message_events(message, needs_response):
								case None:
									assert not needs_response
								case _CloseFilter() as notif:
									self.filters.remove(notif.filter)
								case resp if needs_response:
									assert resp is not None and not isinstance(resp, _CloseFilter)

					skip_or_cont = await session.deliver(message)
					if not proto.needs_response(message):
						continue
					if (resp := await runner.check_response()):
						await sender.send(resp)
								case resp:
									raise RuntimeError(f"unexpected response: {resp}")
					elif self.use_skip:
						await sender.send(skip_or_cont())
					else:
						await sender.send(Continue())

	async def _negotiate(self, message: Negotiate) -> Negotiate:
		_logger.info("Negotiating with MTA")
@@ -227,185 +234,83 @@ class Runner:

		return Negotiate(6, actions, options, dict(macros))

	async def _prepare_filters(
		self,
		message: Connect,
		sender: Sender,
		runner: _TaskRunner,
	) -> None:
		_logger.info(f"Client connected from {message.hostname}")
		for fltr in self.filters:
			session = Session(message, sender, _Broadcast())
			runner.add_filter(fltr, session)


class _TaskRunner:

	def __init__(self, tasks: anyio.abc.TaskGroup):
		self.tasks = tasks
		self.filters = list[tuple[Filter, Session]]()
		self.channels = dict[MessageChannel, Filter]()

	async def __aenter__(self) -> Self:
		return self
class SessionRunner:

	async def __aexit__(self, *_: object) -> None:
		await self.aclose()
	def __init__(self, session: Session):
		self.session = session
		self.filters = dict[Filter, FinalResponse|None]()

	def add_filter(self, flter: Filter, session: Session, /) -> None:
		self.filters.append((flter, session))
	async def start(self, filters: Iterable[Filter], task_group: anyio.abc.TaskGroup) -> None:
		"""
		Run all the given filters in a task group

	async def start(
		self,
		needs_response: bool,
		first_connect: bool,
		use_skip: bool,
	) -> ResponseMessage|_CloseFilter|None:
		if self.channels:
			raise RuntimeError(f"{self} is already running tasks")
		final: ResponseMessage = Accept()
		for flter, session in self.filters:
			lchannel, rchannel = _make_message_channel()
			self.channels[lchannel] = flter
			match await self.tasks.start(self._runner, flter, session, rchannel, use_skip):
				case Accept():
					del self.channels[lchannel]
				case Continue():
					continue
				case TemporaryFailure() as final:  # replaces final
					pass
				case Reject()|Discard()|ReplyCode() as resp:
					if not first_connect:
						_logger.warning(
							f"Ignoring unexpected response from filter after restart: "
							f"{qualname(flter)} -> {resp}",
						)
						continue
					if not needs_response:
						_logger.warning(
							f"Unexpected response from filter {qualname(flter)}",
						)
						return _CloseFilter(flter)
					return resp
				case _ as arg:  # pragma: no-cover
					raise TypeError(f"task_status.started called with bad type: {arg!r}")
		if not needs_response:
			return None
		return final if len(self.channels) == 0 else Continue()

	async def set_macros(self, message: Macro) -> None:
		if self.channels:
			for channel in self.channels:
				await channel.send(message)
		else:
			for _, session in self.filters:
				await session.deliver(message)
		The session MUST have been primed by the delivery of a Connect message beforehand or
		filters will be unable to access the connection details.
		"""
		_logger.debug("Starting filters")
		for flter in filters:
			await task_group.start(self.run_filter, flter)

	async def message_events(
	async def run_filter(
		self,
		message: _VALID_EVENT_MESSAGE,
		needs_response: bool,
	) -> ResponseMessage|Skip|_CloseFilter|None:
		skip = isinstance(message, Body)
		for channel in list(self.channels):
			await channel.send(message)
			match (await channel.receive()):
				case Skip():
					continue
				case Continue():
					skip = False
				case Accept() as resp:
					flter = await self.close_channel(channel)
					if len(self.channels) == 0:
						_logger.info(f"Returning response Accept from {qualname(flter)}")
						return resp
					_logger.info(f"Holding response Accept from {qualname(flter)}")
				case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp:
					flter = await self.close_channel(channel)
					if not needs_response:
						_logger.warning(f"Unexpected response from filter {qualname(flter)}")
						return _CloseFilter(flter)
					_logger.info(f"Returning response {type(resp).__name__} from {qualname(flter)}")
					return resp
		assert len(self.channels) > 0, "Running filters reached zero without a response?!"
		if not needs_response:
			return None
		return Skip() if skip else Continue()

	async def close_channel(self, channel: MessageChannel) -> Filter:
		await channel.aclose()
		return self.channels.pop(channel)

	async def abort(self, abort: Abort) -> None:
		if not self.channels:
			return
		_logger.info("Aborting filters")
		for channel in self.channels:
			await channel.send(abort)
			await channel.receive()
			await channel.aclose()
		self.channels.clear()

	async def aclose(self) -> None:
		if self.channels:
			_logger.info("Closing filters")
		self.tasks.cancel_scope.cancel()
		self.channels.clear()

	@staticmethod
	async def _runner(
		fltr: Filter,
		session: Session,
		channel: MessageChannel,
		use_skip: bool, *,
		task_status: anyio.abc.TaskStatus[ResponseMessage],
	) -> None:
		final_resp: ResponseMessage|None = None

		async def _filter_wrap(
		flter: Filter,
		task_status: anyio.abc.TaskStatus[None],
	) -> None:
			nonlocal final_resp
			async with session:
				assert isinstance(session.broadcast, _Broadcast)
				session.broadcast.task_status = task_status
		"""
		Run a filter as a subtask in a task group

		A `Future` for returning the filter's response is added to the
		`SessionRunner.filter` dict.
		"""
		if flter in self.filters:
			raise RuntimeError
		self.filters[flter] = None

		async with self.session:
			assert isinstance(self.session.broadcast, _Broadcast)
			status_notifiers = self.session.broadcast.task_status
			status_notifiers.append(task_status)

			try:
					final_resp = await fltr(session)
				resp: FinalResponse = await flter(self.session)
			except Aborted:
					_logger.debug(f"Aborted filter {qualname(fltr)}")
				_logger.debug(f"Aborted filter {qualname(flter)}")
				del self.filters[flter]
				return
			except Exception:
					_logger.exception(f"Error in filter {qualname(fltr)}")
					final_resp = TemporaryFailure()
				if not isinstance(final_resp, FilterResponse):
					warn(f"expected a valid response from {qualname(fltr)}, got {final_resp}")
					final_resp = TemporaryFailure()

		async with anyio.create_task_group() as tasks:
			await tasks.start(_filter_wrap)
			task_status.started(final_resp or Continue())
			while final_resp is None:
				try:
					message = await channel.receive()
				except (anyio.EndOfStream, anyio.ClosedResourceError):
					tasks.cancel_scope.cancel()
					return
				if isinstance(message, Macro):
					await session.deliver(message)
				_logger.exception(f"Error in filter {qualname(flter)}")
				resp = TemporaryFailure()
			if not isinstance(resp, FinalResponse):
				warn(f"expected a valid response from {qualname(flter)}, got {resp}")  # type: ignore # Don't fully trust users…
				resp = TemporaryFailure()
			self.filters[flter] = resp
			if task_status in status_notifiers:
				status_notifiers.remove(task_status)
				task_status.started()

	async def check_response(self) -> ResponseMessage|None:
		assert self.filters, "no filters when checking for a response"
		response: ResponseMessage|None = None
		complete = list[Filter]()
		for flter, result in self.filters.items():
			# If a filter has not finished or no response is expected, continue without
			# removing from filter container; remove failed filters and filters that have
			# accepted; return a response for rejections;
			match result:
				case None:
					continue
				assert isinstance(message, _VALID_EVENT_MESSAGE)
				resp = await session.deliver(message)
				if isinstance(message, Abort):
					await channel.send(Continue())
					await channel.aclose()
					return
				if final_resp is not None:
					break  # type: ignore[unreachable]
				await channel.send(Skip() if use_skip and resp == Skip else Continue())
			await channel.send(final_resp)


def _make_message_channel() -> tuple[MessageChannel, MessageChannel]:
	lsend, rrecv = anyio.create_memory_object_stream[Message](1)
	rsend, lrecv = anyio.create_memory_object_stream[Message](1)
	return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)
				case Accept():
					_logger.info("Accept from %s, waiting for remaining", qualname(flter))
				case TemporaryFailure() as response:
					_logger.warning("Filter failed: %s", flter)
				case Reject()|Discard()|ReplyCode() as response:
					_logger.info("Returning response %s from %s", type(response).__name__, qualname(flter))
					complete[:] = self.filters
					break
				case msg:
					raise AssertionError(f"unexpected filter result: {msg}")
			complete.append(flter)
		for flter in complete:
			del self.filters[flter]
		return response if response else None if self.filters else Accept()
+27 −5
Original line number Diff line number Diff line
@@ -67,6 +67,12 @@ class Phase(int, Enum):
	raised by `Session` methods.
	"""

	INIT = 0
	"""
	This phase is the pre-connected phase of a session; this phase will be completed before
	users see the session object.
	"""

	CONNECT = 1
	"""
	This phase is the starting phase of a session, during which a HELO/EHLO message may be
@@ -201,13 +207,12 @@ class Session:

	def __init__(
		self,
		connmsg: Connect,
		sender: Sender,
		broadcast: util.Broadcast[EventMessage]|None = None,
	):
		self.host = connmsg.hostname
		self.address = connmsg.address
		self.port = connmsg.port
		self.host = ""
		self.address = None
		self.port = 0

		self.sender = sender
		self.broadcast = broadcast or util.Broadcast[EventMessage]()
@@ -218,7 +223,9 @@ class Session:

		# Phase checking is a bit fuzzy as a filter may not request every message,
		# so some phases will be skipped; checks should not try to exactly match a phase.
		self.phase = Phase.CONNECT
		self.phase = Phase.INIT

		self._helo: Helo|None = None

	async def __aenter__(self) -> Self:
		await self.broadcast.__aenter__()
@@ -229,11 +236,22 @@ class Session:
		# on session close, wake up any remaining deliver() awaitables
		await self.broadcast.shutdown_hook()

	def _reset(self) -> None:
		self.headers = HeadersAccessor(self, self.sender)
		self.body = BodyAccessor(self, self.sender)

	async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]:
		"""
		Deliver a message (or its contents) to a task waiting for it
		"""
		match message:
			case Connect():
				self.host = message.hostname
				self.address = message.address
				self.port = message.port
				async with self.broadcast:
					self.phase = Phase.CONNECT
				return Continue
			case Macro():
				self.macros.update(message.macros)
				return Continue  # not strictly necessary, but type checker needs something
@@ -241,6 +259,7 @@ class Session:
				async with self.broadcast:
					self.phase = Phase.CONNECT
				await self.broadcast.abort(Aborted)
				self._reset()
				return Continue
			case Helo():
				phase = Phase.MAIL
@@ -266,9 +285,12 @@ class Session:
				"Session.helo() must be awaited before any other async features of a "
				"Session",
			)
		if self._helo:
			return self._helo.hostname
		while self.phase <= Phase.CONNECT:
			message = await self.broadcast.receive()
			if isinstance(message, Helo):
				self._helo = message
				return message.hostname
		raise RuntimeError("HELO/EHLO event not received")

+4 −4
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ class BodyAccessorTests(AsyncTestCase):
		"""
		Check that the body iterator works as expected
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result = b""

		@with_session(session)
@@ -47,7 +47,7 @@ class BodyAccessorTests(AsyncTestCase):
		"""
		Check that Body (and EOM) messages are skipped after breaking out of a loop
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result1 = b""
		result2 = b""

@@ -78,7 +78,7 @@ class BodyAccessorTests(AsyncTestCase):
		Check that `write()` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)

		@with_session(session)
		async def test_filter() -> None:
@@ -96,7 +96,7 @@ class BodyAccessorTests(AsyncTestCase):
		Check that `write()` in an async with context issues a warning
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)

		@with_session(session)
		async def test_filter() -> None:
+17 −17
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that header iterator works as expected
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result = []

		@with_session(session)
@@ -50,7 +50,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that all headers are collected when breaking out of a loop
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result1 = []
		result2 = []

@@ -88,7 +88,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that all headers are collected when awaiting `collect()`
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result = []

		@with_session(session)
@@ -114,7 +114,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that all headers are collected when awaiting `collect()` if EOH is missed
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result = []

		@with_session(session)
@@ -140,7 +140,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that `restrict()` works as expected
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())
		result = []

		@with_session(session)
@@ -165,7 +165,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `delete()` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -198,7 +198,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `update()` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -231,7 +231,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `insert(..., START)` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -261,7 +261,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `insert(..., END)` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -291,7 +291,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `insert(..., Before(...))` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -324,7 +324,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `insert(..., After(...))` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -357,7 +357,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that `insert(..., After(<last header>))` works as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -390,7 +390,7 @@ class HeaderAccessorTests(AsyncTestCase):
		Check that multiple edits in a filter work as expected
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		session = Session(sender)
		result = []

		@with_session(session)
@@ -430,7 +430,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that the AsyncGenerator-required method `asend()` works
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())

		@with_session(session)
		async def test_filter() -> None:
@@ -449,7 +449,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that the AsyncGenerator-required method `athrow()` works
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())

		@with_session(session)
		async def test_filter() -> None:
@@ -471,7 +471,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that the AsyncGenerator-required method `athrow()` works
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())

		@with_session(session)
		async def test_filter() -> None:
@@ -494,7 +494,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		Check that the AsyncGenerator-required method `athrow()` works
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		session = Session(MockEditor())

		@with_session(session)
		async def test_filter() -> None:
+29 −52

File changed.

Preview size limit exceeded, changes collapsed.