Commit 130c4bc9 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Narrow allowed message classes of methods on FilterProtocol

Closes #5
parent cd395b55
Loading
Loading
Loading
Loading
+70 −9
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@ from __future__ import annotations

from typing import Iterable
from typing import Sequence
from typing import TypeAlias
from typing import Union
from warnings import warn

from . import messages
@@ -24,6 +26,51 @@ from .exceptions import InvalidMessage
from .exceptions import NeedsMore
from .exceptions import UnexpectedMessage
from .exceptions import UnimplementedWarning
from .messages import *

EventMessage: TypeAlias = Union[
	Connect,
	Helo,
	EnvelopeFrom,
	EnvelopeRecipient,
	Data,
	Unknown,
	Header,
	EndOfHeaders,
	Body,
	EndOfMessage,
	Macro,
	Abort,
]
"""
Messages sent from an MTA to a filter
"""

ResponseMessage: TypeAlias = Union[
	Continue,
	Reject,
	Discard,
	Accept,
	TemporaryFailure,
	ReplyCode,
]
"""
Messages send from a filter to an MTA in response to `EventMessages`
"""

EditMessage: TypeAlias = Union[
	AddHeader,
	ChangeHeader,
	InsertHeader,
	ChangeSender,
	AddRecipient,
	AddRecipientPar,
	RemoveRecipient,
	ReplaceBody,
]
"""
Messages send from a filter to an MTA after an `EndOfMessage` to modify a message
"""


class Unimplemented(messages.BytesMessage, ident=b"\x00"):
@@ -144,7 +191,10 @@ class FilterProtocol:
		self.nr = {m.ident for m in non_responders}
		self.state: tuple[messages.Message, set[bytes]]|None = None

	def read_from(self, buf: FixedSizeBuffer) -> Iterable[messages.Message]:
	def read_from(
		self,
		buf: FixedSizeBuffer,
	) -> Iterable[Negotiate|EventMessage|Close|Unimplemented]:
		"""
		Return an iterator yielding each complete message from a buffer

@@ -166,21 +216,24 @@ class FilterProtocol:
				yield Unimplemented(data)
				del buf[:len(data)]
			else:
				self._check_recv(message)
				yield message
				yield self._check_recv(message)
				message.release()
				del buf[:size]

	def write_to(self, buf: FixedSizeBuffer, message: messages.Message) -> None:
	def write_to(
		self,
		buf: FixedSizeBuffer,
		message: ResponseMessage|EditMessage|Skip,
	) -> None:
		"""
		Validate and pack response and modification messages into a buffer
		"""
		self._check_send(message)
		message.pack(buf)

	def _check_recv(self, message: messages.Message) -> None:
	def _check_recv(self, message: messages.Message) -> Negotiate|EventMessage|Close:
		if isinstance(message, messages.Macro):
			return
			return message
		if isinstance(message, messages.Negotiate):
			self._store_mta_flags(message)
		if self.state is not None:
@@ -189,9 +242,17 @@ class FilterProtocol:
			responses = MTA_EVENT_RESPONSES[message.ident]
		except KeyError:
			raise InvalidMessage(message)
		if responses is None or message.ident in self.nr:
			return
		else:
			assert isinstance(
				message,
				(
					Negotiate, Macro, Connect, Helo, EnvelopeFrom, EnvelopeRecipient, Data,
					Unknown, Header, EndOfHeaders, Body, EndOfMessage, Abort, Close,
				),
			)
		if responses is not None and message.ident not in self.nr:
			self.state = message, responses
		return message

	def _check_send(self, message: messages.Message) -> None:
		if self.state is None: