Loading kilter/service/runner.py +167 −102 Original line number Diff line number Diff line # Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2022-2023 Dominik Sekotill <dom.sekotill@kodo.org.uk> # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this Loading @@ -16,6 +16,9 @@ from __future__ import annotations import logging from collections.abc import AsyncGenerator from typing import TYPE_CHECKING from typing import TypeAlias from typing import TypeVar from warnings import warn import anyio.abc Loading @@ -37,8 +40,8 @@ kiB = 2**10 MiB = 2**20 _VALID_FINAL_RESPONSES = Reject, Discard, Accept, TemporaryFailure, ReplyCode _VALID_EVENT_MESSAGE = Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, \ Header, EndOfHeaders, Body, EndOfMessage _VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \ Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort _DISABLE_PROTOCOL_FLAGS = ProtocolFlags.NO_CONNECT | ProtocolFlags.NO_HELO | \ ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | ProtocolFlags.NO_BODY | \ ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_EOH | ProtocolFlags.NO_UNKNOWN | \ Loading Loading @@ -90,12 +93,15 @@ class Runner: buff = SimpleBuffer(1*MiB) proto = FilterProtocol() sender = _sender(client, proto) channels = list[MessageChannel]() macro: Macro|None = None await sender.asend(None) # type: ignore # initialise async with anyio.create_task_group() as tasks, aclosing(sender), aclosing(client): async with ( anyio.create_task_group() as tasks, aclosing(sender), aclosing(client), _TaskRunner(tasks) as runner, ): while 1: try: buff[:] = await client.receive(buff.available) Loading @@ -104,49 +110,33 @@ class Runner: anyio.ClosedResourceError, anyio.BrokenResourceError, ): for channel in channels: await channel.aclose() await runner.aclose() return for message in proto.read_from(buff): match message: case Negotiate(): await self._negotiate(message, sender) await sender.asend(await self._negotiate(message, sender)) case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. for channel in channels: await channel.send(macro) await runner.set_macros(macro) case Connect(): channels[:] = await self._connect(message, sender, tasks, macro) await sender.asend(await self._connect(message, sender, runner, macro)) case Abort(): for channel in channels: await channel.aclose() await runner.abort(message) await runner.start(False, self.use_skip) case Close(): await runner.aclose() return case _: assert isinstance(message, _VALID_EVENT_MESSAGE) skip = isinstance(message, Body) for channel in channels: await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept(): await channel.aclose() channels.remove(channel) case resp: await sender.asend(resp) break else: await sender.asend( Accept() if len(channels) == 0 else Skip() if skip else Continue(), ) async def _negotiate(self, message: Negotiate, sender: Sender) -> None: # TODO: Upgrade and remove ignores once python/mypy#14242 is in # TODO: Should remove assert once kilter.protocol#5 is resolved # Type narrowing should do the job adequately # https://code.kodo.org.uk/kilter/kilter.protocol/-/issues/5 assert isinstance(message, _VALID_EVENT_MESSAGE) # type: ignore[misc,arg-type] await sender.asend(await runner.message_events(message)) # type: ignore[arg-type] async def _negotiate(self, message: Negotiate, sender: Sender) -> Negotiate: # TODO: actually negotiate what the filter wants, not just "everything" actions = set(ActionFlags) # All actions! if actions != ActionFlags.unpack(message.action_flags): Loading @@ -156,59 +146,112 @@ class Runner: resp.protocol_flags = message.protocol_flags & ~_DISABLE_PROTOCOL_FLAGS resp.action_flags = ActionFlags.pack(actions) await sender.asend(resp) self.use_skip = bool(resp.protocol_flags & ProtocolFlags.SKIP) return resp async def _connect( self, message: Connect, sender: Sender, tasks: anyio.abc.TaskGroup, runner: _TaskRunner, macro: Macro|None, ) -> list[MessageChannel]: channels = list[MessageChannel]() ) -> ResponseMessage: for fltr in self.filters: lchannel, rchannel = _make_message_channel() channels.append(lchannel) session = Session(message, sender, _Broadcast()) runner.add_filter(fltr, session) if macro: await session.deliver(macro) match await tasks.start( _runner, fltr, session, rchannel, self.use_skip, ): await runner.set_macros(macro) return await runner.start(True, self.use_skip) class _TaskRunner: if TYPE_CHECKING: Self = TypeVar("Self", bound="_TaskRunner") def __init__(self, tasks: anyio.abc.TaskGroup): self.tasks = tasks self.filters = list[tuple[Filter, Session]]() self.channels = list[MessageChannel]() async def __aenter__(self: Self) -> Self: return self async def __aexit__(self, *_: object) -> None: await self.aclose() def add_filter(self, flter: Filter, session: Session, /) -> None: self.filters.append((flter, session)) async def start(self, first_connect: bool, use_skip: bool) -> ResponseMessage: if self.channels: raise RuntimeError(f"{self} is already running tasks") final: ResponseMessage = Accept() for flter, session in self.filters: lchannel, rchannel = _make_message_channel() self.channels.append(lchannel) match await self.tasks.start(self._runner, flter, session, rchannel, first_connect, use_skip): case Accept(): self.channels.remove(lchannel) case Continue(): continue case Message() as resp: await sender.asend(resp) return [] case TemporaryFailure() as final: # replaces final pass case Reject()|Discard()|ReplyCode() as resp: if not first_connect: logging.warning("Unexpected response from filter after restart") continue return resp case _ as arg: # pragma: no-cover raise TypeError( f"task_status.started called with bad type: " f"{arg!r}", ) await sender.asend(Continue()) return channels raise TypeError(f"task_status.started called with bad type: {arg!r}") return final if len(self.channels) == 0 else Continue() async def set_macros(self, message: Macro) -> None: if self.channels: for channel in self.channels: await channel.send(message) else: for _, session in self.filters: await session.deliver(message) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage: skip = isinstance(message, Body) for channel in self.channels: await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept(): await channel.aclose() self.channels.remove(channel) case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: return resp return ( Accept() if len(self.channels) == 0 else Skip() if skip else Continue() ) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (yield)) await client.send(buff[:]) del buff[:] async def abort(self, abort: Abort) -> None: for channel in self.channels: await channel.send(abort) await channel.receive() await channel.aclose() del self.channels[:] async def aclose(self) -> None: self.tasks.cancel_scope.cancel() del self.channels[:] @staticmethod async def _runner( fltr: Filter, session: Session, channel: MessageChannel, first_connect: bool, use_skip: bool, *, task_status: anyio.abc.TaskStatus, ) -> None: Loading @@ -223,6 +266,9 @@ async def _runner( session.broadcast.task_status = task_status try: final_resp = await fltr(session) except Aborted: logging.info(f"aborted filter {qualname(fltr)}") return except Exception: logging.exception(f"error in filter {qualname(fltr)}") final_resp = TemporaryFailure() Loading @@ -242,9 +288,28 @@ async def _runner( if isinstance(message, Macro): await session.deliver(message) continue assert isinstance(message, _VALID_EVENT_MESSAGE) resp = await session.deliver(message) # TODO: Upgrade and remove ignores once python/mypy#14242 is in assert isinstance(message, _VALID_EVENT_MESSAGE) # type: ignore[misc,arg-type] resp = await session.deliver(message) # type: ignore[arg-type] if final_resp is not None: break # type: ignore if isinstance(message, Abort): await channel.send(Continue()) await channel.aclose() return await channel.send(Skip() if use_skip and resp == Skip else Continue()) await channel.send(final_resp) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (yield)) await client.send(buff[:]) del buff[:] Loading
kilter/service/runner.py +167 −102 Original line number Diff line number Diff line # Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2022-2023 Dominik Sekotill <dom.sekotill@kodo.org.uk> # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this Loading @@ -16,6 +16,9 @@ from __future__ import annotations import logging from collections.abc import AsyncGenerator from typing import TYPE_CHECKING from typing import TypeAlias from typing import TypeVar from warnings import warn import anyio.abc Loading @@ -37,8 +40,8 @@ kiB = 2**10 MiB = 2**20 _VALID_FINAL_RESPONSES = Reject, Discard, Accept, TemporaryFailure, ReplyCode _VALID_EVENT_MESSAGE = Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, \ Header, EndOfHeaders, Body, EndOfMessage _VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \ Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort _DISABLE_PROTOCOL_FLAGS = ProtocolFlags.NO_CONNECT | ProtocolFlags.NO_HELO | \ ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | ProtocolFlags.NO_BODY | \ ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_EOH | ProtocolFlags.NO_UNKNOWN | \ Loading Loading @@ -90,12 +93,15 @@ class Runner: buff = SimpleBuffer(1*MiB) proto = FilterProtocol() sender = _sender(client, proto) channels = list[MessageChannel]() macro: Macro|None = None await sender.asend(None) # type: ignore # initialise async with anyio.create_task_group() as tasks, aclosing(sender), aclosing(client): async with ( anyio.create_task_group() as tasks, aclosing(sender), aclosing(client), _TaskRunner(tasks) as runner, ): while 1: try: buff[:] = await client.receive(buff.available) Loading @@ -104,49 +110,33 @@ class Runner: anyio.ClosedResourceError, anyio.BrokenResourceError, ): for channel in channels: await channel.aclose() await runner.aclose() return for message in proto.read_from(buff): match message: case Negotiate(): await self._negotiate(message, sender) await sender.asend(await self._negotiate(message, sender)) case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. for channel in channels: await channel.send(macro) await runner.set_macros(macro) case Connect(): channels[:] = await self._connect(message, sender, tasks, macro) await sender.asend(await self._connect(message, sender, runner, macro)) case Abort(): for channel in channels: await channel.aclose() await runner.abort(message) await runner.start(False, self.use_skip) case Close(): await runner.aclose() return case _: assert isinstance(message, _VALID_EVENT_MESSAGE) skip = isinstance(message, Body) for channel in channels: await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept(): await channel.aclose() channels.remove(channel) case resp: await sender.asend(resp) break else: await sender.asend( Accept() if len(channels) == 0 else Skip() if skip else Continue(), ) async def _negotiate(self, message: Negotiate, sender: Sender) -> None: # TODO: Upgrade and remove ignores once python/mypy#14242 is in # TODO: Should remove assert once kilter.protocol#5 is resolved # Type narrowing should do the job adequately # https://code.kodo.org.uk/kilter/kilter.protocol/-/issues/5 assert isinstance(message, _VALID_EVENT_MESSAGE) # type: ignore[misc,arg-type] await sender.asend(await runner.message_events(message)) # type: ignore[arg-type] async def _negotiate(self, message: Negotiate, sender: Sender) -> Negotiate: # TODO: actually negotiate what the filter wants, not just "everything" actions = set(ActionFlags) # All actions! if actions != ActionFlags.unpack(message.action_flags): Loading @@ -156,59 +146,112 @@ class Runner: resp.protocol_flags = message.protocol_flags & ~_DISABLE_PROTOCOL_FLAGS resp.action_flags = ActionFlags.pack(actions) await sender.asend(resp) self.use_skip = bool(resp.protocol_flags & ProtocolFlags.SKIP) return resp async def _connect( self, message: Connect, sender: Sender, tasks: anyio.abc.TaskGroup, runner: _TaskRunner, macro: Macro|None, ) -> list[MessageChannel]: channels = list[MessageChannel]() ) -> ResponseMessage: for fltr in self.filters: lchannel, rchannel = _make_message_channel() channels.append(lchannel) session = Session(message, sender, _Broadcast()) runner.add_filter(fltr, session) if macro: await session.deliver(macro) match await tasks.start( _runner, fltr, session, rchannel, self.use_skip, ): await runner.set_macros(macro) return await runner.start(True, self.use_skip) class _TaskRunner: if TYPE_CHECKING: Self = TypeVar("Self", bound="_TaskRunner") def __init__(self, tasks: anyio.abc.TaskGroup): self.tasks = tasks self.filters = list[tuple[Filter, Session]]() self.channels = list[MessageChannel]() async def __aenter__(self: Self) -> Self: return self async def __aexit__(self, *_: object) -> None: await self.aclose() def add_filter(self, flter: Filter, session: Session, /) -> None: self.filters.append((flter, session)) async def start(self, first_connect: bool, use_skip: bool) -> ResponseMessage: if self.channels: raise RuntimeError(f"{self} is already running tasks") final: ResponseMessage = Accept() for flter, session in self.filters: lchannel, rchannel = _make_message_channel() self.channels.append(lchannel) match await self.tasks.start(self._runner, flter, session, rchannel, first_connect, use_skip): case Accept(): self.channels.remove(lchannel) case Continue(): continue case Message() as resp: await sender.asend(resp) return [] case TemporaryFailure() as final: # replaces final pass case Reject()|Discard()|ReplyCode() as resp: if not first_connect: logging.warning("Unexpected response from filter after restart") continue return resp case _ as arg: # pragma: no-cover raise TypeError( f"task_status.started called with bad type: " f"{arg!r}", ) await sender.asend(Continue()) return channels raise TypeError(f"task_status.started called with bad type: {arg!r}") return final if len(self.channels) == 0 else Continue() async def set_macros(self, message: Macro) -> None: if self.channels: for channel in self.channels: await channel.send(message) else: for _, session in self.filters: await session.deliver(message) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage: skip = isinstance(message, Body) for channel in self.channels: await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept(): await channel.aclose() self.channels.remove(channel) case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: return resp return ( Accept() if len(self.channels) == 0 else Skip() if skip else Continue() ) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (yield)) await client.send(buff[:]) del buff[:] async def abort(self, abort: Abort) -> None: for channel in self.channels: await channel.send(abort) await channel.receive() await channel.aclose() del self.channels[:] async def aclose(self) -> None: self.tasks.cancel_scope.cancel() del self.channels[:] @staticmethod async def _runner( fltr: Filter, session: Session, channel: MessageChannel, first_connect: bool, use_skip: bool, *, task_status: anyio.abc.TaskStatus, ) -> None: Loading @@ -223,6 +266,9 @@ async def _runner( session.broadcast.task_status = task_status try: final_resp = await fltr(session) except Aborted: logging.info(f"aborted filter {qualname(fltr)}") return except Exception: logging.exception(f"error in filter {qualname(fltr)}") final_resp = TemporaryFailure() Loading @@ -242,9 +288,28 @@ async def _runner( if isinstance(message, Macro): await session.deliver(message) continue assert isinstance(message, _VALID_EVENT_MESSAGE) resp = await session.deliver(message) # TODO: Upgrade and remove ignores once python/mypy#14242 is in assert isinstance(message, _VALID_EVENT_MESSAGE) # type: ignore[misc,arg-type] resp = await session.deliver(message) # type: ignore[arg-type] if final_resp is not None: break # type: ignore if isinstance(message, Abort): await channel.send(Continue()) await channel.aclose() return await channel.send(Skip() if use_skip and resp == Skip else Continue()) await channel.send(final_resp) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (yield)) await client.send(buff[:]) del buff[:]