Loading kilter/service/__init__.py +2 −0 Original line number Diff line number Diff line Loading @@ -8,6 +8,7 @@ project). The framework aims to provide Pythonic interfaces for implementing fi including leveraging coroutines instead of libmilter's callback-style interface. """ from .runner import Runner from .session import END from .session import START from .session import After Loading @@ -22,6 +23,7 @@ __all__ = [ "Before", "END", "ResponseMessage", "Runner", "START", "Session", ] kilter/service/runner.py 0 → 100644 +242 −0 Original line number Diff line number Diff line # Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk> # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. """ Coordinate receiving and sending raw messages with a filter and Session object The primary class in this module (`Runner`) is intended to be used with an `anyio.abc.Listener`, which can be obtained, for instance, from `anyio.create_tcp_listener()`. """ from __future__ import annotations from collections.abc import AsyncGenerator from warnings import warn import anyio.abc from anyio.streams.stapled import StapledObjectStream from async_generator import aclosing from kilter.protocol.buffer import SimpleBuffer from kilter.protocol.core import FilterProtocol from kilter.protocol.messages import ProtocolFlags from .session import * from .util import Broadcast MessageChannel = anyio.abc.ObjectStream[Message] Sender = AsyncGenerator[None, Message] kiB = 2**10 MiB = 2**20 class NegotiationError(Exception): """ An error raised when MTAs are not compatible with the filter """ class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: super().__init__() self._ready = anyio.Condition() async def aclose(self) -> None: async with self._ready: self._ready.notify_all() async def pre_receive_hook(self) -> None: async with self._ready: self._ready.notify_all() async def post_send_hook(self) -> None: # Await notification of either a receiver waiting or the broadcaster closing # This is necessary to delay returning until a filter has had a chance to return # a result. async with self._ready: await self._ready.wait() class Runner: """ A filter runner that coordinates passing data between a stream and multiple filters Instances can be used as handlers that can be passed to `anyio.abc.Listener.serve()` or used with any `anyio.abc.ByteStream`. """ def __init__(self, *filters: Filter): if len(filters) == 0: # pragma: no-cover raise TypeError("Runner requires at least one filter to run") self.filters = filters self.use_skip = True async def __call__(self, client: anyio.abc.ByteStream) -> None: """ Return an awaitable that starts and coordinates filters """ buff = SimpleBuffer(1*MiB) proto = FilterProtocol() sender = _sender(client, proto) channels = list[MessageChannel]() await sender.asend(None) # type: ignore # initialise async with anyio.create_task_group() as tasks, aclosing(sender), aclosing(client): while 1: try: buff[:] = await client.receive(buff.available) except (anyio.EndOfStream, anyio.ClosedResourceError): for channel in channels: await channel.aclose() return for message in proto.read_from(buff): match message: case Negotiate(): await self._negotiate(message, sender) case Macro(): # TODO: implement macro support ... case Connect(): channels[:] = await self._connect(message, sender, tasks) case Abort(): for channel in channels: await channel.aclose() case Close(): return case _: assert isinstance( message, ( Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, ), ) skip = isinstance(message, Body) for channel in channels: await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept(): await channel.aclose() channels.remove(channel) case resp: await sender.asend(resp) break else: await sender.asend( Accept() if len(channels) == 0 else Skip() if skip else Continue(), ) async def _negotiate(self, message: Negotiate, sender: Sender) -> None: # TODO: actually negotiate what the filter wants, not just "everything" actions = set(ActionFlags) # All actions! if actions != ActionFlags.unpack(message.action_flags): raise NegotiationError("MTA does not accept all actions required by the filter") resp = Negotiate(6, 0, 0) resp.protocol_flags = message.protocol_flags resp.action_flags = ActionFlags.pack(actions) await sender.asend(resp) self.use_skip = bool(resp.protocol_flags & ProtocolFlags.SKIP) async def _connect( self, message: Connect, sender: Sender, tasks: anyio.abc.TaskGroup, ) -> list[MessageChannel]: channels = list[MessageChannel]() for fltr in self.filters: lchannel, rchannel = _make_message_channel() channels.append(lchannel) session = Session(message, sender, _Broadcast()) match await tasks.start( _runner, fltr, session, rchannel, self.use_skip, ): case Continue(): continue case Message() as resp: await sender.asend(resp) return [] case _ as arg: # pragma: no-cover raise TypeError( f"task_status.started called with bad type: " f"{arg!r}", ) await sender.asend(Continue()) return channels def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (yield)) await client.send(buff[:]) del buff[:] _VALID_FINAL_RESPONSES = Reject, Discard, Accept, TemporaryFailure, ReplyCode async def _runner( fltr: Filter, session: Session, channel: MessageChannel, use_skip: bool, *, task_status: anyio.abc.TaskStatus, ) -> None: final_resp: ResponseMessage|None = None async def _filter_wrap( task_status: anyio.abc.TaskStatus, ) -> None: nonlocal final_resp async with session: task_status.started() final_resp = await fltr(session) if not isinstance(final_resp, _VALID_FINAL_RESPONSES): warn(f"expected a final response from {fltr}, got {final_resp}") final_resp = TemporaryFailure() async with anyio.create_task_group() as tasks: await tasks.start(_filter_wrap) task_status.started(final_resp or Continue()) while final_resp is None: try: message = await channel.receive() except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return assert isinstance( message, ( Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, ), ) resp = await session.deliver(message) if final_resp is not None: await channel.send(final_resp) # type: ignore elif use_skip and resp == Skip: await channel.send(Skip()) else: await channel.send(Continue()) tests/mock_stream.py 0 → 100644 +149 −0 Original line number Diff line number Diff line from __future__ import annotations import typing from collections.abc import AsyncGenerator from collections.abc import AsyncIterator from collections.abc import Callable from contextlib import asynccontextmanager from functools import wraps from types import TracebackType from typing import TYPE_CHECKING from typing import AsyncContextManager from typing import TypeVar import anyio from anyio.streams.buffered import BufferedByteReceiveStream from anyio.streams.stapled import StapledByteStream from anyio.streams.stapled import StapledObjectStream from async_generator import aclosing from kilter.protocol import * from kilter.protocol.buffer import SimpleBuffer from kilter.service import ResponseMessage P = typing.ParamSpec("P") SendT = typing.TypeVar("SendT") YieldT = typing.TypeVar("YieldT") def _make_aclosing( func: Callable[P, AsyncGenerator[YieldT, SendT]], ) -> Callable[P, AsyncContextManager[AsyncGenerator[YieldT, SendT]]]: @wraps(func) @asynccontextmanager async def wrap(*a: P.args, **k: P.kwargs) -> AsyncIterator[AsyncGenerator[YieldT, SendT]]: agen = func(*a, **k) async with aclosing(agen): yield agen return wrap class MockMessageStream: """ A mock of the right-side of an `anyio.abc.ByteStream` with test support on the left side """ if TYPE_CHECKING: Self = TypeVar("Self", bound="MockMessageStream") def __init__(self) -> None: self.buffer = SimpleBuffer(1024) self.closed = False async def __aenter__(self: Self) -> Self: send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes) send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes) self._stream = StapledObjectStream(send_obj, recv_obj) self.peer_stream = StapledByteStream( send_bytes, # type: ignore BufferedByteReceiveStream(recv_bytes), ) await self._stream.__aenter__() await self.peer_stream.__aenter__() return self async def __aexit__( self, et: type[BaseException]|None = None, ex: BaseException|None = None, tb: TracebackType|None = None, ) -> None: if not self.closed: if et is not None: await self.abort() else: await self.close() await self._stream.__aexit__(et, ex, tb) await self.peer_stream.__aexit__(et, ex, tb) async def abort(self) -> None: """ Send Abort and close the stream """ try: resp = await self.send_msg(Abort()) except anyio.BrokenResourceError: return assert len(resp) == 0, resp await self.close() async def close(self) -> None: """ Send Close and close the stream """ if self.closed: return resp = await self.send_msg(Close()) assert len(resp) == 0, resp await self._stream.aclose() self.closed = True async def send_msg(self, msg: Message) -> list[Message]: """ Send a message and return the messages sent in response """ responses = [] async with self._send_msg(msg) as aiter: async for resp in aiter: responses.append(resp) return responses @_make_aclosing async def _send_msg(self, msg: Message) -> AsyncGenerator[Message, None]: buff = self.buffer msg.pack(buff) await self._stream.send(buff[:].tobytes()) del buff[:] if isinstance(msg, (Abort, Close)): return while 1: try: buff[:] = chunk = await self._stream.receive() except anyio.EndOfStream: break if len(chunk) == 0: break try: msg, size = Message.unpack(buff) except NeedsMore: continue del buff[:size] yield msg if isinstance(msg, typing.get_args(ResponseMessage) + (Negotiate, Skip)): break assert buff.filled == 0, buff[:].tobytes() async def send_and_expect(self, msg: Message, *exp: type[Message]|Message) -> None: """ Send a message and check the responses by type or equality """ resp = await self.send_msg(msg) assert len(resp) == len(exp), resp for r, e in zip(resp, exp): if isinstance(e, type): assert isinstance(r, e), f"expected {e}, got {type(r)}" else: assert r == e, r tests/test_runner.py 0 → 100644 +306 −0 Original line number Diff line number Diff line import trio.testing from kilter.protocol import * from kilter.service import Runner from kilter.service import Session from . import AsyncTestCase from .mock_stream import MockMessageStream class RunnerTests(AsyncTestCase): """ Tests for the Runner class """ async def test_helo(self) -> None: """ Check that awaiting Session.helo() responds to Connect with Continue """ hostname = "" @Runner async def test_filter(session: Session) -> Accept: nonlocal hostname hostname = await session.helo() return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) assert hostname == "" await stream_mock.send_and_expect(Helo("test.example.com"), Accept) assert hostname == "test.example.com" async def test_respond_to_peer(self) -> None: """ Check that returning before engaging with async session features works """ @Runner async def test_filter(session: Session) -> Reject: return Reject() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Reject) async def test_post_header(self) -> None: """ Check that delaying return until a phase later than CONNECT sends Continue """ @Runner async def test_filter(session: Session) -> Accept: assert "test@example.com" == await session.envelope_from() return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Helo("test.example.com"), Continue) await stream_mock.send_and_expect(EnvelopeFrom(b"test@example.com"), Accept) async def test_body_all(self) -> None: """ Check that the whole body is processes when Continue is passed """ contents = b"" @Runner async def test_filter(session: Session) -> Accept: nonlocal contents async with session.body as body: async for chunk in body: await trio.sleep(0) contents += chunk return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Continue) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue) await stream_mock.send_and_expect(EndOfMessage(b"Bye"), Accept) assert contents == b"This is a message sent in multiple chunks. Bye", contents async def test_body_skip(self) -> None: """ Check that Skip is returned once a body loop is broken """ contents = b"" @Runner async def test_filter(session: Session) -> Accept: nonlocal contents async with session.body as body: async for chunk in body: contents += chunk if b"message" in chunk.tobytes(): break # Move phase onto POST await session.change_sender("test@example.com") return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, ProtocolFlags.SKIP), Negotiate(6, 0x1ff, ProtocolFlags.SKIP), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Skip) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Skip) await stream_mock.send_and_expect( EndOfMessage(b"Bye"), ChangeSender("test@example.com"), Accept, ) assert contents == b"This is a message sent ", contents async def test_body_fake_skip(self) -> None: """ Check that Skip is NOT returned if not accepted by an MTA """ contents = b"" @Runner async def test_filter(session: Session) -> Accept: nonlocal contents async with session.body as body: async for chunk in body: contents += chunk if b"message" in chunk.tobytes(): break # Move phase onto POST await session.change_sender("test@example.com") return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Continue) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue) await stream_mock.send_and_expect( EndOfMessage(b"Bye"), ChangeSender("test@example.com"), Accept, ) assert contents == b"This is a message sent ", contents async def test_multiple(self) -> None: """ Check that multiple filters receive the messages they expect """ hostname = "" contents1 = b"" contents2 = b"" async def test_filter1(session: Session) -> Reject: nonlocal hostname nonlocal contents1 hostname = await session.helo() async with session.body as body: async for chunk in body: await trio.sleep(0) contents1 += chunk return Reject() async def test_filter2(session: Session) -> Accept: nonlocal contents2 async with session.body as body: async for chunk in body: await trio.sleep(0) contents2 += chunk if b"message" in chunk.tobytes(): break return Accept() runner = Runner(test_filter1, test_filter2) async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(runner, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, ProtocolFlags.SKIP), Negotiate(6, 0x1ff, ProtocolFlags.SKIP), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Helo("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Continue) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue) await stream_mock.send_and_expect(EndOfMessage(b"Bye"), Reject) assert hostname == "test.example.com", hostname assert contents1 == b"This is a message sent in multiple chunks. Bye", contents1 assert contents2 == b"This is a message sent ", contents2 async def test_abort(self) -> None: """ Check that a runner closes cleanly when it receives an Abort """ cancelled = False @Runner async def test_filter(session: Session) -> Accept: nonlocal cancelled try: await session.helo() except trio.Cancelled: cancelled = True raise return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.abort() assert cancelled async def test_bad_response(self) -> None: """ Check that a runner closes cleanly when it receives an Abort """ @Runner async def test_filter(session: Session) -> Skip: await session.helo() return Skip() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) with self.assertWarns(UserWarning) as wcm: await stream_mock.send_and_expect(Helo("test.example.com"), TemporaryFailure) assert "expected a final response" in str(wcm.warning) Loading
kilter/service/__init__.py +2 −0 Original line number Diff line number Diff line Loading @@ -8,6 +8,7 @@ project). The framework aims to provide Pythonic interfaces for implementing fi including leveraging coroutines instead of libmilter's callback-style interface. """ from .runner import Runner from .session import END from .session import START from .session import After Loading @@ -22,6 +23,7 @@ __all__ = [ "Before", "END", "ResponseMessage", "Runner", "START", "Session", ]
kilter/service/runner.py 0 → 100644 +242 −0 Original line number Diff line number Diff line # Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk> # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. """ Coordinate receiving and sending raw messages with a filter and Session object The primary class in this module (`Runner`) is intended to be used with an `anyio.abc.Listener`, which can be obtained, for instance, from `anyio.create_tcp_listener()`. """ from __future__ import annotations from collections.abc import AsyncGenerator from warnings import warn import anyio.abc from anyio.streams.stapled import StapledObjectStream from async_generator import aclosing from kilter.protocol.buffer import SimpleBuffer from kilter.protocol.core import FilterProtocol from kilter.protocol.messages import ProtocolFlags from .session import * from .util import Broadcast MessageChannel = anyio.abc.ObjectStream[Message] Sender = AsyncGenerator[None, Message] kiB = 2**10 MiB = 2**20 class NegotiationError(Exception): """ An error raised when MTAs are not compatible with the filter """ class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: super().__init__() self._ready = anyio.Condition() async def aclose(self) -> None: async with self._ready: self._ready.notify_all() async def pre_receive_hook(self) -> None: async with self._ready: self._ready.notify_all() async def post_send_hook(self) -> None: # Await notification of either a receiver waiting or the broadcaster closing # This is necessary to delay returning until a filter has had a chance to return # a result. async with self._ready: await self._ready.wait() class Runner: """ A filter runner that coordinates passing data between a stream and multiple filters Instances can be used as handlers that can be passed to `anyio.abc.Listener.serve()` or used with any `anyio.abc.ByteStream`. """ def __init__(self, *filters: Filter): if len(filters) == 0: # pragma: no-cover raise TypeError("Runner requires at least one filter to run") self.filters = filters self.use_skip = True async def __call__(self, client: anyio.abc.ByteStream) -> None: """ Return an awaitable that starts and coordinates filters """ buff = SimpleBuffer(1*MiB) proto = FilterProtocol() sender = _sender(client, proto) channels = list[MessageChannel]() await sender.asend(None) # type: ignore # initialise async with anyio.create_task_group() as tasks, aclosing(sender), aclosing(client): while 1: try: buff[:] = await client.receive(buff.available) except (anyio.EndOfStream, anyio.ClosedResourceError): for channel in channels: await channel.aclose() return for message in proto.read_from(buff): match message: case Negotiate(): await self._negotiate(message, sender) case Macro(): # TODO: implement macro support ... case Connect(): channels[:] = await self._connect(message, sender, tasks) case Abort(): for channel in channels: await channel.aclose() case Close(): return case _: assert isinstance( message, ( Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, ), ) skip = isinstance(message, Body) for channel in channels: await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept(): await channel.aclose() channels.remove(channel) case resp: await sender.asend(resp) break else: await sender.asend( Accept() if len(channels) == 0 else Skip() if skip else Continue(), ) async def _negotiate(self, message: Negotiate, sender: Sender) -> None: # TODO: actually negotiate what the filter wants, not just "everything" actions = set(ActionFlags) # All actions! if actions != ActionFlags.unpack(message.action_flags): raise NegotiationError("MTA does not accept all actions required by the filter") resp = Negotiate(6, 0, 0) resp.protocol_flags = message.protocol_flags resp.action_flags = ActionFlags.pack(actions) await sender.asend(resp) self.use_skip = bool(resp.protocol_flags & ProtocolFlags.SKIP) async def _connect( self, message: Connect, sender: Sender, tasks: anyio.abc.TaskGroup, ) -> list[MessageChannel]: channels = list[MessageChannel]() for fltr in self.filters: lchannel, rchannel = _make_message_channel() channels.append(lchannel) session = Session(message, sender, _Broadcast()) match await tasks.start( _runner, fltr, session, rchannel, self.use_skip, ): case Continue(): continue case Message() as resp: await sender.asend(resp) return [] case _ as arg: # pragma: no-cover raise TypeError( f"task_status.started called with bad type: " f"{arg!r}", ) await sender.asend(Continue()) return channels def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (yield)) await client.send(buff[:]) del buff[:] _VALID_FINAL_RESPONSES = Reject, Discard, Accept, TemporaryFailure, ReplyCode async def _runner( fltr: Filter, session: Session, channel: MessageChannel, use_skip: bool, *, task_status: anyio.abc.TaskStatus, ) -> None: final_resp: ResponseMessage|None = None async def _filter_wrap( task_status: anyio.abc.TaskStatus, ) -> None: nonlocal final_resp async with session: task_status.started() final_resp = await fltr(session) if not isinstance(final_resp, _VALID_FINAL_RESPONSES): warn(f"expected a final response from {fltr}, got {final_resp}") final_resp = TemporaryFailure() async with anyio.create_task_group() as tasks: await tasks.start(_filter_wrap) task_status.started(final_resp or Continue()) while final_resp is None: try: message = await channel.receive() except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return assert isinstance( message, ( Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, ), ) resp = await session.deliver(message) if final_resp is not None: await channel.send(final_resp) # type: ignore elif use_skip and resp == Skip: await channel.send(Skip()) else: await channel.send(Continue())
tests/mock_stream.py 0 → 100644 +149 −0 Original line number Diff line number Diff line from __future__ import annotations import typing from collections.abc import AsyncGenerator from collections.abc import AsyncIterator from collections.abc import Callable from contextlib import asynccontextmanager from functools import wraps from types import TracebackType from typing import TYPE_CHECKING from typing import AsyncContextManager from typing import TypeVar import anyio from anyio.streams.buffered import BufferedByteReceiveStream from anyio.streams.stapled import StapledByteStream from anyio.streams.stapled import StapledObjectStream from async_generator import aclosing from kilter.protocol import * from kilter.protocol.buffer import SimpleBuffer from kilter.service import ResponseMessage P = typing.ParamSpec("P") SendT = typing.TypeVar("SendT") YieldT = typing.TypeVar("YieldT") def _make_aclosing( func: Callable[P, AsyncGenerator[YieldT, SendT]], ) -> Callable[P, AsyncContextManager[AsyncGenerator[YieldT, SendT]]]: @wraps(func) @asynccontextmanager async def wrap(*a: P.args, **k: P.kwargs) -> AsyncIterator[AsyncGenerator[YieldT, SendT]]: agen = func(*a, **k) async with aclosing(agen): yield agen return wrap class MockMessageStream: """ A mock of the right-side of an `anyio.abc.ByteStream` with test support on the left side """ if TYPE_CHECKING: Self = TypeVar("Self", bound="MockMessageStream") def __init__(self) -> None: self.buffer = SimpleBuffer(1024) self.closed = False async def __aenter__(self: Self) -> Self: send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes) send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes) self._stream = StapledObjectStream(send_obj, recv_obj) self.peer_stream = StapledByteStream( send_bytes, # type: ignore BufferedByteReceiveStream(recv_bytes), ) await self._stream.__aenter__() await self.peer_stream.__aenter__() return self async def __aexit__( self, et: type[BaseException]|None = None, ex: BaseException|None = None, tb: TracebackType|None = None, ) -> None: if not self.closed: if et is not None: await self.abort() else: await self.close() await self._stream.__aexit__(et, ex, tb) await self.peer_stream.__aexit__(et, ex, tb) async def abort(self) -> None: """ Send Abort and close the stream """ try: resp = await self.send_msg(Abort()) except anyio.BrokenResourceError: return assert len(resp) == 0, resp await self.close() async def close(self) -> None: """ Send Close and close the stream """ if self.closed: return resp = await self.send_msg(Close()) assert len(resp) == 0, resp await self._stream.aclose() self.closed = True async def send_msg(self, msg: Message) -> list[Message]: """ Send a message and return the messages sent in response """ responses = [] async with self._send_msg(msg) as aiter: async for resp in aiter: responses.append(resp) return responses @_make_aclosing async def _send_msg(self, msg: Message) -> AsyncGenerator[Message, None]: buff = self.buffer msg.pack(buff) await self._stream.send(buff[:].tobytes()) del buff[:] if isinstance(msg, (Abort, Close)): return while 1: try: buff[:] = chunk = await self._stream.receive() except anyio.EndOfStream: break if len(chunk) == 0: break try: msg, size = Message.unpack(buff) except NeedsMore: continue del buff[:size] yield msg if isinstance(msg, typing.get_args(ResponseMessage) + (Negotiate, Skip)): break assert buff.filled == 0, buff[:].tobytes() async def send_and_expect(self, msg: Message, *exp: type[Message]|Message) -> None: """ Send a message and check the responses by type or equality """ resp = await self.send_msg(msg) assert len(resp) == len(exp), resp for r, e in zip(resp, exp): if isinstance(e, type): assert isinstance(r, e), f"expected {e}, got {type(r)}" else: assert r == e, r
tests/test_runner.py 0 → 100644 +306 −0 Original line number Diff line number Diff line import trio.testing from kilter.protocol import * from kilter.service import Runner from kilter.service import Session from . import AsyncTestCase from .mock_stream import MockMessageStream class RunnerTests(AsyncTestCase): """ Tests for the Runner class """ async def test_helo(self) -> None: """ Check that awaiting Session.helo() responds to Connect with Continue """ hostname = "" @Runner async def test_filter(session: Session) -> Accept: nonlocal hostname hostname = await session.helo() return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) assert hostname == "" await stream_mock.send_and_expect(Helo("test.example.com"), Accept) assert hostname == "test.example.com" async def test_respond_to_peer(self) -> None: """ Check that returning before engaging with async session features works """ @Runner async def test_filter(session: Session) -> Reject: return Reject() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Reject) async def test_post_header(self) -> None: """ Check that delaying return until a phase later than CONNECT sends Continue """ @Runner async def test_filter(session: Session) -> Accept: assert "test@example.com" == await session.envelope_from() return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Helo("test.example.com"), Continue) await stream_mock.send_and_expect(EnvelopeFrom(b"test@example.com"), Accept) async def test_body_all(self) -> None: """ Check that the whole body is processes when Continue is passed """ contents = b"" @Runner async def test_filter(session: Session) -> Accept: nonlocal contents async with session.body as body: async for chunk in body: await trio.sleep(0) contents += chunk return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Continue) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue) await stream_mock.send_and_expect(EndOfMessage(b"Bye"), Accept) assert contents == b"This is a message sent in multiple chunks. Bye", contents async def test_body_skip(self) -> None: """ Check that Skip is returned once a body loop is broken """ contents = b"" @Runner async def test_filter(session: Session) -> Accept: nonlocal contents async with session.body as body: async for chunk in body: contents += chunk if b"message" in chunk.tobytes(): break # Move phase onto POST await session.change_sender("test@example.com") return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, ProtocolFlags.SKIP), Negotiate(6, 0x1ff, ProtocolFlags.SKIP), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Skip) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Skip) await stream_mock.send_and_expect( EndOfMessage(b"Bye"), ChangeSender("test@example.com"), Accept, ) assert contents == b"This is a message sent ", contents async def test_body_fake_skip(self) -> None: """ Check that Skip is NOT returned if not accepted by an MTA """ contents = b"" @Runner async def test_filter(session: Session) -> Accept: nonlocal contents async with session.body as body: async for chunk in body: contents += chunk if b"message" in chunk.tobytes(): break # Move phase onto POST await session.change_sender("test@example.com") return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Continue) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue) await stream_mock.send_and_expect( EndOfMessage(b"Bye"), ChangeSender("test@example.com"), Accept, ) assert contents == b"This is a message sent ", contents async def test_multiple(self) -> None: """ Check that multiple filters receive the messages they expect """ hostname = "" contents1 = b"" contents2 = b"" async def test_filter1(session: Session) -> Reject: nonlocal hostname nonlocal contents1 hostname = await session.helo() async with session.body as body: async for chunk in body: await trio.sleep(0) contents1 += chunk return Reject() async def test_filter2(session: Session) -> Accept: nonlocal contents2 async with session.body as body: async for chunk in body: await trio.sleep(0) contents2 += chunk if b"message" in chunk.tobytes(): break return Accept() runner = Runner(test_filter1, test_filter2) async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(runner, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, ProtocolFlags.SKIP), Negotiate(6, 0x1ff, ProtocolFlags.SKIP), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect(Helo("test.example.com"), Continue) await stream_mock.send_and_expect(Body(b"This is a "), Continue) await stream_mock.send_and_expect(Body(b"message sent "), Continue) await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue) await stream_mock.send_and_expect(EndOfMessage(b"Bye"), Reject) assert hostname == "test.example.com", hostname assert contents1 == b"This is a message sent in multiple chunks. Bye", contents1 assert contents2 == b"This is a message sent ", contents2 async def test_abort(self) -> None: """ Check that a runner closes cleanly when it receives an Abort """ cancelled = False @Runner async def test_filter(session: Session) -> Accept: nonlocal cancelled try: await session.helo() except trio.Cancelled: cancelled = True raise return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.abort() assert cancelled async def test_bad_response(self) -> None: """ Check that a runner closes cleanly when it receives an Abort """ @Runner async def test_filter(session: Session) -> Skip: await session.helo() return Skip() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) with self.assertWarns(UserWarning) as wcm: await stream_mock.send_and_expect(Helo("test.example.com"), TemporaryFailure) assert "expected a final response" in str(wcm.warning)