Verified Commit d29e7964 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add macro requests to filter options

parent 61f5e77b
Loading
Loading
Loading
Loading
+44 −7
Original line number Diff line number Diff line
@@ -10,12 +10,14 @@ Filter decorators for marking the requested protocol options and actions used

from __future__ import annotations

from collections import defaultdict
from typing import Callable
from typing import Literal
from typing import NamedTuple

from kilter.protocol.messages import ActionFlags
from kilter.protocol.messages import ProtocolFlags
from kilter.protocol.messages import Stage

from .session import Filter

@@ -24,12 +26,14 @@ __all__ = [
	"examine_sender", "examine_recipients",
	"examine_headers", "examine_body",
	"get_flags", "modify_flags",
	"get_macros", "request_macros",
]

Decorator = Callable[[Filter], Filter]
SIZES = Literal[ProtocolFlags.NONE, ProtocolFlags.MDS_256K, ProtocolFlags.MDS_1M]

FLAGS_ATTRIBUTE = "filter_flags"
MACRO_ATTRIBUTE = "filter_macros"

DEFAULT_UNSET = \
	ProtocolFlags.NO_CONNECT | ProtocolFlags.NO_HELO | \
@@ -60,13 +64,7 @@ def modify_flags(
	Return a decorator that modifies the given flags on a decorated filter
	"""
	def decorator(filtr: Filter) -> Filter:
		flags = _get_flags(filtr, FlagsTuple())
		flags = FlagsTuple(
			flags.unset_options|unset_options,
			flags.set_options|set_options,
			flags.set_actions|set_actions,
		)
		setattr(filtr, FLAGS_ATTRIBUTE, flags)
		_set_flags(filtr, set_options, unset_options, set_actions)
		return filtr
	return decorator

@@ -79,11 +77,50 @@ def get_flags(filtr: Filter) -> FlagsTuple:
	return _get_flags(filtr, default)


def _set_flags(
	filtr: Filter,
	set_options: ProtocolFlags = ProtocolFlags.NONE,
	unset_options: ProtocolFlags = ProtocolFlags.NONE,
	set_actions: ActionFlags = ActionFlags.NONE,
) -> None:
	flags = _get_flags(filtr, FlagsTuple())
	flags = FlagsTuple(
		flags.unset_options|unset_options,
		flags.set_options|set_options,
		flags.set_actions|set_actions,
	)
	setattr(filtr, FLAGS_ATTRIBUTE, flags)


def _get_flags(filtr: Filter, default: FlagsTuple) -> FlagsTuple:
	assert isinstance(getattr(filtr, FLAGS_ATTRIBUTE, default), FlagsTuple)
	return getattr(filtr, FLAGS_ATTRIBUTE, default)


def request_macros(stage: Stage, *names: str) -> Decorator:
	"""
	Return a decorator that adds the given macro requests to a decorated filter
	"""
	def decorator(filtr: Filter) -> Filter:
		_set_flags(filtr, set_actions=ActionFlags.SETSYMLIST)
		macros = get_macros(filtr)
		macros[stage].update(names)
		return filtr
	return decorator


def get_macros(filtr: Filter) -> defaultdict[Stage, set[str]]:
	"""
	Return the requested macros attached to a filter
	"""
	try:
		macros = getattr(filtr, MACRO_ATTRIBUTE)
	except AttributeError:
		setattr(filtr, MACRO_ATTRIBUTE, (macros := defaultdict(set)))
	assert isinstance(macros, defaultdict)
	return macros


def responds_to_connect() -> Decorator:
	"""
	Mark a filter as possibly delivering a non-continue response to Connect events
+47 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from unittest import TestCase
from kilter.protocol import Accept
from kilter.protocol import ActionFlags
from kilter.protocol import ProtocolFlags
from kilter.protocol import Stage
from kilter.service import Session
from kilter.service import options

@@ -328,3 +329,49 @@ class Tests(TestCase):
			assert MAX_DATA_SIZE_1M in resolve_opts(flags, MAX_DATA_SIZE_1M)

			assert CHANGE_BODY in flags.set_actions


class MacroTests(TestCase):
	"""
	Tests for macro requests decorator
	"""

	def test_undecorated(self) -> None:
		"""
		Check that `get_macros` returns a default mapping for undecorated filters
		"""
		@options.responds_to_connect()
		async def filter_flag_decorator(session: Session) -> Accept:
			return Accept()

		async def filter_no_decorator(session: Session) -> Accept:
			return Accept()

		for filtr in (filter_flag_decorator, filter_no_decorator):
			with self.subTest(filter=filtr):

				macros = options.get_macros(filtr)

				assert isinstance(macros, dict)
				assert len(macros) == 0

	def test_decorated(self) -> None:
		"""
		Check that `get_macros` returns a good mapping for decorated filters
		"""
		@options.request_macros(Stage.CONNECT, "spam", "ham")
		@options.request_macros(Stage.CONNECT, "eggs")
		@options.request_macros(Stage.HELO, "spam", "ham")
		@options.request_macros(Stage.HELO, "spam", "eggs")
		async def filtr(session: Session) -> Accept:
			return Accept()

		macros = options.get_macros(filtr)

		self.assertDictEqual(
			macros,
			{
				Stage.CONNECT: {"spam", "ham", "eggs"},
				Stage.HELO: {"spam", "ham", "eggs"},
			},
		)