Commit 6ce627a0 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add basic macro support to Session and Runner

parent 4010b5f2
Loading
Loading
Loading
Loading
+13 −4
Original line number Diff line number Diff line
@@ -89,6 +89,7 @@ class Runner:
		proto = FilterProtocol()
		sender = _sender(client, proto)
		channels = list[MessageChannel]()
		macro: Macro|None = None

		await sender.asend(None)  # type: ignore # initialise

@@ -104,11 +105,13 @@ class Runner:
					match message:
						case Negotiate():
							await self._negotiate(message, sender)
						case Macro():
							# TODO: implement macro support
							...
						case Macro() as macro:
							# Note that this Macro will hang around as "macro"; this is for
							# Connect messages.
							for channel in channels:
								await channel.send(macro)
						case Connect():
							channels[:] = await self._connect(message, sender, tasks)
							channels[:] = await self._connect(message, sender, tasks, macro)
						case Abort():
							for channel in channels:
								await channel.aclose()
@@ -156,12 +159,15 @@ class Runner:
		message: Connect,
		sender: Sender,
		tasks: anyio.abc.TaskGroup,
		macro: Macro|None,
	) -> list[MessageChannel]:
		channels = list[MessageChannel]()
		for fltr in self.filters:
			lchannel, rchannel = _make_message_channel()
			channels.append(lchannel)
			session = Session(message, sender, _Broadcast())
			if macro:
				await session.deliver(macro)
			match await tasks.start(
				_runner, fltr, session, rchannel, self.use_skip,
			):
@@ -222,6 +228,9 @@ async def _runner(
			except (anyio.EndOfStream, anyio.ClosedResourceError):
				tasks.cancel_scope.cancel()
				return
			if isinstance(message, Macro):
				await session.deliver(message)
				continue
			assert isinstance(message, _VALID_EVENT_MESSAGE)
			resp = await session.deliver(message)
			if final_resp is not None:
+5 −0
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ from . import util
EventMessage: TypeAlias = Union[
	Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown,
	Header, EndOfHeaders, Body, EndOfMessage,
	Macro,
]
"""
Messages sent from an MTA to a filter
@@ -172,6 +173,7 @@ class Session:
		self._editor = sender
		self._broadcast = broadcast or util.Broadcast[EventMessage]()

		self.macros = dict[str, str]()
		self.headers = HeadersAccessor(self, sender)
		self.body = BodyAccessor(self, sender)

@@ -195,6 +197,9 @@ class Session:
		match message:
			case Body() if self.skip:
				return Skip
			case Macro():
				self.macros.update(message.macros)
				return Continue  # not strictly necessary, but type checker needs something
			case Helo():
				self.phase = Phase.MAIL
			case EnvelopeFrom() | EnvelopeRecipient() | Unknown():
+1 −1
Original line number Diff line number Diff line
@@ -117,7 +117,7 @@ class MockMessageStream:
		msg.pack(buff)
		await self._stream.send(buff[:].tobytes())
		del buff[:]
		if isinstance(msg, (Abort, Close)):
		if isinstance(msg, (Macro, Abort, Close)):
			return
		while 1:
			try:
+29 −0
Original line number Diff line number Diff line
@@ -304,3 +304,32 @@ class RunnerTests(AsyncTestCase):
				await stream_mock.send_and_expect(Helo("test.example.com"), TemporaryFailure)

			assert "expected a final response" in str(wcm.warning)

	async def test_macros(self) -> None:
		"""
		Check that delivered macros are available
		"""
		@Runner
		async def test_filter(session: Session) -> Accept:
			self.assertDictEqual(session.macros, {"{spam}": "yes", "{eggs}": "yes"})
			await session.helo()
			self.assertDictEqual(session.macros, {"{spam}": "no", "{ham}": "maybe", "{eggs}": "yes"})
			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)
			await stream_mock.send_and_expect(
				Macro(Connect.ident, {"{spam}": "yes", "{eggs}": "yes"}),
			)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(
				Macro(Connect.ident, {"{spam}": "no", "{ham}": "maybe"}),
			)
			await stream_mock.send_and_expect(Helo("test.example.com"), Accept)
			await stream_mock.close()
+24 −0
Original line number Diff line number Diff line
@@ -365,3 +365,27 @@ class SessionTests(AsyncTestCase):
		sender._asend.assert_has_awaits([
			call(RemoveRecipient("test@example.com")),
		])

	async def test_load_macros(self) -> None:
		"""
		Check that `deliver(Macro())` updates the macros dict
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		async def test_filter(session: Session) -> Accept:
			self.assertDictEqual(session.macros, {})
			await session.helo()
			self.assertDictEqual(session.macros, {"{spam}": "yes", "{eggs}": "yes"})
			await session.envelope_from()
			self.assertDictEqual(session.macros, {"{spam}": "no", "{ham}": "maybe", "{eggs}": "yes"})
			return Accept()

		async with trio.open_nursery() as tg:
			tg.start_soon(test_filter, session)
			await trio.testing.wait_all_tasks_blocked()

			await session.deliver(Macro(Helo.ident, {"{spam}": "yes", "{eggs}": "yes"}))
			await session.deliver(Helo("test.example.com"))
			await session.deliver(Macro(Helo.ident, {"{spam}": "no", "{ham}": "maybe"}))
			await session.deliver(EnvelopeFrom(b"test@example.com"))