Commit b40495a2 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Implement validation of returned Negotiate messages

Closes #2
parent d4654bc0
Loading
Loading
Loading
Loading
+46 −13
Original line number Diff line number Diff line
@@ -177,17 +177,28 @@ MTA_EVENT_RESPONSES = {
	messages.Close.ident: None,
}

UPDATE_RESPONSES = {
	messages.ChangeHeader.ident,
	messages.AddHeader.ident,
	messages.InsertHeader.ident,
	messages.ChangeSender.ident,
	messages.AddRecipient.ident,
	messages.AddRecipientPar.ident,
	messages.RemoveRecipient.ident,
	messages.ReplaceBody.ident,
	messages.Progress.ident,
	messages.Quarantine.ident,
UPDATE_FLAG_MAP = {
	messages.ChangeHeader.ident:        ActionFlags.CHANGE_HEADERS,
	messages.AddHeader.ident:           ActionFlags.ADD_HEADERS,
	messages.InsertHeader.ident:        ActionFlags.ADD_HEADERS,
	messages.ChangeSender.ident:        ActionFlags.CHANGE_FROM,
	messages.AddRecipient.ident:        ActionFlags.ADD_RECIPIENT,
	messages.AddRecipientPar.ident:     ActionFlags.ADD_RECIPIENT_PAR,
	messages.RemoveRecipient.ident:     ActionFlags.DELETE_RECIPIENT,
	messages.ReplaceBody.ident:         ActionFlags.CHANGE_BODY,
	messages.Quarantine.ident:          ActionFlags.QUARANTINE,
}

NR_FLAG_MAP = {
	messages.Connect.ident:             ProtocolFlags.NR_CONNECT,
	messages.Helo.ident:                ProtocolFlags.NR_HELO,
	messages.EnvelopeFrom.ident:        ProtocolFlags.NR_SENDER,
	messages.EnvelopeRecipient.ident:   ProtocolFlags.NR_RECIPIENT,
	messages.Data.ident:                ProtocolFlags.NR_DATA,
	messages.Unknown.ident:             ProtocolFlags.NR_UNKNOWN,
	messages.Header.ident:              ProtocolFlags.NR_HEADER,
	messages.EndOfHeaders.ident:        ProtocolFlags.NR_EOH,
	messages.Body.ident:                ProtocolFlags.NR_BODY,
}


@@ -201,7 +212,10 @@ class FilterProtocol:

	def __init__(self) -> None:
		self.nr = set[bytes]()
		self.actions = set[bytes]([messages.Progress.ident])
		self.state: tuple[messages.Message, set[bytes]]|None = None
		self._optflags: int = 0
		self._actflags: int = 0

	def read_from(
		self,
@@ -268,8 +282,11 @@ class FilterProtocol:
		if isinstance(message, messages.Negotiate):
			self._check_mta_flags(message)
		event, responses = self.state
		if message.ident in UPDATE_RESPONSES and isinstance(event, messages.EndOfMessage):
		if isinstance(event, messages.EndOfMessage):
			if message.ident in self.actions:
				return
			if message.ident in UPDATE_FLAG_MAP:
				raise UnexpectedMessage(message)
		if message.ident not in responses:
			raise InvalidMessage(message, event)
		self.state = None
@@ -278,6 +295,8 @@ class FilterProtocol:
		"""
		Store the option flags offered by an MTA for later checking
		"""
		self._optflags = message.protocol_flags
		self._actflags = message.action_flags

	def _check_mta_flags(self, message: messages.Negotiate) -> None:
		"""
@@ -286,3 +305,17 @@ class FilterProtocol:
		Filters cannot request options an MTA did not send, and any no-response (NR)
		flags need to be recorded for checking.
		"""
		# ActionFlag.SETSYMLIST must be set if Negotiate.macros is not empty
		if message.macros:
			if not ActionFlags.SETSYMLIST & message.action_flags:
				message.action_flags |= ActionFlags.SETSYMLIST
				warn(f"adding {ActionFlags.SETSYMLIST!r} to {message}", stacklevel=4)
			if not ActionFlags.SETSYMLIST & self._actflags:
				raise ValueError("requesting symbols (macros) is not offered by the MTA")

		if (flags := message.protocol_flags & ~self._optflags):
			raise ValueError(f"requested options not offered by the MTA: {ProtocolFlags(flags)!r}")
		if (flags := message.action_flags & ~self._actflags):
			raise ValueError(f"requested actions not offered by the MTA: {ActionFlags(flags)!r}")
		self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag & message.protocol_flags)
		self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag & message.action_flags)
