Commit 212699ba authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Use typing_extensions.Self instead of a TypeVar

Closes #6
parent 5299aee9
Loading
Loading
Loading
Loading
Loading
+17 −57
Original line number Diff line number Diff line
# Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2022-2023 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -33,6 +33,8 @@ from typing import ClassVar
from typing import TypeVar
from typing import cast

from typing_extensions import Self

from .exceptions import InsufficientSpace
from .exceptions import NeedsMore

@@ -230,9 +232,6 @@ class Message(metaclass=ABCMeta):
	>>> del buf[:10]
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Message")

	ident: ClassVar[bytes]

	_message_classes = dict[bytes, "type[Message]"]()
@@ -270,7 +269,7 @@ class Message(metaclass=ABCMeta):

	@classmethod
	@abstractmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		"""
		Construct an instance with values unpacked from a buffer

@@ -321,14 +320,11 @@ class NoDataMessage(Message):
	Base class implementing `Message` abstract methods for messages with no contents
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="NoDataMessage")

	def __repr__(self) -> str:
		return f"{self.__class__.__name__}()"

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		assert len(buf) == 0, "message has some data"
		return cls()

@@ -341,9 +337,6 @@ class BytesMessage(Message):
	Base class implementing `Message` abstract methods for messages with unstructured contents
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="BytesMessage")

	content: memoryview

	def __init__(self, content: bytes):
@@ -360,7 +353,7 @@ class BytesMessage(Message):
		return other.content == self.content

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		return cls(buf)

	def to_buffer(self, buf: FixedSizeBuffer) -> None:
@@ -383,9 +376,6 @@ class Negotiate(Message, ident=b"O"):
	wants before each stage.
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Negotiate")

	version: int

	# TODO: use set[Enum]?
@@ -397,7 +387,7 @@ class Negotiate(Message, ident=b"O"):
	_struct = Struct("!LLL")

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		opts = cast(tuple[int, int, int], cls._struct.unpack_from(buf))
		buf = buf[cls._struct.size:]
		macros = dict()
@@ -428,16 +418,13 @@ class Macro(Message, ident=b"D"):
	A message type for transferring symbol mappings prior to a stage event
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Macro")

	stage: bytes
	macros: Mapping[str, str]

	_struct = Struct("!c")

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		stage, *_ = cls._struct.unpack_from(buf)
		macros = {}
		with buf[1:] as buf:
@@ -470,9 +457,6 @@ class Connect(Message, ident=b"C"):
	or `None` for connections for which there is no known address (e.g. stdin).
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Connect")

	hostname: str
	address: IPv4Address|IPv6Address|Path|None = None
	port: int = 0
@@ -480,7 +464,7 @@ class Connect(Message, ident=b"C"):
	_struct = Struct("!H")

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		_hostname, buf = split_cstring(buf)
		hostname = _hostname.tobytes().decode("idna")
		family, buf = Family(buf[0:1].tobytes()), buf[1:]
@@ -519,13 +503,10 @@ class Helo(Message, ident=b"H"):
	An event message reporting a client sent an SMTP HELO/EHLO command
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Helo")

	hostname: str

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		hostname, _ = split_cstring(buf)
		return cls(hostname.tobytes().decode("idna"))

@@ -539,14 +520,11 @@ class EnvelopeFrom(Message, ident=b"M"):
	An event message reporting a client sent an SMTP "MAIL FROM" command
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="EnvelopeFrom")

	sender: bytes
	arguments: list[bytes] = field(default_factory=list)

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		args = cstring_iter(buf)
		return cls(next(args), [*args])

@@ -571,14 +549,11 @@ class EnvelopeRecipient(Message, ident=b"R"):
	A client must send at least one "RCPT TO" command, and can send multiple.
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="EnvelopeRecipient")

	recipient: bytes
	arguments: list[bytes] = field(default_factory=list)

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		args = cstring_iter(buf)
		return cls(next(args), [*args])

@@ -618,14 +593,11 @@ class Header(Message, ident=b"L"):
	Transfers a header name and value from an email to a filter
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Header")

	name: str
	value: bytes

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		name, buf = split_cstring(buf)
		value, buf = split_cstring(buf)
		assert len(buf) == 0
@@ -725,13 +697,10 @@ class _AddrCmd(Message):
	Base class implementing `Message` abstract methods for messages with a single address
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="_AddrCmd")

	address: str

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		address, buf = split_cstring(buf)
		assert len(buf) == 0
		return cls(address.tobytes().decode("utf-8"))
@@ -746,14 +715,11 @@ class _AddrParCmd(Message):
	Base class implementing `Message` abstract methods for messages with an address and arguments
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="_AddrParCmd")

	address: str
	args: str|None = None

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		args: memoryview|None = None
		address, buf = split_cstring(buf)
		if len(buf) > 0:
@@ -784,15 +750,12 @@ class ChangeHeader(Message, ident=b"m"):
	Message from a filter to request a header is modified or removed at a given index
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="ChangeHeader")

	index: int
	name: str
	value: bytes

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		index, *_ = LONG.unpack_from(buf)
		name, buf = split_cstring(buf[LONG.size:])
		value, buf = split_cstring(buf)
@@ -857,13 +820,10 @@ class Quarantine(Message, ident=b"q"):
	Request that a message is quarantined (blocked, but kept for review)
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Quarantine")

	reason: str

	@classmethod
	def from_buffer(cls: type[Self], buf: memoryview) -> Self:
	def from_buffer(cls, buf: memoryview) -> Self:
		reason, buf = split_cstring(buf)
		assert len(buf) == 0
		return cls(reason.tobytes().decode("utf-8"))