Commit 443552a3 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add session module with Session

parent dd9e9817
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -101,6 +101,6 @@ repos:
    args: [--follow-imports=silent]
    additional_dependencies:
    - anyio
    - kilter.protocol
    - kilter.protocol ~=0.1.3
    - sphinx
    - trio-typing
+12 −9
Original line number Diff line number Diff line
@@ -75,7 +75,8 @@ async def reject_black_knight(session: Session) -> Reject|Accept:
	if (await session.envelope_from()) == BLOCK:
		return Reject()

	async for header in session.headers:
	async with session.headers as headers:
		async for header in headers:
			if header.name == "From" and header.value == BLOCK:
				return Reject()

@@ -94,7 +95,8 @@ async def strip_x_headers(session: Session) -> Accept:
	remove = []

	# iterate over headers as they arrive and select ones for later removal
	async for header in session.headers:
	async with session.headers as headers:
		async for header in headers:
			if header.name.startswith("X-"):
				remove.append(header)

@@ -112,7 +114,8 @@ async def strip_x_headers(session: Session) -> Accept:
	await session.headers.collect()

	# iterate over collected headers during the post phase, removing the unwanted ones
	async for header in session.headers:
	async with session.headers as headers:
		async for header in headers:
			if header.name.startswith("X-"):
				await session.headers.remove(header)