+0 −11
Original line number Diff line number Diff line
@@ -99,7 +99,6 @@ class FilterProtocolTests(unittest.TestCase):
			for _ in FilterProtocol().read_from(buf):
				pass

	@unittest.expectedFailure
	def test_no_response(self) -> None:
		"""
		Check that a following message can be read immediately when no response is expected
@@ -154,7 +153,6 @@ class FilterProtocolTests(unittest.TestCase):
		else:
			self.fail("Connect not read")

	@unittest.expectedFailure
	def test_write_unexpected_response_nr(self) -> None:
		"""
		Check that writing a message when no response is expected raises UnexpectedMessage
@@ -212,7 +210,6 @@ class FilterProtocolTests(unittest.TestCase):
		else:
			self.fail("Connect not read")

	@unittest.expectedFailure
	def test_write_disallowed_update(self) -> None:
		"""
		Check that writing updates without negotiation raises UnexpectedMessage
@@ -315,7 +312,6 @@ class FilterProtocolTests(unittest.TestCase):
		else:
			self.fail("Connect not read")

	@unittest.expectedFailure
	def test_disallowed_opts(self) -> None:
		"""
		Check that requesting protocol options not offered by the MTA results in ValueError
@@ -333,7 +329,6 @@ class FilterProtocolTests(unittest.TestCase):
				Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K),
			)

	@unittest.expectedFailure
	def test_disallowed_actions(self) -> None:
		"""
		Check that requesting protocol actions not offered by the MTA results in ValueError
@@ -351,7 +346,6 @@ class FilterProtocolTests(unittest.TestCase):
				Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00),
			)

	@unittest.expectedFailure
	def test_unrequested_action(self) -> None:
		"""
		Check that sending an action that was not requested raises UnexpectedMessage
@@ -378,7 +372,6 @@ class FilterProtocolTests(unittest.TestCase):
							ReplaceBody(b""),
						)

	@unittest.expectedFailure
	def test_action(self) -> None:
		"""
		Check that sending an allowed modification action raised no issues
@@ -406,7 +399,6 @@ class FilterProtocolTests(unittest.TestCase):
						)
					assert len(warn_cm) == 0

	@unittest.expectedFailure
	def test_action_bad(self) -> None:
		"""
		Check that sending a disallowed message after EOM raises UnexpectedMessage
@@ -433,7 +425,6 @@ class FilterProtocolTests(unittest.TestCase):
							Skip(),
						)

	@unittest.expectedFailure
	def test_setsymlist_implicit(self) -> None:
		"""
		Check that sending a mapping of symbol lists sets SETSYMLIST and issues a warning
@@ -451,7 +442,6 @@ class FilterProtocolTests(unittest.TestCase):
				Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}),
			)

	@unittest.expectedFailure
	def test_setsymlist_disallowed(self) -> None:
		"""
		Check that sending symbol lists when not offered by an MTA raises ValueError
@@ -469,7 +459,6 @@ class FilterProtocolTests(unittest.TestCase):
				Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}),
			)

	@unittest.expectedFailure
	def test_setsymlist_implicit_disallowed(self) -> None:
		"""
		Check that sending symbol lists when not offered by an MTA raises ValueError