Loading kilter/service/session.py +12 −1 Original line number Diff line number Diff line Loading @@ -32,7 +32,7 @@ from . import util EventMessage: TypeAlias = Union[ Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, Macro, Macro, Abort, ] """ Messages sent from an MTA to a filter Loading @@ -55,6 +55,12 @@ Messages send from a filter to an MTA after an `EndOfMessage` to modify a messag """ class Aborted(BaseException): """ An exception for aborting filters on receipt of an Abort message """ class Filter(Protocol): """ Filters are callables that accept a `Session` and return a response Loading Loading @@ -249,6 +255,11 @@ class Session: case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something case Abort(): async with self._broadcast: self.phase = Phase.CONNECT await self._broadcast.abort(Aborted) return Continue case Helo(): phase = Phase.MAIL case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): Loading tests/test_session.py +26 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from unittest.mock import call import trio.testing from kilter.protocol import * from kilter.service.session import Aborted from kilter.service.session import Phase from kilter.service.session import Session Loading Loading @@ -410,3 +411,28 @@ class SessionTests(AsyncTestCase): await session.deliver(Helo("test.example.com")) await session.deliver(EnvelopeFrom(b"test@example.com")) async def test_abort_in_helo(self) -> None: """ Check that receipt of Abort while awaiting Helo raises Aborted and resets """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) @with_session(session) async def test_filter() -> None: assert session.phase == Phase.CONNECT assert await session.helo() == "test.example.org" assert session.phase == Phase.MAIL with self.assertRaises(Aborted): await session.extension("MAIL") assert session.phase == Phase.CONNECT assert await session.helo() == "test.example.com" async with trio.open_nursery() as tg: tg.start_soon(test_filter) await trio.testing.wait_all_tasks_blocked() await session.deliver(Helo("test.example.org")) await session.deliver(Abort()) await session.deliver(Helo("test.example.com")) Loading
kilter/service/session.py +12 −1 Original line number Diff line number Diff line Loading @@ -32,7 +32,7 @@ from . import util EventMessage: TypeAlias = Union[ Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header, EndOfHeaders, Body, EndOfMessage, Macro, Macro, Abort, ] """ Messages sent from an MTA to a filter Loading @@ -55,6 +55,12 @@ Messages send from a filter to an MTA after an `EndOfMessage` to modify a messag """ class Aborted(BaseException): """ An exception for aborting filters on receipt of an Abort message """ class Filter(Protocol): """ Filters are callables that accept a `Session` and return a response Loading Loading @@ -249,6 +255,11 @@ class Session: case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something case Abort(): async with self._broadcast: self.phase = Phase.CONNECT await self._broadcast.abort(Aborted) return Continue case Helo(): phase = Phase.MAIL case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): Loading
tests/test_session.py +26 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from unittest.mock import call import trio.testing from kilter.protocol import * from kilter.service.session import Aborted from kilter.service.session import Phase from kilter.service.session import Session Loading Loading @@ -410,3 +411,28 @@ class SessionTests(AsyncTestCase): await session.deliver(Helo("test.example.com")) await session.deliver(EnvelopeFrom(b"test@example.com")) async def test_abort_in_helo(self) -> None: """ Check that receipt of Abort while awaiting Helo raises Aborted and resets """ sender = MockEditor() session = Session(Connect("example.com", LOCALHOST, 1025), sender) @with_session(session) async def test_filter() -> None: assert session.phase == Phase.CONNECT assert await session.helo() == "test.example.org" assert session.phase == Phase.MAIL with self.assertRaises(Aborted): await session.extension("MAIL") assert session.phase == Phase.CONNECT assert await session.helo() == "test.example.com" async with trio.open_nursery() as tg: tg.start_soon(test_filter) await trio.testing.wait_all_tasks_blocked() await session.deliver(Helo("test.example.org")) await session.deliver(Abort()) await session.deliver(Helo("test.example.com"))