Loading kilter/service/runner.py +132 −227 Original line number Diff line number Diff line Loading @@ -14,16 +14,16 @@ The primary class in this module (`Runner`) is intended to be used with an from __future__ import annotations import enum import logging from collections import defaultdict from collections.abc import Iterable from typing import Final from typing import TypeAlias from warnings import warn import anyio.abc from anyio.streams.stapled import StapledObjectStream from async_generator import aclosing from typing_extensions import Self from kilter.protocol.buffer import SimpleBuffer from kilter.protocol.core import EventMessage Loading @@ -41,14 +41,16 @@ from .session import Session from .util import Broadcast from .util import qualname MessageChannel: TypeAlias = anyio.abc.ObjectStream[Message] __all__ = [ "Runner", "NegotiationError", ] FinalResponse: TypeAlias = FilterResponse | TemporaryFailure kiB: Final = 2**10 MiB: Final = 2**20 _VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \ Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort _logger = logging.getLogger(__package__) Loading @@ -58,25 +60,27 @@ class NegotiationError(Exception): """ class _CloseFilter: class State(enum.Enum): def __init__(self, filtr: Filter): self.filter = filtr CONNECTED = enum.auto() SESSION = enum.auto() SESSION_ABORTED = enum.auto() MESSAGE = enum.auto() MESSAGE_ABORTED = enum.auto() class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: super().__init__() self.task_status: anyio.abc.TaskStatus[None]|None = None self.task_status = list[anyio.abc.TaskStatus[None]]() async def shutdown_hook(self) -> None: await self.pre_receive_hook() async def pre_receive_hook(self) -> None: if self.task_status is not None: self.task_status.started() self.task_status = None while self.task_status: self.task_status.pop().started() class Sender: Loading Loading @@ -122,13 +126,13 @@ class Runner: buff = SimpleBuffer(1*MiB) proto = FilterProtocol(abort_on_unknown=True) sender = Sender(client, proto) macro: Macro|None = None aborted = False session = Session(sender, _Broadcast()) runner = SessionRunner(session) state = State.CONNECTED async with ( aclosing(client), anyio.create_task_group() as tasks, _TaskRunner(tasks) as runner, ): while 1: try: Loading @@ -138,54 +142,57 @@ class Runner: anyio.ClosedResourceError, anyio.BrokenResourceError, ): await runner.aclose() return for message in proto.read_from(buff): if __debug__: _logger.debug(f"received: {message}") # If previous message was Abort, restart filters for any non-Abort/Close # message if state in (State.SESSION_ABORTED, State.MESSAGE_ABORTED): if not isinstance(message, Abort|Close): await runner.start(self.filters, tasks) state = ( State.CONNECTED if state == State.SESSION_ABORTED else State.SESSION ) match message: case Negotiate(): await sender.send(await self._negotiate(message)) case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. await runner.set_macros(macro) continue case Connect(): await self._prepare_filters(message, sender, runner) if macro: await runner.set_macros(macro) needs_response = proto.needs_response(message) match await runner.start(needs_response, True, self.use_skip): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case c_resp if needs_response: assert c_resp is not None and not isinstance(c_resp, _CloseFilter) await sender.send(c_resp) case c_resp: raise RuntimeError(f"unexpected response: {c_resp}") _logger.info(f"Client connected from {message.hostname}") await session.deliver(message) await runner.start(self.filters, tasks) if proto.needs_response(message): await sender.send(await runner.check_response() or Continue()) continue case Helo(): state = State.SESSION case EnvelopeFrom(): state = State.MESSAGE case Abort() if state in (State.SESSION, State.MESSAGE): state = ( State.SESSION_ABORTED if state == State.SESSION else State.MESSAGE_ABORTED ) case Abort(): aborted = True await runner.abort(message) _logger.warning("Unexpected Abort received") state = State.CONNECTED case Close(): await runner.aclose() tasks.cancel_scope.cancel() return case _: if aborted: aborted = False await runner.start(True, False, self.use_skip) needs_response = proto.needs_response(message) match await runner.message_events(message, needs_response): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case resp if needs_response: assert resp is not None and not isinstance(resp, _CloseFilter) skip_or_cont = await session.deliver(message) if not proto.needs_response(message): continue if (resp := await runner.check_response()): await sender.send(resp) case resp: raise RuntimeError(f"unexpected response: {resp}") elif self.use_skip: await sender.send(skip_or_cont()) else: await sender.send(Continue()) async def _negotiate(self, message: Negotiate) -> Negotiate: _logger.info("Negotiating with MTA") Loading Loading @@ -227,185 +234,83 @@ class Runner: return Negotiate(6, actions, options, dict(macros)) async def _prepare_filters( self, message: Connect, sender: Sender, runner: _TaskRunner, ) -> None: _logger.info(f"Client connected from {message.hostname}") for fltr in self.filters: session = Session(message, sender, _Broadcast()) runner.add_filter(fltr, session) class _TaskRunner: def __init__(self, tasks: anyio.abc.TaskGroup): self.tasks = tasks self.filters = list[tuple[Filter, Session]]() self.channels = dict[MessageChannel, Filter]() async def __aenter__(self) -> Self: return self class SessionRunner: async def __aexit__(self, *_: object) -> None: await self.aclose() def __init__(self, session: Session): self.session = session self.filters = dict[Filter, FinalResponse|None]() def add_filter(self, flter: Filter, session: Session, /) -> None: self.filters.append((flter, session)) async def start(self, filters: Iterable[Filter], task_group: anyio.abc.TaskGroup) -> None: """ Run all the given filters in a task group async def start( self, needs_response: bool, first_connect: bool, use_skip: bool, ) -> ResponseMessage|_CloseFilter|None: if self.channels: raise RuntimeError(f"{self} is already running tasks") final: ResponseMessage = Accept() for flter, session in self.filters: lchannel, rchannel = _make_message_channel() self.channels[lchannel] = flter match await self.tasks.start(self._runner, flter, session, rchannel, use_skip): case Accept(): del self.channels[lchannel] case Continue(): continue case TemporaryFailure() as final: # replaces final pass case Reject()|Discard()|ReplyCode() as resp: if not first_connect: _logger.warning( f"Ignoring unexpected response from filter after restart: " f"{qualname(flter)} -> {resp}", ) continue if not needs_response: _logger.warning( f"Unexpected response from filter {qualname(flter)}", ) return _CloseFilter(flter) return resp case _ as arg: # pragma: no-cover raise TypeError(f"task_status.started called with bad type: {arg!r}") if not needs_response: return None return final if len(self.channels) == 0 else Continue() async def set_macros(self, message: Macro) -> None: if self.channels: for channel in self.channels: await channel.send(message) else: for _, session in self.filters: await session.deliver(message) The session MUST have been primed by the delivery of a Connect message beforehand or filters will be unable to access the connection details. """ _logger.debug("Starting filters") for flter in filters: await task_group.start(self.run_filter, flter) async def message_events( async def run_filter( self, message: _VALID_EVENT_MESSAGE, needs_response: bool, ) -> ResponseMessage|Skip|_CloseFilter|None: skip = isinstance(message, Body) for channel in list(self.channels): await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept() as resp: flter = await self.close_channel(channel) if len(self.channels) == 0: _logger.info(f"Returning response Accept from {qualname(flter)}") return resp _logger.info(f"Holding response Accept from {qualname(flter)}") case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: flter = await self.close_channel(channel) if not needs_response: _logger.warning(f"Unexpected response from filter {qualname(flter)}") return _CloseFilter(flter) _logger.info(f"Returning response {type(resp).__name__} from {qualname(flter)}") return resp assert len(self.channels) > 0, "Running filters reached zero without a response?!" if not needs_response: return None return Skip() if skip else Continue() async def close_channel(self, channel: MessageChannel) -> Filter: await channel.aclose() return self.channels.pop(channel) async def abort(self, abort: Abort) -> None: if not self.channels: return _logger.info("Aborting filters") for channel in self.channels: await channel.send(abort) await channel.receive() await channel.aclose() self.channels.clear() async def aclose(self) -> None: if self.channels: _logger.info("Closing filters") self.tasks.cancel_scope.cancel() self.channels.clear() @staticmethod async def _runner( fltr: Filter, session: Session, channel: MessageChannel, use_skip: bool, *, task_status: anyio.abc.TaskStatus[ResponseMessage], ) -> None: final_resp: ResponseMessage|None = None async def _filter_wrap( flter: Filter, task_status: anyio.abc.TaskStatus[None], ) -> None: nonlocal final_resp async with session: assert isinstance(session.broadcast, _Broadcast) session.broadcast.task_status = task_status """ Run a filter as a subtask in a task group A `Future` for returning the filter's response is added to the `SessionRunner.filter` dict. """ if flter in self.filters: raise RuntimeError self.filters[flter] = None async with self.session: assert isinstance(self.session.broadcast, _Broadcast) status_notifiers = self.session.broadcast.task_status status_notifiers.append(task_status) try: final_resp = await fltr(session) resp: FinalResponse = await flter(self.session) except Aborted: _logger.debug(f"Aborted filter {qualname(fltr)}") _logger.debug(f"Aborted filter {qualname(flter)}") del self.filters[flter] return except Exception: _logger.exception(f"Error in filter {qualname(fltr)}") final_resp = TemporaryFailure() if not isinstance(final_resp, FilterResponse): warn(f"expected a valid response from {qualname(fltr)}, got {final_resp}") final_resp = TemporaryFailure() async with anyio.create_task_group() as tasks: await tasks.start(_filter_wrap) task_status.started(final_resp or Continue()) while final_resp is None: try: message = await channel.receive() except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return if isinstance(message, Macro): await session.deliver(message) _logger.exception(f"Error in filter {qualname(flter)}") resp = TemporaryFailure() if not isinstance(resp, FinalResponse): warn(f"expected a valid response from {qualname(flter)}, got {resp}") # type: ignore # Don't fully trust users… resp = TemporaryFailure() self.filters[flter] = resp if task_status in status_notifiers: status_notifiers.remove(task_status) task_status.started() async def check_response(self) -> ResponseMessage|None: assert self.filters, "no filters when checking for a response" response: ResponseMessage|None = None complete = list[Filter]() for flter, result in self.filters.items(): # If a filter has not finished or no response is expected, continue without # removing from filter container; remove failed filters and filters that have # accepted; return a response for rejections; match result: case None: continue assert isinstance(message, _VALID_EVENT_MESSAGE) resp = await session.deliver(message) if isinstance(message, Abort): await channel.send(Continue()) await channel.aclose() return if final_resp is not None: break # type: ignore[unreachable] await channel.send(Skip() if use_skip and resp == Skip else Continue()) await channel.send(final_resp) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream[Message](1) rsend, lrecv = anyio.create_memory_object_stream[Message](1) return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) case Accept(): _logger.info("Accept from %s, waiting for remaining", qualname(flter)) case TemporaryFailure() as response: _logger.warning("Filter failed: %s", flter) case Reject()|Discard()|ReplyCode() as response: _logger.info("Returning response %s from %s", type(response).__name__, qualname(flter)) complete[:] = self.filters break case msg: raise AssertionError(f"unexpected filter result: {msg}") complete.append(flter) for flter in complete: del self.filters[flter] return response if response else None if self.filters else Accept() kilter/service/session.py +27 −5 Original line number Diff line number Diff line Loading @@ -67,6 +67,12 @@ class Phase(int, Enum): raised by `Session` methods. """ INIT = 0 """ This phase is the pre-connected phase of a session; this phase will be completed before users see the session object. """ CONNECT = 1 """ This phase is the starting phase of a session, during which a HELO/EHLO message may be Loading Loading @@ -201,13 +207,12 @@ class Session: def __init__( self, connmsg: Connect, sender: Sender, broadcast: util.Broadcast[EventMessage]|None = None, ): self.host = connmsg.hostname self.address = connmsg.address self.port = connmsg.port self.host = "" self.address = None self.port = 0 self.sender = sender self.broadcast = broadcast or util.Broadcast[EventMessage]() Loading @@ -218,7 +223,9 @@ class Session: # Phase checking is a bit fuzzy as a filter may not request every message, # so some phases will be skipped; checks should not try to exactly match a phase. self.phase = Phase.CONNECT self.phase = Phase.INIT self._helo: Helo|None = None async def __aenter__(self) -> Self: await self.broadcast.__aenter__() Loading @@ -229,11 +236,22 @@ class Session: # on session close, wake up any remaining deliver() awaitables await self.broadcast.shutdown_hook() def _reset(self) -> None: self.headers = HeadersAccessor(self, self.sender) self.body = BodyAccessor(self, self.sender) async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]: """ Deliver a message (or its contents) to a task waiting for it """ match message: case Connect(): self.host = message.hostname self.address = message.address self.port = message.port async with self.broadcast: self.phase = Phase.CONNECT return Continue case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something Loading @@ -241,6 +259,7 @@ class Session: async with self.broadcast: self.phase = Phase.CONNECT await self.broadcast.abort(Aborted) self._reset() return Continue case Helo(): phase = Phase.MAIL Loading @@ -266,9 +285,12 @@ class Session: "Session.helo() must be awaited before any other async features of a " "Session", ) if self._helo: return self._helo.hostname while self.phase <= Phase.CONNECT: message = await self.broadcast.receive() if isinstance(message, Helo): self._helo = message return message.hostname raise RuntimeError("HELO/EHLO event not received") Loading tests/test_body_accessor.py +4 −4 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ class BodyAccessorTests(AsyncTestCase): """ Check that the body iterator works as expected """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = b"" @with_session(session) Loading @@ -47,7 +47,7 @@ class BodyAccessorTests(AsyncTestCase): """ Check that Body (and EOM) messages are skipped after breaking out of a loop """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result1 = b"" result2 = b"" Loading Loading @@ -78,7 +78,7 @@ class BodyAccessorTests(AsyncTestCase): Check that `write()` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) @with_session(session) async def test_filter() -> None: Loading @@ -96,7 +96,7 @@ class BodyAccessorTests(AsyncTestCase): Check that `write()` in an async with context issues a warning """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) @with_session(session) async def test_filter() -> None: Loading tests/test_header_accessor.py +17 −17 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that header iterator works as expected """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -50,7 +50,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that all headers are collected when breaking out of a loop """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result1 = [] result2 = [] Loading Loading @@ -88,7 +88,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that all headers are collected when awaiting `collect()` """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -114,7 +114,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that all headers are collected when awaiting `collect()` if EOH is missed """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -140,7 +140,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that `restrict()` works as expected """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -165,7 +165,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `delete()` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -198,7 +198,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `update()` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -231,7 +231,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., START)` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -261,7 +261,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., END)` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -291,7 +291,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., Before(...))` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -324,7 +324,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., After(...))` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -357,7 +357,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., After(<last header>))` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -390,7 +390,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that multiple edits in a filter work as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -430,7 +430,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `asend()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading @@ -449,7 +449,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `athrow()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading @@ -471,7 +471,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `athrow()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading @@ -494,7 +494,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `athrow()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading tests/test_session.py +29 −52 File changed.Preview size limit exceeded, changes collapsed. Show changes Loading
kilter/service/runner.py +132 −227 Original line number Diff line number Diff line Loading @@ -14,16 +14,16 @@ The primary class in this module (`Runner`) is intended to be used with an from __future__ import annotations import enum import logging from collections import defaultdict from collections.abc import Iterable from typing import Final from typing import TypeAlias from warnings import warn import anyio.abc from anyio.streams.stapled import StapledObjectStream from async_generator import aclosing from typing_extensions import Self from kilter.protocol.buffer import SimpleBuffer from kilter.protocol.core import EventMessage Loading @@ -41,14 +41,16 @@ from .session import Session from .util import Broadcast from .util import qualname MessageChannel: TypeAlias = anyio.abc.ObjectStream[Message] __all__ = [ "Runner", "NegotiationError", ] FinalResponse: TypeAlias = FilterResponse | TemporaryFailure kiB: Final = 2**10 MiB: Final = 2**20 _VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \ Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort _logger = logging.getLogger(__package__) Loading @@ -58,25 +60,27 @@ class NegotiationError(Exception): """ class _CloseFilter: class State(enum.Enum): def __init__(self, filtr: Filter): self.filter = filtr CONNECTED = enum.auto() SESSION = enum.auto() SESSION_ABORTED = enum.auto() MESSAGE = enum.auto() MESSAGE_ABORTED = enum.auto() class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: super().__init__() self.task_status: anyio.abc.TaskStatus[None]|None = None self.task_status = list[anyio.abc.TaskStatus[None]]() async def shutdown_hook(self) -> None: await self.pre_receive_hook() async def pre_receive_hook(self) -> None: if self.task_status is not None: self.task_status.started() self.task_status = None while self.task_status: self.task_status.pop().started() class Sender: Loading Loading @@ -122,13 +126,13 @@ class Runner: buff = SimpleBuffer(1*MiB) proto = FilterProtocol(abort_on_unknown=True) sender = Sender(client, proto) macro: Macro|None = None aborted = False session = Session(sender, _Broadcast()) runner = SessionRunner(session) state = State.CONNECTED async with ( aclosing(client), anyio.create_task_group() as tasks, _TaskRunner(tasks) as runner, ): while 1: try: Loading @@ -138,54 +142,57 @@ class Runner: anyio.ClosedResourceError, anyio.BrokenResourceError, ): await runner.aclose() return for message in proto.read_from(buff): if __debug__: _logger.debug(f"received: {message}") # If previous message was Abort, restart filters for any non-Abort/Close # message if state in (State.SESSION_ABORTED, State.MESSAGE_ABORTED): if not isinstance(message, Abort|Close): await runner.start(self.filters, tasks) state = ( State.CONNECTED if state == State.SESSION_ABORTED else State.SESSION ) match message: case Negotiate(): await sender.send(await self._negotiate(message)) case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. await runner.set_macros(macro) continue case Connect(): await self._prepare_filters(message, sender, runner) if macro: await runner.set_macros(macro) needs_response = proto.needs_response(message) match await runner.start(needs_response, True, self.use_skip): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case c_resp if needs_response: assert c_resp is not None and not isinstance(c_resp, _CloseFilter) await sender.send(c_resp) case c_resp: raise RuntimeError(f"unexpected response: {c_resp}") _logger.info(f"Client connected from {message.hostname}") await session.deliver(message) await runner.start(self.filters, tasks) if proto.needs_response(message): await sender.send(await runner.check_response() or Continue()) continue case Helo(): state = State.SESSION case EnvelopeFrom(): state = State.MESSAGE case Abort() if state in (State.SESSION, State.MESSAGE): state = ( State.SESSION_ABORTED if state == State.SESSION else State.MESSAGE_ABORTED ) case Abort(): aborted = True await runner.abort(message) _logger.warning("Unexpected Abort received") state = State.CONNECTED case Close(): await runner.aclose() tasks.cancel_scope.cancel() return case _: if aborted: aborted = False await runner.start(True, False, self.use_skip) needs_response = proto.needs_response(message) match await runner.message_events(message, needs_response): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case resp if needs_response: assert resp is not None and not isinstance(resp, _CloseFilter) skip_or_cont = await session.deliver(message) if not proto.needs_response(message): continue if (resp := await runner.check_response()): await sender.send(resp) case resp: raise RuntimeError(f"unexpected response: {resp}") elif self.use_skip: await sender.send(skip_or_cont()) else: await sender.send(Continue()) async def _negotiate(self, message: Negotiate) -> Negotiate: _logger.info("Negotiating with MTA") Loading Loading @@ -227,185 +234,83 @@ class Runner: return Negotiate(6, actions, options, dict(macros)) async def _prepare_filters( self, message: Connect, sender: Sender, runner: _TaskRunner, ) -> None: _logger.info(f"Client connected from {message.hostname}") for fltr in self.filters: session = Session(message, sender, _Broadcast()) runner.add_filter(fltr, session) class _TaskRunner: def __init__(self, tasks: anyio.abc.TaskGroup): self.tasks = tasks self.filters = list[tuple[Filter, Session]]() self.channels = dict[MessageChannel, Filter]() async def __aenter__(self) -> Self: return self class SessionRunner: async def __aexit__(self, *_: object) -> None: await self.aclose() def __init__(self, session: Session): self.session = session self.filters = dict[Filter, FinalResponse|None]() def add_filter(self, flter: Filter, session: Session, /) -> None: self.filters.append((flter, session)) async def start(self, filters: Iterable[Filter], task_group: anyio.abc.TaskGroup) -> None: """ Run all the given filters in a task group async def start( self, needs_response: bool, first_connect: bool, use_skip: bool, ) -> ResponseMessage|_CloseFilter|None: if self.channels: raise RuntimeError(f"{self} is already running tasks") final: ResponseMessage = Accept() for flter, session in self.filters: lchannel, rchannel = _make_message_channel() self.channels[lchannel] = flter match await self.tasks.start(self._runner, flter, session, rchannel, use_skip): case Accept(): del self.channels[lchannel] case Continue(): continue case TemporaryFailure() as final: # replaces final pass case Reject()|Discard()|ReplyCode() as resp: if not first_connect: _logger.warning( f"Ignoring unexpected response from filter after restart: " f"{qualname(flter)} -> {resp}", ) continue if not needs_response: _logger.warning( f"Unexpected response from filter {qualname(flter)}", ) return _CloseFilter(flter) return resp case _ as arg: # pragma: no-cover raise TypeError(f"task_status.started called with bad type: {arg!r}") if not needs_response: return None return final if len(self.channels) == 0 else Continue() async def set_macros(self, message: Macro) -> None: if self.channels: for channel in self.channels: await channel.send(message) else: for _, session in self.filters: await session.deliver(message) The session MUST have been primed by the delivery of a Connect message beforehand or filters will be unable to access the connection details. """ _logger.debug("Starting filters") for flter in filters: await task_group.start(self.run_filter, flter) async def message_events( async def run_filter( self, message: _VALID_EVENT_MESSAGE, needs_response: bool, ) -> ResponseMessage|Skip|_CloseFilter|None: skip = isinstance(message, Body) for channel in list(self.channels): await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept() as resp: flter = await self.close_channel(channel) if len(self.channels) == 0: _logger.info(f"Returning response Accept from {qualname(flter)}") return resp _logger.info(f"Holding response Accept from {qualname(flter)}") case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: flter = await self.close_channel(channel) if not needs_response: _logger.warning(f"Unexpected response from filter {qualname(flter)}") return _CloseFilter(flter) _logger.info(f"Returning response {type(resp).__name__} from {qualname(flter)}") return resp assert len(self.channels) > 0, "Running filters reached zero without a response?!" if not needs_response: return None return Skip() if skip else Continue() async def close_channel(self, channel: MessageChannel) -> Filter: await channel.aclose() return self.channels.pop(channel) async def abort(self, abort: Abort) -> None: if not self.channels: return _logger.info("Aborting filters") for channel in self.channels: await channel.send(abort) await channel.receive() await channel.aclose() self.channels.clear() async def aclose(self) -> None: if self.channels: _logger.info("Closing filters") self.tasks.cancel_scope.cancel() self.channels.clear() @staticmethod async def _runner( fltr: Filter, session: Session, channel: MessageChannel, use_skip: bool, *, task_status: anyio.abc.TaskStatus[ResponseMessage], ) -> None: final_resp: ResponseMessage|None = None async def _filter_wrap( flter: Filter, task_status: anyio.abc.TaskStatus[None], ) -> None: nonlocal final_resp async with session: assert isinstance(session.broadcast, _Broadcast) session.broadcast.task_status = task_status """ Run a filter as a subtask in a task group A `Future` for returning the filter's response is added to the `SessionRunner.filter` dict. """ if flter in self.filters: raise RuntimeError self.filters[flter] = None async with self.session: assert isinstance(self.session.broadcast, _Broadcast) status_notifiers = self.session.broadcast.task_status status_notifiers.append(task_status) try: final_resp = await fltr(session) resp: FinalResponse = await flter(self.session) except Aborted: _logger.debug(f"Aborted filter {qualname(fltr)}") _logger.debug(f"Aborted filter {qualname(flter)}") del self.filters[flter] return except Exception: _logger.exception(f"Error in filter {qualname(fltr)}") final_resp = TemporaryFailure() if not isinstance(final_resp, FilterResponse): warn(f"expected a valid response from {qualname(fltr)}, got {final_resp}") final_resp = TemporaryFailure() async with anyio.create_task_group() as tasks: await tasks.start(_filter_wrap) task_status.started(final_resp or Continue()) while final_resp is None: try: message = await channel.receive() except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return if isinstance(message, Macro): await session.deliver(message) _logger.exception(f"Error in filter {qualname(flter)}") resp = TemporaryFailure() if not isinstance(resp, FinalResponse): warn(f"expected a valid response from {qualname(flter)}, got {resp}") # type: ignore # Don't fully trust users… resp = TemporaryFailure() self.filters[flter] = resp if task_status in status_notifiers: status_notifiers.remove(task_status) task_status.started() async def check_response(self) -> ResponseMessage|None: assert self.filters, "no filters when checking for a response" response: ResponseMessage|None = None complete = list[Filter]() for flter, result in self.filters.items(): # If a filter has not finished or no response is expected, continue without # removing from filter container; remove failed filters and filters that have # accepted; return a response for rejections; match result: case None: continue assert isinstance(message, _VALID_EVENT_MESSAGE) resp = await session.deliver(message) if isinstance(message, Abort): await channel.send(Continue()) await channel.aclose() return if final_resp is not None: break # type: ignore[unreachable] await channel.send(Skip() if use_skip and resp == Skip else Continue()) await channel.send(final_resp) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream[Message](1) rsend, lrecv = anyio.create_memory_object_stream[Message](1) return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) case Accept(): _logger.info("Accept from %s, waiting for remaining", qualname(flter)) case TemporaryFailure() as response: _logger.warning("Filter failed: %s", flter) case Reject()|Discard()|ReplyCode() as response: _logger.info("Returning response %s from %s", type(response).__name__, qualname(flter)) complete[:] = self.filters break case msg: raise AssertionError(f"unexpected filter result: {msg}") complete.append(flter) for flter in complete: del self.filters[flter] return response if response else None if self.filters else Accept()
kilter/service/session.py +27 −5 Original line number Diff line number Diff line Loading @@ -67,6 +67,12 @@ class Phase(int, Enum): raised by `Session` methods. """ INIT = 0 """ This phase is the pre-connected phase of a session; this phase will be completed before users see the session object. """ CONNECT = 1 """ This phase is the starting phase of a session, during which a HELO/EHLO message may be Loading Loading @@ -201,13 +207,12 @@ class Session: def __init__( self, connmsg: Connect, sender: Sender, broadcast: util.Broadcast[EventMessage]|None = None, ): self.host = connmsg.hostname self.address = connmsg.address self.port = connmsg.port self.host = "" self.address = None self.port = 0 self.sender = sender self.broadcast = broadcast or util.Broadcast[EventMessage]() Loading @@ -218,7 +223,9 @@ class Session: # Phase checking is a bit fuzzy as a filter may not request every message, # so some phases will be skipped; checks should not try to exactly match a phase. self.phase = Phase.CONNECT self.phase = Phase.INIT self._helo: Helo|None = None async def __aenter__(self) -> Self: await self.broadcast.__aenter__() Loading @@ -229,11 +236,22 @@ class Session: # on session close, wake up any remaining deliver() awaitables await self.broadcast.shutdown_hook() def _reset(self) -> None: self.headers = HeadersAccessor(self, self.sender) self.body = BodyAccessor(self, self.sender) async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]: """ Deliver a message (or its contents) to a task waiting for it """ match message: case Connect(): self.host = message.hostname self.address = message.address self.port = message.port async with self.broadcast: self.phase = Phase.CONNECT return Continue case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something Loading @@ -241,6 +259,7 @@ class Session: async with self.broadcast: self.phase = Phase.CONNECT await self.broadcast.abort(Aborted) self._reset() return Continue case Helo(): phase = Phase.MAIL Loading @@ -266,9 +285,12 @@ class Session: "Session.helo() must be awaited before any other async features of a " "Session", ) if self._helo: return self._helo.hostname while self.phase <= Phase.CONNECT: message = await self.broadcast.receive() if isinstance(message, Helo): self._helo = message return message.hostname raise RuntimeError("HELO/EHLO event not received") Loading
tests/test_body_accessor.py +4 −4 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ class BodyAccessorTests(AsyncTestCase): """ Check that the body iterator works as expected """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = b"" @with_session(session) Loading @@ -47,7 +47,7 @@ class BodyAccessorTests(AsyncTestCase): """ Check that Body (and EOM) messages are skipped after breaking out of a loop """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result1 = b"" result2 = b"" Loading Loading @@ -78,7 +78,7 @@ class BodyAccessorTests(AsyncTestCase): Check that `write()` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) @with_session(session) async def test_filter() -> None: Loading @@ -96,7 +96,7 @@ class BodyAccessorTests(AsyncTestCase): Check that `write()` in an async with context issues a warning """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) @with_session(session) async def test_filter() -> None: Loading
tests/test_header_accessor.py +17 −17 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that header iterator works as expected """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -50,7 +50,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that all headers are collected when breaking out of a loop """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result1 = [] result2 = [] Loading Loading @@ -88,7 +88,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that all headers are collected when awaiting `collect()` """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -114,7 +114,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that all headers are collected when awaiting `collect()` if EOH is missed """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -140,7 +140,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that `restrict()` works as expected """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) result = [] @with_session(session) Loading @@ -165,7 +165,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `delete()` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -198,7 +198,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `update()` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -231,7 +231,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., START)` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -261,7 +261,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., END)` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -291,7 +291,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., Before(...))` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -324,7 +324,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., After(...))` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -357,7 +357,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that `insert(..., After(<last header>))` works as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -390,7 +390,7 @@ class HeaderAccessorTests(AsyncTestCase): Check that multiple edits in a filter work as expected """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) session = Session(sender) result = [] @with_session(session) Loading Loading @@ -430,7 +430,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `asend()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading @@ -449,7 +449,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `athrow()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading @@ -471,7 +471,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `athrow()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading @@ -494,7 +494,7 @@ class HeaderAccessorTests(AsyncTestCase): """ Check that the AsyncGenerator-required method `athrow()` works """ session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor()) session = Session(MockEditor()) @with_session(session) async def test_filter() -> None: Loading
tests/test_session.py +29 −52 File changed.Preview size limit exceeded, changes collapsed. Show changes