Commit c8c549ea authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add runner module

parent d0be55a6
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ project). The framework aims to provide Pythonic interfaces for implementing fi
including leveraging coroutines instead of libmilter's callback-style interface.
"""

from .runner import Runner
from .session import END
from .session import START
from .session import After
@@ -22,6 +23,7 @@ __all__ = [
	"Before",
	"END",
	"ResponseMessage",
	"Runner",
	"START",
	"Session",
]
+242 −0
Original line number Diff line number Diff line
# Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""
Coordinate receiving and sending raw messages with a filter and Session object

The primary class in this module (`Runner`) is intended to be used with an
`anyio.abc.Listener`, which can be obtained, for instance, from
`anyio.create_tcp_listener()`.
"""

from __future__ import annotations

from collections.abc import AsyncGenerator
from warnings import warn

import anyio.abc
from anyio.streams.stapled import StapledObjectStream
from async_generator import aclosing

from kilter.protocol.buffer import SimpleBuffer
from kilter.protocol.core import FilterProtocol
from kilter.protocol.messages import ProtocolFlags

from .session import *
from .util import Broadcast

MessageChannel = anyio.abc.ObjectStream[Message]
Sender = AsyncGenerator[None, Message]

kiB = 2**10
MiB = 2**20


class NegotiationError(Exception):
	"""
	An error raised when MTAs are not compatible with the filter
	"""


class _Broadcast(Broadcast[EventMessage]):

	def __init__(self) -> None:
		super().__init__()
		self._ready = anyio.Condition()

	async def aclose(self) -> None:
		async with self._ready:
			self._ready.notify_all()

	async def pre_receive_hook(self) -> None:
		async with self._ready:
			self._ready.notify_all()

	async def post_send_hook(self) -> None:
		# Await notification of either a receiver waiting or the broadcaster closing
		# This is necessary to delay returning until a filter has had a chance to return
		# a result.
		async with self._ready:
			await self._ready.wait()


class Runner:
	"""
	A filter runner that coordinates passing data between a stream and multiple filters

	Instances can be used as handlers that can be passed to `anyio.abc.Listener.serve()` or
	used with any `anyio.abc.ByteStream`.
	"""

	def __init__(self, *filters: Filter):
		if len(filters) == 0:  # pragma: no-cover
			raise TypeError("Runner requires at least one filter to run")
		self.filters = filters
		self.use_skip = True

	async def __call__(self, client: anyio.abc.ByteStream) -> None:
		"""
		Return an awaitable that starts and coordinates filters
		"""
		buff = SimpleBuffer(1*MiB)
		proto = FilterProtocol()
		sender = _sender(client, proto)
		channels = list[MessageChannel]()

		await sender.asend(None)  # type: ignore # initialise

		async with anyio.create_task_group() as tasks, aclosing(sender), aclosing(client):
			while 1:
				try:
					buff[:] = await client.receive(buff.available)
				except (anyio.EndOfStream, anyio.ClosedResourceError):
					for channel in channels:
						await channel.aclose()
					return
				for message in proto.read_from(buff):
					match message:
						case Negotiate():
							await self._negotiate(message, sender)
						case Macro():
							# TODO: implement macro support
							...
						case Connect():
							channels[:] = await self._connect(message, sender, tasks)
						case Abort():
							for channel in channels:
								await channel.aclose()
						case Close():
							return
						case _:
							assert isinstance(
								message,
								(
									Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown,
									Header, EndOfHeaders, Body, EndOfMessage,
								),
							)
							skip = isinstance(message, Body)
							for channel in channels:
								await channel.send(message)
								match (await channel.receive()):
									case Skip():
										continue
									case Continue():
										skip = False
									case Accept():
										await channel.aclose()
										channels.remove(channel)
									case resp:
										await sender.asend(resp)
										break
							else:
								await sender.asend(
									Accept() if len(channels) == 0 else
									Skip() if skip else
									Continue(),
								)

	async def _negotiate(self, message: Negotiate, sender: Sender) -> None:
		# TODO: actually negotiate what the filter wants, not just "everything"
		actions = set(ActionFlags)  # All actions!
		if actions != ActionFlags.unpack(message.action_flags):
			raise NegotiationError("MTA does not accept all actions required by the filter")

		resp = Negotiate(6, 0, 0)
		resp.protocol_flags = message.protocol_flags
		resp.action_flags = ActionFlags.pack(actions)

		await sender.asend(resp)

		self.use_skip = bool(resp.protocol_flags & ProtocolFlags.SKIP)

	async def _connect(
		self,
		message: Connect,
		sender: Sender,
		tasks: anyio.abc.TaskGroup,
	) -> list[MessageChannel]:
		channels = list[MessageChannel]()
		for fltr in self.filters:
			lchannel, rchannel = _make_message_channel()
			channels.append(lchannel)
			session = Session(message, sender, _Broadcast())
			match await tasks.start(
				_runner, fltr, session, rchannel, self.use_skip,
			):
				case Continue():
					continue
				case Message() as resp:
					await sender.asend(resp)
					return []
				case _ as arg:  # pragma: no-cover
					raise TypeError(
						f"task_status.started called with bad type: "
						f"{arg!r}",
					)
		await sender.asend(Continue())
		return channels


