Loading kilter/service/runner.py +13 −4 Original line number Diff line number Diff line Loading @@ -89,6 +89,7 @@ class Runner: proto = FilterProtocol() sender = _sender(client, proto) channels = list[MessageChannel]() macro: Macro|None = None await sender.asend(None) # type: ignore # initialise Loading @@ -104,11 +105,13 @@ class Runner: match message: case Negotiate(): await self._negotiate(message, sender) case Macro(): # TODO: implement macro support ... case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. for channel in channels: await channel.send(macro) case Connect(): channels[:] = await self._connect(message, sender, tasks) channels[:] = await self._connect(message, sender, tasks, macro) case Abort(): for channel in channels: await channel.aclose() Loading Loading @@ -156,12 +159,15 @@ class Runner: message: Connect, sender: Sender, tasks: anyio.abc.TaskGroup, macro: Macro|None, ) -> list[MessageChannel]: channels = list[MessageChannel]() for fltr in self.filters: lchannel, rchannel = _make_message_channel() channels.append(lchannel) session = Session(message, sender, _Broadcast()) if macro: await session.deliver(macro) match await tasks.start( _runner, fltr, session, rchannel, self.use_skip, ): Loading Loading @@ -222,6 +228,9 @@ async def _runner( except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return if isinstance(message, Macro): await session.deliver(message) continue assert isinstance(message, _VALID_EVENT_MESSAGE) resp = await session.deliver(message) if final_resp is not None: Loading kilter/service/session.py +5 −0 Original line number Diff line number Diff line Loading @@ -29,6 +29,7 @@ from . import util EventMessage: TypeAlias = Union[ Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, Macro, ] """ Messages sent from an MTA to a filter Loading Loading @@ -172,6 +173,7 @@ class Session: self._editor = sender self._broadcast = broadcast or util.Broadcast[EventMessage]() self.macros = dict[str, str]() self.headers = HeadersAccessor(self, sender) self.body = BodyAccessor(self, sender) Loading @@ -195,6 +197,9 @@ class Session: match message: case Body() if self.skip: return Skip case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something case Helo(): self.phase = Phase.MAIL case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): Loading tests/mock_stream.py +1 −1 Original line number Diff line number Diff line Loading @@ -117,7 +117,7 @@ class MockMessageStream: msg.pack(buff) await self._stream.send(buff[:].tobytes()) del buff[:] if isinstance(msg, (Abort, Close)): if isinstance(msg, (Macro, Abort, Close)): return while 1: try: Loading tests/test_runner.py +29 −0 Original line number Diff line number Diff line Loading @@ -304,3 +304,32 @@ class RunnerTests(AsyncTestCase): await stream_mock.send_and_expect(Helo("test.example.com"), TemporaryFailure) assert "expected a final response" in str(wcm.warning) async def test_macros(self) -> None: """ Check that delivered macros are available """ @Runner async def test_filter(session: Session) -> Accept: self.assertDictEqual(session.macros, {"{spam}": "yes", "{eggs}": "yes"}) await session.helo() self.assertDictEqual(session.macros, {"{spam}": "no", "{ham}": "maybe", "{eggs}": "yes"}) return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect( Macro(Connect.ident, {"{spam}": "yes", "{eggs}": "yes"}), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect( Macro(Connect.ident, {"{spam}": "no", "{ham}": "maybe"}), ) await stream_mock.send_and_expect(Helo("test.example.com"), Accept) await stream_mock.close() tests/test_session.py +24 −0 Original line number Diff line number Diff line Loading @@ -365,3 +365,27 @@ class SessionTests(AsyncTestCase): sender._asend.assert_has_awaits([ call(RemoveRecipient("test@example.com")), ]) async def test_load_macros(self) -> None: """ Check that `deliver(Macro())` updates the macros dict """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) async def test_filter(session: Session) -> Accept: self.assertDictEqual(session.macros, {}) await session.helo() self.assertDictEqual(session.macros, {"{spam}": "yes", "{eggs}": "yes"}) await session.envelope_from() self.assertDictEqual(session.macros, {"{spam}": "no", "{ham}": "maybe", "{eggs}": "yes"}) return Accept() async with trio.open_nursery() as tg: tg.start_soon(test_filter, session) await trio.testing.wait_all_tasks_blocked() await session.deliver(Macro(Helo.ident, {"{spam}": "yes", "{eggs}": "yes"})) await session.deliver(Helo("test.example.com")) await session.deliver(Macro(Helo.ident, {"{spam}": "no", "{ham}": "maybe"})) await session.deliver(EnvelopeFrom(b"test@example.com")) Loading
kilter/service/runner.py +13 −4 Original line number Diff line number Diff line Loading @@ -89,6 +89,7 @@ class Runner: proto = FilterProtocol() sender = _sender(client, proto) channels = list[MessageChannel]() macro: Macro|None = None await sender.asend(None) # type: ignore # initialise Loading @@ -104,11 +105,13 @@ class Runner: match message: case Negotiate(): await self._negotiate(message, sender) case Macro(): # TODO: implement macro support ... case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. for channel in channels: await channel.send(macro) case Connect(): channels[:] = await self._connect(message, sender, tasks) channels[:] = await self._connect(message, sender, tasks, macro) case Abort(): for channel in channels: await channel.aclose() Loading Loading @@ -156,12 +159,15 @@ class Runner: message: Connect, sender: Sender, tasks: anyio.abc.TaskGroup, macro: Macro|None, ) -> list[MessageChannel]: channels = list[MessageChannel]() for fltr in self.filters: lchannel, rchannel = _make_message_channel() channels.append(lchannel) session = Session(message, sender, _Broadcast()) if macro: await session.deliver(macro) match await tasks.start( _runner, fltr, session, rchannel, self.use_skip, ): Loading Loading @@ -222,6 +228,9 @@ async def _runner( except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return if isinstance(message, Macro): await session.deliver(message) continue assert isinstance(message, _VALID_EVENT_MESSAGE) resp = await session.deliver(message) if final_resp is not None: Loading
kilter/service/session.py +5 −0 Original line number Diff line number Diff line Loading @@ -29,6 +29,7 @@ from . import util EventMessage: TypeAlias = Union[ Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, Macro, ] """ Messages sent from an MTA to a filter Loading Loading @@ -172,6 +173,7 @@ class Session: self._editor = sender self._broadcast = broadcast or util.Broadcast[EventMessage]() self.macros = dict[str, str]() self.headers = HeadersAccessor(self, sender) self.body = BodyAccessor(self, sender) Loading @@ -195,6 +197,9 @@ class Session: match message: case Body() if self.skip: return Skip case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something case Helo(): self.phase = Phase.MAIL case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): Loading
tests/mock_stream.py +1 −1 Original line number Diff line number Diff line Loading @@ -117,7 +117,7 @@ class MockMessageStream: msg.pack(buff) await self._stream.send(buff[:].tobytes()) del buff[:] if isinstance(msg, (Abort, Close)): if isinstance(msg, (Macro, Abort, Close)): return while 1: try: Loading
tests/test_runner.py +29 −0 Original line number Diff line number Diff line Loading @@ -304,3 +304,32 @@ class RunnerTests(AsyncTestCase): await stream_mock.send_and_expect(Helo("test.example.com"), TemporaryFailure) assert "expected a final response" in str(wcm.warning) async def test_macros(self) -> None: """ Check that delivered macros are available """ @Runner async def test_filter(session: Session) -> Accept: self.assertDictEqual(session.macros, {"{spam}": "yes", "{eggs}": "yes"}) await session.helo() self.assertDictEqual(session.macros, {"{spam}": "no", "{ham}": "maybe", "{eggs}": "yes"}) return Accept() async with trio.open_nursery() as tg, MockMessageStream() as stream_mock: tg.start_soon(test_filter, stream_mock.peer_stream) await trio.testing.wait_all_tasks_blocked() await stream_mock.send_and_expect( Negotiate(6, 0x1ff, 0), Negotiate(6, 0x1ff, 0), ) await stream_mock.send_and_expect( Macro(Connect.ident, {"{spam}": "yes", "{eggs}": "yes"}), ) await stream_mock.send_and_expect(Connect("test.example.com"), Continue) await stream_mock.send_and_expect( Macro(Connect.ident, {"{spam}": "no", "{ham}": "maybe"}), ) await stream_mock.send_and_expect(Helo("test.example.com"), Accept) await stream_mock.close()
tests/test_session.py +24 −0 Original line number Diff line number Diff line Loading @@ -365,3 +365,27 @@ class SessionTests(AsyncTestCase): sender._asend.assert_has_awaits([ call(RemoveRecipient("test@example.com")), ]) async def test_load_macros(self) -> None: """ Check that `deliver(Macro())` updates the macros dict """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) async def test_filter(session: Session) -> Accept: self.assertDictEqual(session.macros, {}) await session.helo() self.assertDictEqual(session.macros, {"{spam}": "yes", "{eggs}": "yes"}) await session.envelope_from() self.assertDictEqual(session.macros, {"{spam}": "no", "{ham}": "maybe", "{eggs}": "yes"}) return Accept() async with trio.open_nursery() as tg: tg.start_soon(test_filter, session) await trio.testing.wait_all_tasks_blocked() await session.deliver(Macro(Helo.ident, {"{spam}": "yes", "{eggs}": "yes"})) await session.deliver(Helo("test.example.com")) await session.deliver(Macro(Helo.ident, {"{spam}": "no", "{ham}": "maybe"})) await session.deliver(EnvelopeFrom(b"test@example.com"))