Commit e1e95129 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Make Negotiate flag fields enum types

parent c4026c69
Loading
Loading
Loading
Loading
+10 −10
Original line number Diff line number Diff line
@@ -214,8 +214,8 @@ class FilterProtocol:
		self.nr = set[bytes]()
		self.actions = set[bytes]([messages.Progress.ident])
		self.state: tuple[messages.Message, set[bytes]]|None = None
		self._optflags: int = 0
		self._actflags: int = 0
		self._optflags = ProtocolFlags(0)
		self._actflags = ActionFlags(0)

	def read_from(
		self,
@@ -307,15 +307,15 @@ class FilterProtocol:
		"""
		# ActionFlag.SETSYMLIST must be set if Negotiate.macros is not empty
		if message.macros:
			if not ActionFlags.SETSYMLIST & message.action_flags:
			if ActionFlags.SETSYMLIST not in message.action_flags:
				message.action_flags |= ActionFlags.SETSYMLIST
				warn(f"adding {ActionFlags.SETSYMLIST!r} to {message}", stacklevel=4)
			if not ActionFlags.SETSYMLIST & self._actflags:
			if ActionFlags.SETSYMLIST not in self._actflags:
				raise ValueError("requesting symbols (macros) is not offered by the MTA")

		if (flags := message.protocol_flags & ~self._optflags):
			raise ValueError(f"requested options not offered by the MTA: {ProtocolFlags(flags)!r}")
		if (flags := message.action_flags & ~self._actflags):
			raise ValueError(f"requested actions not offered by the MTA: {ActionFlags(flags)!r}")
		self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag & message.protocol_flags)
		self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag & message.action_flags)
		if (pflags := message.protocol_flags & ~self._optflags):
			raise ValueError(f"requested options not offered by the MTA: {pflags!r}")
		if (aflags := message.action_flags & ~self._actflags):
			raise ValueError(f"requested actions not offered by the MTA: {aflags!r}")
		self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag in message.protocol_flags)
		self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag in message.action_flags)
+9 −6
Original line number Diff line number Diff line
@@ -143,6 +143,8 @@ class ActionFlags(BitField):
	https://pythonhosted.org/pymilter/milter_api/smfi_register.html#flags
	"""

	NONE = 0x0

	ADD_HEADERS = ADDHDRS = 0x1
	CHANGE_HEADERS = CHGHDRS = 0x10
	CHANGE_BODY = CHGBODY = 0x2
@@ -162,6 +164,8 @@ class ProtocolFlags(BitField):
	https://pythonhosted.org/pymilter/milter_api/xxfi_negotiate.html
	"""

	NONE = 0x0

	NO_CONNECT = 0x1
	NO_HELO = 0x2
	NO_SENDER = NO_MAIL = 0x4
@@ -380,9 +384,8 @@ class Negotiate(Message, ident=b"O"):

	version: int

	# TODO: use set[Enum]?
	action_flags: int
	protocol_flags: int
	action_flags: ActionFlags
	protocol_flags: ProtocolFlags

	macros: Mapping[Stage, Collection[str]] = field(default_factory=dict)

@@ -390,14 +393,14 @@ class Negotiate(Message, ident=b"O"):

	@classmethod
	def from_buffer(cls, buf: memoryview) -> Self:
		opts = cast(tuple[int, int, int], cls._struct.unpack_from(buf))
		version, actions, options = cast(tuple[int, int, int], cls._struct.unpack_from(buf))
		buf = buf[cls._struct.size:]
		macros = dict()
		while len(buf) > 0:
			stage, *_ = LONG.unpack_from(buf)
			names, buf = split_cstring(buf[LONG.size:])
			macros[Stage(stage)] = names.tobytes().decode("utf-8").split()
		return cls(*opts, macros)
			macros[Stage(stage)] = str(names, "utf-8").split()
		return cls(version, ActionFlags(actions), ProtocolFlags(options), macros)

	def to_buffer(self, buf: FixedSizeBuffer) -> None:
		self._struct.pack_into(
+1 −1
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ async def process_client(client: trio.SocketStream, nursery: trio.Nursery) -> No
				logging.info(f"RECEIVED {message!r}")
				match message:
					case messages.Negotiate():
						message.protocol_flags = 0
						message.protocol_flags = messages.ProtocolFlags.NONE
						await send_channel.send(message)
					case messages.Macro() | messages.Abort() | messages.Close():
						continue
+36 −33
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@ from kilter.protocol.exceptions import UnknownMessage
from kilter.protocol.messages import NoDataMessage
from kilter.protocol.messages import *

ALL_ACTION_FLAGS = ActionFlags(0x1ff)
ALL_PROTOCOL_FLAGS = ProtocolFlags(0xfffff)


class FilterProtocolTests(unittest.TestCase):
	"""
@@ -29,7 +32,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Macro(b"\x00", dict(spam="ham")).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)