def _make_message_channel() -> tuple[MessageChannel, MessageChannel]:
	lsend, rrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	rsend, lrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)


async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender:
	buff = SimpleBuffer(1*kiB)
	while 1:
		proto.write_to(buff, (yield))
		await client.send(buff[:])
		del buff[:]


_VALID_FINAL_RESPONSES = Reject, Discard, Accept, TemporaryFailure, ReplyCode

async def _runner(
	fltr: Filter,
	session: Session,
	channel: MessageChannel,
	use_skip: bool, *,
	task_status: anyio.abc.TaskStatus,
) -> None:
	final_resp: ResponseMessage|None = None

	async def _filter_wrap(
		task_status: anyio.abc.TaskStatus,
	) -> None:
		nonlocal final_resp
		async with session:
			task_status.started()
			final_resp = await fltr(session)
		if not isinstance(final_resp, _VALID_FINAL_RESPONSES):
			warn(f"expected a final response from {fltr}, got {final_resp}")
			final_resp = TemporaryFailure()

	async with anyio.create_task_group() as tasks:
		await tasks.start(_filter_wrap)
		task_status.started(final_resp or Continue())
		while final_resp is None:
			try:
				message = await channel.receive()
			except (anyio.EndOfStream, anyio.ClosedResourceError):
				tasks.cancel_scope.cancel()
				return
			assert isinstance(
				message,
				(
					Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown, Header,
					EndOfHeaders, Body, EndOfMessage,
				),
			)
			resp = await session.deliver(message)
			if final_resp is not None:
				await channel.send(final_resp)  # type: ignore
			elif use_skip and resp == Skip:
				await channel.send(Skip())
			else:
				await channel.send(Continue())

tests/mock_stream.py

0 → 100644
+149 −0
Original line number Diff line number Diff line
from __future__ import annotations

import typing
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from collections.abc import Callable
from contextlib import asynccontextmanager
from functools import wraps
from types import TracebackType
from typing import TYPE_CHECKING
from typing import AsyncContextManager
from typing import TypeVar

import anyio
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.stapled import StapledByteStream
from anyio.streams.stapled import StapledObjectStream
from async_generator import aclosing

from kilter.protocol import *
from kilter.protocol.buffer import SimpleBuffer
from kilter.service import ResponseMessage

P = typing.ParamSpec("P")
SendT = typing.TypeVar("SendT")
YieldT = typing.TypeVar("YieldT")


def _make_aclosing(
	func: Callable[P, AsyncGenerator[YieldT, SendT]],
) -> Callable[P, AsyncContextManager[AsyncGenerator[YieldT, SendT]]]:

	@wraps(func)
	@asynccontextmanager
	async def wrap(*a: P.args, **k: P.kwargs) -> AsyncIterator[AsyncGenerator[YieldT, SendT]]:
		agen = func(*a, **k)
		async with aclosing(agen):
			yield agen

	return wrap


