Verified Commit 717d4f63 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Refine the allowed responses of a session.Filter

… and remove the work-around for
https://github.com/python/mypy/issues/14242
parent 8c57c6ad
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ from .options import get_flags
from .options import get_macros
from .session import Aborted
from .session import Filter
from .session import FilterResponse
from .session import Session
from .util import Broadcast
from .util import qualname
@@ -45,8 +46,6 @@ MessageChannel: TypeAlias = anyio.abc.ObjectStream[Message]
kiB: Final = 2**10
MiB: Final = 2**20

# TODO: Convert to Union type alias once python/mypy#14242 is fixed
_VALID_FINAL_RESPONSES: Final = Reject, Discard, Accept, TemporaryFailure, ReplyCode
_VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \
	Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort

@@ -378,7 +377,7 @@ class _TaskRunner:
				except Exception:
					_logger.exception(f"Error in filter {qualname(fltr)}")
					final_resp = TemporaryFailure()
				if not isinstance(final_resp, _VALID_FINAL_RESPONSES):
				if not isinstance(final_resp, FilterResponse):
					warn(f"expected a valid response from {qualname(fltr)}, got {final_resp}")
					final_resp = TemporaryFailure()

+4 −2
Original line number Diff line number Diff line
@@ -23,16 +23,18 @@ from types import TracebackType
from typing import AsyncContextManager
from typing import Literal
from typing import Protocol
from typing import TypeAlias
from warnings import warn

from typing_extensions import Self

from ..protocol.core import EditMessage
from ..protocol.core import EventMessage
from ..protocol.core import ResponseMessage
from ..protocol.messages import *
from . import util

FilterResponse: TypeAlias = Accept | Reject | Discard | ReplyCode


class Aborted(BaseException):
	"""
@@ -45,7 +47,7 @@ class Filter(Protocol):
	Filters are callables that accept a `Session` and return a response
	"""

	async def __call__(self, session: Session, /) -> ResponseMessage: ...  # noqa: D102
	async def __call__(self, session: Session, /) -> FilterResponse: ...  # noqa: D102


class Sender(Protocol):