@@ -40,7 +43,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x01, 0x11f),
						Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags(0x13e)),
					)
				case Connect():
					protocol.write_to(
@@ -79,7 +82,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Macro(b"\x00", dict(spam="ham")).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)

@@ -105,7 +108,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		Helo("example.com").pack(buf)

@@ -116,7 +119,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x00, ProtocolFlags.NR_CONNECT),
						Negotiate(6, ActionFlags.NONE, ProtocolFlags.NR_CONNECT),
					)
				case Connect():
					pass
@@ -127,7 +130,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)

		protocol = FilterProtocol()
@@ -137,7 +140,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x00, 0x00),
						Negotiate(6, ActionFlags.NONE, ProtocolFlags.NONE),
					)
				case Connect():
					protocol.write_to(
@@ -159,7 +162,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)

		protocol = FilterProtocol()
@@ -169,7 +172,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x00, ProtocolFlags.NR_CONNECT),
						Negotiate(6, ActionFlags.NONE, ProtocolFlags.NR_CONNECT),
					)
				case Connect():
					with self.assertRaises(UnexpectedMessage):
@@ -187,7 +190,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		Helo("example.com").pack(buf)

@@ -198,7 +201,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, ActionFlags.ADD_HEADERS, 0x00),
						Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE),
					)
				case Connect():
					with self.assertRaises(UnexpectedMessage):
@@ -216,7 +219,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		Helo("example.com").pack(buf)
		Data().pack(buf)
@@ -230,7 +233,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x00, 0x00),
						Negotiate(6, ActionFlags.NONE, ProtocolFlags.NONE),
					)
				case EndOfMessage():
					with self.assertRaises(UnexpectedMessage):
@@ -253,7 +256,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		Helo("example.com").pack(buf)
		Data().pack(buf)
@@ -267,7 +270,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, ActionFlags.ADD_HEADERS, 0x00),
						Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE),
					)
				case EndOfMessage():
					protocol.write_to(
@@ -290,7 +293,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(100)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)

		protocol = FilterProtocol()
@@ -300,7 +303,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, ActionFlags.ADD_HEADERS, 0x00),
						Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE),
					)
				case Connect():
					with self.assertRaises(InvalidMessage):
@@ -318,7 +321,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine
@@ -326,7 +329,7 @@ class FilterProtocolTests(unittest.TestCase):
		with self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(20),
				Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K),
				Negotiate(6, ActionFlags.NONE, ProtocolFlags.MAX_DATA_SIZE_256K),
			)

	def test_disallowed_actions(self) -> None:
@@ -335,7 +338,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff & ~ActionFlags.CHANGE_BODY, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.CHANGE_BODY, ALL_PROTOCOL_FLAGS).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine
@@ -343,7 +346,7 @@ class FilterProtocolTests(unittest.TestCase):
		with self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(20),
				Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00),
				Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, ProtocolFlags.NONE),
			)

	def test_unrequested_action(self) -> None:
@@ -352,7 +355,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(60)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		EndOfMessage(b"").pack(buf)

@@ -363,7 +366,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x00, 0xfffff),
						Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS),
					)
				case EndOfMessage():
					with self.assertRaises(UnexpectedMessage):
@@ -378,7 +381,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(60)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		EndOfMessage(b"").pack(buf)

@@ -389,7 +392,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x1ff, 0xfffff),
						Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS),
					)
				case EndOfMessage():
					with catch_warnings(record=True) as warn_cm:
@@ -405,7 +408,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(60)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		EndOfMessage(b"").pack(buf)

@@ -416,7 +419,7 @@ class FilterProtocolTests(unittest.TestCase):
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x1ff, 0xfffff),
						Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS),
					)
				case EndOfMessage():
					with self.assertRaises(UnexpectedMessage):
@@ -431,7 +434,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine
@@ -439,7 +442,7 @@ class FilterProtocolTests(unittest.TestCase):
		with self.assertWarns(UserWarning):
			protocol.write_to(
				SimpleBuffer(40),
				Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}),
				Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}),
			)

	def test_setsymlist_disallowed(self) -> None:
@@ -448,7 +451,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine
@@ -456,7 +459,7 @@ class FilterProtocolTests(unittest.TestCase):
		with self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(40),
				Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}),
				Negotiate(6, ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}),
			)

	def test_setsymlist_implicit_disallowed(self) -> None:
@@ -465,7 +468,7 @@ class FilterProtocolTests(unittest.TestCase):
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf)
		Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine
@@ -473,5 +476,5 @@ class FilterProtocolTests(unittest.TestCase):
		with self.assertWarns(UserWarning), self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(40),
				Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}),
				Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}),
			)