class MockMessageStream:
	"""
	A mock of the right-side of an `anyio.abc.ByteStream` with test support on the left side
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="MockMessageStream")

	def __init__(self) -> None:
		self.buffer = SimpleBuffer(1024)
		self.closed = False

	async def __aenter__(self: Self) -> Self:
		send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes)
		send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes)

		self._stream = StapledObjectStream(send_obj, recv_obj)
		self.peer_stream = StapledByteStream(
			send_bytes,  # type: ignore
			BufferedByteReceiveStream(recv_bytes),
		)
		await self._stream.__aenter__()
		await self.peer_stream.__aenter__()
		return self

	async def __aexit__(
		self,
		et: type[BaseException]|None = None,
		ex: BaseException|None = None,
		tb: TracebackType|None = None,
	) -> None:
		if not self.closed:
			if et is not None:
				await self.abort()
			else:
				await self.close()
		await self._stream.__aexit__(et, ex, tb)
		await self.peer_stream.__aexit__(et, ex, tb)

	async def abort(self) -> None:
		"""
		Send Abort and close the stream
		"""
		try:
			resp = await self.send_msg(Abort())
		except anyio.BrokenResourceError:
			return
		assert len(resp) == 0, resp
		await self.close()

	async def close(self) -> None:
		"""
		Send Close and close the stream
		"""
		if self.closed:
			return
		resp = await self.send_msg(Close())
		assert len(resp) == 0, resp
		await self._stream.aclose()
		self.closed = True

	async def send_msg(self, msg: Message) -> list[Message]:
		"""
		Send a message and return the messages sent in response
		"""
		responses = []
		async with self._send_msg(msg) as aiter:
			async for resp in aiter:
				responses.append(resp)
		return responses

	@_make_aclosing
	async def _send_msg(self, msg: Message) -> AsyncGenerator[Message, None]:
		buff = self.buffer
		msg.pack(buff)
		await self._stream.send(buff[:].tobytes())
		del buff[:]
		if isinstance(msg, (Abort, Close)):
			return
		while 1:
			try:
				buff[:] = chunk = await self._stream.receive()
			except anyio.EndOfStream:
				break
			if len(chunk) == 0:
				break
			try:
				msg, size = Message.unpack(buff)
			except NeedsMore:
				continue
			del buff[:size]
			yield msg
			if isinstance(msg, typing.get_args(ResponseMessage) + (Negotiate, Skip)):
				break
		assert buff.filled == 0, buff[:].tobytes()

	async def send_and_expect(self, msg: Message, *exp: type[Message]|Message) -> None:
		"""
		Send a message and check the responses by type or equality
		"""
		resp = await self.send_msg(msg)
		assert len(resp) == len(exp), resp
		for r, e in zip(resp, exp):
			if isinstance(e, type):
				assert isinstance(r, e), f"expected {e}, got {type(r)}"
			else:
				assert r == e, r

tests/test_runner.py

0 → 100644
+306 −0
Original line number Diff line number Diff line
import trio.testing

from kilter.protocol import *
from kilter.service import Runner
from kilter.service import Session

from . import AsyncTestCase
from .mock_stream import MockMessageStream


class RunnerTests(AsyncTestCase):
	"""
	Tests for the Runner class
	"""

	async def test_helo(self) -> None:
		"""
		Check that awaiting Session.helo() responds to Connect with Continue
		"""
		hostname = ""

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal hostname
			hostname = await session.helo()
			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

			assert hostname == ""
			await stream_mock.send_and_expect(Helo("test.example.com"), Accept)
			assert hostname == "test.example.com"

	async def test_respond_to_peer(self) -> None:
		"""
		Check that returning before engaging with async session features works
		"""

		@Runner
		async def test_filter(session: Session) -> Reject:
			return Reject()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Reject)

	async def test_post_header(self) -> None:
		"""
		Check that delaying return until a phase later than CONNECT sends Continue
		"""
		@Runner
		async def test_filter(session: Session) -> Accept:
			assert "test@example.com" == await session.envelope_from()
			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(Helo("test.example.com"), Continue)

			await stream_mock.send_and_expect(EnvelopeFrom(b"test@example.com"), Accept)

	async def test_body_all(self) -> None:
		"""
		Check that the whole body is processes when Continue is passed
		"""
		contents = b""

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal contents
			async with session.body as body:
				async for chunk in body:
					await trio.sleep(0)
					contents += chunk
			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

			await stream_mock.send_and_expect(Body(b"This is a "), Continue)
			await stream_mock.send_and_expect(Body(b"message sent "), Continue)
			await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue)
			await stream_mock.send_and_expect(EndOfMessage(b"Bye"), Accept)

		assert contents == b"This is a message sent in multiple chunks. Bye", contents

	async def test_body_skip(self) -> None:
		"""
		Check that Skip is returned once a body loop is broken
		"""
		contents = b""

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal contents
			async with session.body as body:
				async for chunk in body:
					contents += chunk
					if b"message" in chunk.tobytes():
						break

			# Move phase onto POST
			await session.change_sender("test@example.com")

			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, ProtocolFlags.SKIP),
				Negotiate(6, 0x1ff, ProtocolFlags.SKIP),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

			await stream_mock.send_and_expect(Body(b"This is a "), Continue)
			await stream_mock.send_and_expect(Body(b"message sent "), Skip)
			await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Skip)
			await stream_mock.send_and_expect(
				EndOfMessage(b"Bye"),
				ChangeSender("test@example.com"), Accept,
			)

		assert contents == b"This is a message sent ", contents

	async def test_body_fake_skip(self) -> None:
		"""
		Check that Skip is NOT returned if not accepted by an MTA
		"""
		contents = b""

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal contents
			async with session.body as body:
				async for chunk in body:
					contents += chunk
					if b"message" in chunk.tobytes():
						break

			# Move phase onto POST
			await session.change_sender("test@example.com")

			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

			await stream_mock.send_and_expect(Body(b"This is a "), Continue)
			await stream_mock.send_and_expect(Body(b"message sent "), Continue)
			await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue)
			await stream_mock.send_and_expect(
				EndOfMessage(b"Bye"),
				ChangeSender("test@example.com"), Accept,
			)

		assert contents == b"This is a message sent ", contents

	async def test_multiple(self) -> None:
		"""
		Check that multiple filters receive the messages they expect
		"""
		hostname = ""
		contents1 = b""
		contents2 = b""

		async def test_filter1(session: Session) -> Reject:
			nonlocal hostname
			nonlocal contents1

			hostname = await session.helo()

			async with session.body as body:
				async for chunk in body:
					await trio.sleep(0)
					contents1 += chunk

			return Reject()

		async def test_filter2(session: Session) -> Accept:
			nonlocal contents2

			async with session.body as body:
				async for chunk in body:
					await trio.sleep(0)
					contents2 += chunk
					if b"message" in chunk.tobytes():
						break

			return Accept()

		runner = Runner(test_filter1, test_filter2)

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(runner, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, ProtocolFlags.SKIP),
				Negotiate(6, 0x1ff, ProtocolFlags.SKIP),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(Helo("test.example.com"), Continue)

			await stream_mock.send_and_expect(Body(b"This is a "), Continue)
			await stream_mock.send_and_expect(Body(b"message sent "), Continue)
			await stream_mock.send_and_expect(Body(b"in multiple chunks. "), Continue)
			await stream_mock.send_and_expect(EndOfMessage(b"Bye"), Reject)

		assert hostname == "test.example.com", hostname
		assert contents1 == b"This is a message sent in multiple chunks. Bye", contents1
		assert contents2 == b"This is a message sent ", contents2

	async def test_abort(self) -> None:
		"""
		Check that a runner closes cleanly when it receives an Abort
		"""
		cancelled = False

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal cancelled
			try:
				await session.helo()
			except trio.Cancelled:
				cancelled = True
				raise
			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.abort()

		assert cancelled

	async def test_bad_response(self) -> None:
		"""
		Check that a runner closes cleanly when it receives an Abort
		"""
		@Runner
		async def test_filter(session: Session) -> Skip:
			await session.helo()
			return Skip()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0),
				Negotiate(6, 0x1ff, 0),
			)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

			with self.assertWarns(UserWarning) as wcm:
				await stream_mock.send_and_expect(Helo("test.example.com"), TemporaryFailure)

			assert "expected a final response" in str(wcm.warning)