Commit 9fdaa95a authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Allow Session to handle Abort messages

Closes #8
parent 4f666624
Loading
Loading
Loading
Loading
+12 −1
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ from . import util
EventMessage: TypeAlias = Union[
	Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown,
	Header, EndOfHeaders, Body, EndOfMessage,
	Macro,
	Macro, Abort,
]
"""
Messages sent from an MTA to a filter
@@ -55,6 +55,12 @@ Messages send from a filter to an MTA after an `EndOfMessage` to modify a messag
"""


class Aborted(BaseException):
	"""
	An exception for aborting filters on receipt of an Abort message
	"""


class Filter(Protocol):
	"""
	Filters are callables that accept a `Session` and return a response
@@ -249,6 +255,11 @@ class Session:
			case Macro():
				self.macros.update(message.macros)
				return Continue  # not strictly necessary, but type checker needs something
			case Abort():
				async with self._broadcast:
					self.phase = Phase.CONNECT
				await self._broadcast.abort(Aborted)
				return Continue
			case Helo():
				phase = Phase.MAIL
			case EnvelopeFrom() | EnvelopeRecipient() | Unknown():
+26 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from unittest.mock import call
import trio.testing

from kilter.protocol import *
from kilter.service.session import Aborted
from kilter.service.session import Phase
from kilter.service.session import Session

@@ -410,3 +411,28 @@ class SessionTests(AsyncTestCase):

			await session.deliver(Helo("test.example.com"))
			await session.deliver(EnvelopeFrom(b"test@example.com"))

	async def test_abort_in_helo(self) -> None:
		"""
		Check that receipt of Abort while awaiting Helo raises Aborted and resets
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			assert await session.helo() == "test.example.org"
			assert session.phase == Phase.MAIL
			with self.assertRaises(Aborted):
				await session.extension("MAIL")
			assert session.phase == Phase.CONNECT
			assert await session.helo() == "test.example.com"

		async with trio.open_nursery() as tg:
			tg.start_soon(test_filter)
			await trio.testing.wait_all_tasks_blocked()

			await session.deliver(Helo("test.example.org"))
			await session.deliver(Abort())
			await session.deliver(Helo("test.example.com"))