```
+7 −0
Original line number Diff line number Diff line
@@ -8,4 +8,11 @@ project). The framework aims to provide Pythonic interfaces for implementing fi
including leveraging coroutines instead of libmilter's callback-style interface.
"""

from .session import ResponseMessage as ResponseMessage
from .session import Session as Session
from .session import Before as Before
from .session import After as After
from .session import START as START
from .session import END as END

__version__ = "0.1"
+430 −0
Original line number Diff line number Diff line
# Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""
Sessions are the kernel of a filter, providing it with an async API to access messages
"""

from __future__ import annotations

from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from dataclasses import dataclass
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING
from typing import AsyncContextManager
from typing import Literal
from typing import Protocol
from typing import TypeVar
from typing import Union

from ..protocol.messages import *
from . import util

EventMessage = Union[
	Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data, Unknown,
	Header, EndOfHeaders, Body, EndOfMessage,
]
ResponseMessage = Union[
	Continue, Reject, Discard, Accept, TemporaryFailure, Skip,
	ReplyCode, Abort,
]
EditMessage = Union[
	AddHeader, ChangeHeader, InsertHeader, ChangeSender, AddRecipient, AddRecipientPar,
	RemoveRecipient, ReplaceBody,
]


class Filter(Protocol):
	"""
	Filters are callables that accept a `Session` and return a response
	"""

	async def __call__(self, session: Session) -> ResponseMessage: ...  # noqa: D102


class Phase(int, Enum):
	"""
	Session phases indicate what messages to expect and are impacted by received messages
	"""

	CONNECT = 1
	MAIL = 2
	ENVELOPE = 3
	HEADERS = 4
	BODY = 5
	POST = 6


@dataclass
class Position:
	"""
	A base class for `Before` and `After`, this class is not intended to be used directly
	"""

	subject: Header|Literal["start"]|Literal["end"]


class Before(Position):
	"""
	Indicates a relative position preceding a subject `Header` in a header list
	"""

	subject: Header


class After(Position):
	"""
	Indicates a relative position following a subject `Header` in a header list
	"""

	subject: Header


START = Position("start")
END = Position("end")


class Session:
	"""
	The kernel of a filter, providing an API for filters to access messages from and MTA
	"""

	def __init__(self, connmsg: Connect, sender: AsyncGenerator[None, EditMessage|Skip]):
		self.host = connmsg.hostname
		self.address = connmsg.address
		self.port = connmsg.port

		self._editor = sender
		self._broadcast = util.Broadcast[EventMessage]()

		self.headers = HeadersAccessor(self, sender)
		self.body = BodyAccessor(self, sender)

		self.skip = False

		# Phase checking is a bit fuzzy as a filter may not request every message,
		# so some phases will be skipped; checks should not try to exactly match a phase.
		self.phase = Phase.CONNECT

	async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]:
		"""
		Deliver a message (or its contents) to a task waiting for it
		"""
		match message:
			case Body() if self.skip:
				return Skip
			case Helo():
				self.phase = Phase.MAIL
			case EnvelopeFrom() | EnvelopeRecipient() | Unknown():
				self.phase = Phase.ENVELOPE
			case Data() | Header():
				self.phase = Phase.HEADERS
			case EndOfHeaders() | Body():
				self.phase = Phase.BODY
			case EndOfMessage():  # pragma: no-branch
				self.phase = Phase.POST
		await self._broadcast.send(message)
		return Skip if self.phase == Phase.BODY and self.skip else Continue

	async def helo(self) -> str:
		"""
		Wait for a HELO/EHLO message and return the client's claimed hostname
		"""
		if self.phase > Phase.CONNECT:
			raise RuntimeError(
				"Session.helo() must be awaited before any other async features of a "
				"Session",
			)
		async with self._broadcast:
			while self.phase <= Phase.CONNECT:
				message = await self._broadcast.receive()
				if isinstance(message, Helo):
					return message.hostname
		raise RuntimeError("HELO/EHLO event not received")

	async def envelope_from(self) -> str:
		"""
		Wait for a MAIL command message and return the sender identity

		Note that if extensions arguments are wanted, users should use `Session.extension()`
		instead with a name of `MAIL`.
		"""
		if self.phase > Phase.MAIL:
			raise RuntimeError(
				"Session.envelope_from() may only be awaited before the ENVELOPE phase",
			)
		async with self._broadcast:
			while self.phase <= Phase.MAIL:
				message = await self._broadcast.receive()
				if isinstance(message, EnvelopeFrom):
					return message.sender.decode()
		raise RuntimeError("MAIL event not received")

	async def envelope_recipients(self) -> AsyncIterator[str]:
		"""
		Wait for RCPT command messages and iteratively yield the recipients' identities

		Note that if extensions arguments are wanted, users should use `Session.extension()`
		instead with a name of `RCPT`.
		"""
		if self.phase > Phase.ENVELOPE:
			raise RuntimeError(
				"Session.envelope_from() may only be awaited before the HEADERS phase",
			)
		async with self._broadcast:
			while self.phase <= Phase.ENVELOPE:
				message = await self._broadcast.receive()
				if isinstance(message, EnvelopeRecipient):
					yield message.recipient.decode()

	async def extension(self, name: str) -> memoryview:
		"""
		Wait for the named command extension and return the raw command for processing
		"""
		if self.phase > Phase.ENVELOPE:
			raise RuntimeError(
				"Session.extension() may only be awaited before the HEADERS phase",
			)
		async with self._broadcast:
			while self.phase <= Phase.ENVELOPE:
				message = await self._broadcast.receive()
				match message:
					case Unknown():
						bname = name.encode("utf-8")
						if message.content[:len(bname)] == bname:
							return message.content
					# fake buffers for MAIL and RCPT commands
					case EnvelopeFrom() if name == "MAIL":
						vals = [b"MAIL FROM", message.sender, *message.arguments]
						return memoryview(b" ".join(vals))
					case EnvelopeRecipient() if name == "RCPT":
						vals = [b"RCPT TO", message.recipient, *message.arguments]
						return memoryview(b" ".join(vals))
		raise RuntimeError(f"{name} event not received")

	async def change_sender(self, sender: str, args: str = "") -> None:
		"""
		Move onto the `Phase.POST` phase and instruct the MTA to change the sender address
		"""
		await _until_editable(self)
		await self._editor.asend(ChangeSender(sender, args or None))

	async def add_recipient(self, recipient: str, args: str = "") -> None:
		"""
		Move onto the `Phase.POST` phase and instruct the MTA to add a new recipient address
		"""
		await _until_editable(self)
		await self._editor.asend(
			AddRecipientPar(recipient, args) if args else AddRecipient(recipient),
		)

	async def remove_recipient(self, recipient: str) -> None:
		"""
		Move onto the `Phase.POST` phase and instruct the MTA to remove a recipient address
		"""
		await _until_editable(self)
		await self._editor.asend(RemoveRecipient(recipient))


class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
	"""
	A class that allows access and modification of the message headers sent from an MTA

	To access headers (which are only available iteratively), use an instance as an
	asynchronous context manager; a `HeaderIterator` is returned when the context is
	entered.
	"""

	def __init__(self, session: Session, sender: AsyncGenerator[None, EditMessage]):
		self.session = session
		self._editor = sender
		self._table = list[Header]()
		self._aiter: AsyncGenerator[Header, None]

	async def __aenter__(self) -> HeaderIterator:
		self._aiter = HeaderIterator(self.__aiter())
		return self._aiter

	async def __aexit__(self, *_: object) -> None:
		await self._aiter.aclose()

	async def __aiter(self) -> AsyncGenerator[Header, None]:
		async with self.session._broadcast:
			# yield from cached headers first; allows multiple tasks to access the headers
			# in an uncoordinated manner; note the broadcaster is locked at this point
			for header in self._table:
				yield header
			while self.session.phase <= Phase.HEADERS:
				match (await self.session._broadcast.receive()):
					case Header() as header:
						self._table.append(header)
						try:
							yield header
						except GeneratorExit:
							await self._collect()
							raise
					case EndOfHeaders():
						return

	async def collect(self) -> None:
		"""
		Collect all headers without producing an iterator

		Calling this method before the `Phase.BODY` phase allows later processing of headers
		(after the HEADER phase) without the need for an empty loop.
		"""
		async with self.session._broadcast:
			await self._collect()

	async def _collect(self) -> None:
		# note the similarities between this and __aiter; the difference is no mutex or
		# yields
		while self.session.phase <= Phase.HEADERS:
			match (await self.session._broadcast.receive()):
				case Header() as header:
					self._table.append(header)
				case EndOfHeaders():
					return

	async def delete(self, header: Header) -> None:
		"""
		Move onto the `Phase.POST` phase and Instruct the MTA to delete the given header
		"""
		await self.collect()
		await _until_editable(self.session)
		index = self._table.index(header)
		await self._editor.asend(ChangeHeader(index, header.name, b""))
		del self._table[index]

	async def update(self, header: Header, value: bytes) -> None:
		"""
		Move onto the `Phase.POST` phase and Instruct the MTA to modify the value of a header
		"""
		await self.collect()
		await _until_editable(self.session)
		index = self._table.index(header)
		await self._editor.asend(ChangeHeader(index, header.name, value))
		self._table[index].value = value

	async def insert(self, header: Header, position: Position) -> None:
		"""
		Move onto the `Phase.POST` phase and instruct the MTA to insert a new header

		The header is inserted at `START`, `END`, or a relative position with `Before` and
		`After`; for example `Before(Header("To", "test@example.com"))`.
		"""
		await self.collect()
		await _until_editable(self.session)
		match position:
			case Position(subject="start"):
				index = 0
			case Position(subject="end"):
				index = len(self._table)
			case Before():
				index = self._table.index(position.subject)
			case After():  # pragma: no-branch
				index = self._table.index(position.subject) + 1
		if index >= len(self._table):
			await self._editor.asend(AddHeader(header.name, header.value))
			self._table.append(header)
		else:
			await self._editor.asend(InsertHeader(index, header.name, header.value))
			self._table.insert(index, header)


class HeaderIterator(AsyncGenerator[Header, None]):
	"""
	Iterator for headers obtained by using a `HeaderAccessor` as a context manager
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="HeaderIterator")

	def __init__(self, aiter: AsyncGenerator[Header, None]):
		self._aiter = aiter

	def __aiter__(self: Self) -> Self:
		return self

	async def __anext__(self) -> Header:  # noqa: D102
		return await self._aiter.__anext__()

	async def asend(self, value: None = None) -> Header:  # noqa: D102
		return await self._aiter.__anext__()

	async def athrow(  # noqa: D102
		self,
		e: type[BaseException]|BaseException,
		m: object = None,
		t: TracebackType|None = None, /,
	) -> Header:
		if isinstance(e, type):
			return await self._aiter.athrow(e, m, t)
		assert m is None
		return await self._aiter.athrow(e, m, t)

	async def aclose(self) -> None:  # noqa: D102
		await self._aiter.aclose()

	async def restrict(self, *names: str) -> AsyncIterator[Header]:
		"""
		Return an asynchronous generator that filters headers by name
		"""
		async for header in self._aiter:
			if header.name in names:
				yield header


class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
	"""
	A class that allows access and modification of the message body sent from an MTA

	To access chunks of abody (which are only available iteratively), use an instance as an
	asynchronous context manager; a `BodyIterator` is returned when the context is
	entered.
	"""

	def __init__(self, session: Session, sender: AsyncGenerator[None, EditMessage|Skip]):
		self.session = session
		self._editor = sender

	async def __aenter__(self) -> AsyncIterator[memoryview]:
		self._aiter = self.__aiter()
		return self._aiter

	async def __aexit__(self, *_: object) -> None:
		await self._aiter.aclose()

	async def __aiter(self) -> AsyncGenerator[memoryview, None]:
		async with self.session._broadcast:
			while self.session.phase <= Phase.BODY:
				match (await self.session._broadcast.receive()):
					case Body() as body:
						try:
							yield body.content
						except GeneratorExit:
							self.session.skip = True
							raise
					case EndOfMessage() as eom:
						if not self.session.skip:
							yield eom.content

	async def write(self, chunk: bytes) -> None:
		"""
		Request that chunks of a new message body are sent to the MTA
		"""
		await _until_editable(self.session)
		await self._editor.asend(ReplaceBody(chunk))


async def _until_editable(session: Session) -> None:
	if session.phase == Phase.POST:
		return
	async with session._broadcast:
		while session.phase < Phase.POST:
			await session._broadcast.receive()
+1 −1
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ requires-python = ">=3.10,<4"

dependencies = [
	"anyio",
	"kilter.protocol",
	"kilter.protocol ~=0.1.3",
]
classifiers = [
	"Development Status :: 1 - Planning",
Loading