Loading kilter/protocol/core.py +46 −13 Original line number Diff line number Diff line Loading @@ -177,17 +177,28 @@ MTA_EVENT_RESPONSES = { messages.Close.ident: None, } UPDATE_RESPONSES = { messages.ChangeHeader.ident, messages.AddHeader.ident, messages.InsertHeader.ident, messages.ChangeSender.ident, messages.AddRecipient.ident, messages.AddRecipientPar.ident, messages.RemoveRecipient.ident, messages.ReplaceBody.ident, messages.Progress.ident, messages.Quarantine.ident, UPDATE_FLAG_MAP = { messages.ChangeHeader.ident: ActionFlags.CHANGE_HEADERS, messages.AddHeader.ident: ActionFlags.ADD_HEADERS, messages.InsertHeader.ident: ActionFlags.ADD_HEADERS, messages.ChangeSender.ident: ActionFlags.CHANGE_FROM, messages.AddRecipient.ident: ActionFlags.ADD_RECIPIENT, messages.AddRecipientPar.ident: ActionFlags.ADD_RECIPIENT_PAR, messages.RemoveRecipient.ident: ActionFlags.DELETE_RECIPIENT, messages.ReplaceBody.ident: ActionFlags.CHANGE_BODY, messages.Quarantine.ident: ActionFlags.QUARANTINE, } NR_FLAG_MAP = { messages.Connect.ident: ProtocolFlags.NR_CONNECT, messages.Helo.ident: ProtocolFlags.NR_HELO, messages.EnvelopeFrom.ident: ProtocolFlags.NR_SENDER, messages.EnvelopeRecipient.ident: ProtocolFlags.NR_RECIPIENT, messages.Data.ident: ProtocolFlags.NR_DATA, messages.Unknown.ident: ProtocolFlags.NR_UNKNOWN, messages.Header.ident: ProtocolFlags.NR_HEADER, messages.EndOfHeaders.ident: ProtocolFlags.NR_EOH, messages.Body.ident: ProtocolFlags.NR_BODY, } Loading @@ -201,7 +212,10 @@ class FilterProtocol: def __init__(self) -> None: self.nr = set[bytes]() self.actions = set[bytes]([messages.Progress.ident]) self.state: tuple[messages.Message, set[bytes]]|None = None self._optflags: int = 0 self._actflags: int = 0 def read_from( self, Loading Loading @@ -268,8 +282,11 @@ class FilterProtocol: if isinstance(message, messages.Negotiate): self._check_mta_flags(message) event, responses = self.state if message.ident in UPDATE_RESPONSES and isinstance(event, messages.EndOfMessage): if isinstance(event, messages.EndOfMessage): if message.ident in self.actions: return if message.ident in UPDATE_FLAG_MAP: raise UnexpectedMessage(message) if message.ident not in responses: raise InvalidMessage(message, event) self.state = None Loading @@ -278,6 +295,8 @@ class FilterProtocol: """ Store the option flags offered by an MTA for later checking """ self._optflags = message.protocol_flags self._actflags = message.action_flags def _check_mta_flags(self, message: messages.Negotiate) -> None: """ Loading @@ -286,3 +305,17 @@ class FilterProtocol: Filters cannot request options an MTA did not send, and any no-response (NR) flags need to be recorded for checking. """ # ActionFlag.SETSYMLIST must be set if Negotiate.macros is not empty if message.macros: if not ActionFlags.SETSYMLIST & message.action_flags: message.action_flags |= ActionFlags.SETSYMLIST warn(f"adding {ActionFlags.SETSYMLIST!r} to {message}", stacklevel=4) if not ActionFlags.SETSYMLIST & self._actflags: raise ValueError("requesting symbols (macros) is not offered by the MTA") if (flags := message.protocol_flags & ~self._optflags): raise ValueError(f"requested options not offered by the MTA: {ProtocolFlags(flags)!r}") if (flags := message.action_flags & ~self._actflags): raise ValueError(f"requested actions not offered by the MTA: {ActionFlags(flags)!r}") self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag & message.protocol_flags) self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag & message.action_flags) tests/test_core_filter_protocol.py +0 −11 Original line number Diff line number Diff line Loading @@ -99,7 +99,6 @@ class FilterProtocolTests(unittest.TestCase): for _ in FilterProtocol().read_from(buf): pass @unittest.expectedFailure def test_no_response(self) -> None: """ Check that a following message can be read immediately when no response is expected Loading Loading @@ -154,7 +153,6 @@ class FilterProtocolTests(unittest.TestCase): else: self.fail("Connect not read") @unittest.expectedFailure def test_write_unexpected_response_nr(self) -> None: """ Check that writing a message when no response is expected raises UnexpectedMessage Loading Loading @@ -212,7 +210,6 @@ class FilterProtocolTests(unittest.TestCase): else: self.fail("Connect not read") @unittest.expectedFailure def test_write_disallowed_update(self) -> None: """ Check that writing updates without negotiation raises UnexpectedMessage Loading Loading @@ -315,7 +312,6 @@ class FilterProtocolTests(unittest.TestCase): else: self.fail("Connect not read") @unittest.expectedFailure def test_disallowed_opts(self) -> None: """ Check that requesting protocol options not offered by the MTA results in ValueError Loading @@ -333,7 +329,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K), ) @unittest.expectedFailure def test_disallowed_actions(self) -> None: """ Check that requesting protocol actions not offered by the MTA results in ValueError Loading @@ -351,7 +346,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00), ) @unittest.expectedFailure def test_unrequested_action(self) -> None: """ Check that sending an action that was not requested raises UnexpectedMessage Loading @@ -378,7 +372,6 @@ class FilterProtocolTests(unittest.TestCase): ReplaceBody(b""), ) @unittest.expectedFailure def test_action(self) -> None: """ Check that sending an allowed modification action raised no issues Loading Loading @@ -406,7 +399,6 @@ class FilterProtocolTests(unittest.TestCase): ) assert len(warn_cm) == 0 @unittest.expectedFailure def test_action_bad(self) -> None: """ Check that sending a disallowed message after EOM raises UnexpectedMessage Loading @@ -433,7 +425,6 @@ class FilterProtocolTests(unittest.TestCase): Skip(), ) @unittest.expectedFailure def test_setsymlist_implicit(self) -> None: """ Check that sending a mapping of symbol lists sets SETSYMLIST and issues a warning Loading @@ -451,7 +442,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}), ) @unittest.expectedFailure def test_setsymlist_disallowed(self) -> None: """ Check that sending symbol lists when not offered by an MTA raises ValueError Loading @@ -469,7 +459,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}), ) @unittest.expectedFailure def test_setsymlist_implicit_disallowed(self) -> None: """ Check that sending symbol lists when not offered by an MTA raises ValueError Loading Loading
kilter/protocol/core.py +46 −13 Original line number Diff line number Diff line Loading @@ -177,17 +177,28 @@ MTA_EVENT_RESPONSES = { messages.Close.ident: None, } UPDATE_RESPONSES = { messages.ChangeHeader.ident, messages.AddHeader.ident, messages.InsertHeader.ident, messages.ChangeSender.ident, messages.AddRecipient.ident, messages.AddRecipientPar.ident, messages.RemoveRecipient.ident, messages.ReplaceBody.ident, messages.Progress.ident, messages.Quarantine.ident, UPDATE_FLAG_MAP = { messages.ChangeHeader.ident: ActionFlags.CHANGE_HEADERS, messages.AddHeader.ident: ActionFlags.ADD_HEADERS, messages.InsertHeader.ident: ActionFlags.ADD_HEADERS, messages.ChangeSender.ident: ActionFlags.CHANGE_FROM, messages.AddRecipient.ident: ActionFlags.ADD_RECIPIENT, messages.AddRecipientPar.ident: ActionFlags.ADD_RECIPIENT_PAR, messages.RemoveRecipient.ident: ActionFlags.DELETE_RECIPIENT, messages.ReplaceBody.ident: ActionFlags.CHANGE_BODY, messages.Quarantine.ident: ActionFlags.QUARANTINE, } NR_FLAG_MAP = { messages.Connect.ident: ProtocolFlags.NR_CONNECT, messages.Helo.ident: ProtocolFlags.NR_HELO, messages.EnvelopeFrom.ident: ProtocolFlags.NR_SENDER, messages.EnvelopeRecipient.ident: ProtocolFlags.NR_RECIPIENT, messages.Data.ident: ProtocolFlags.NR_DATA, messages.Unknown.ident: ProtocolFlags.NR_UNKNOWN, messages.Header.ident: ProtocolFlags.NR_HEADER, messages.EndOfHeaders.ident: ProtocolFlags.NR_EOH, messages.Body.ident: ProtocolFlags.NR_BODY, } Loading @@ -201,7 +212,10 @@ class FilterProtocol: def __init__(self) -> None: self.nr = set[bytes]() self.actions = set[bytes]([messages.Progress.ident]) self.state: tuple[messages.Message, set[bytes]]|None = None self._optflags: int = 0 self._actflags: int = 0 def read_from( self, Loading Loading @@ -268,8 +282,11 @@ class FilterProtocol: if isinstance(message, messages.Negotiate): self._check_mta_flags(message) event, responses = self.state if message.ident in UPDATE_RESPONSES and isinstance(event, messages.EndOfMessage): if isinstance(event, messages.EndOfMessage): if message.ident in self.actions: return if message.ident in UPDATE_FLAG_MAP: raise UnexpectedMessage(message) if message.ident not in responses: raise InvalidMessage(message, event) self.state = None Loading @@ -278,6 +295,8 @@ class FilterProtocol: """ Store the option flags offered by an MTA for later checking """ self._optflags = message.protocol_flags self._actflags = message.action_flags def _check_mta_flags(self, message: messages.Negotiate) -> None: """ Loading @@ -286,3 +305,17 @@ class FilterProtocol: Filters cannot request options an MTA did not send, and any no-response (NR) flags need to be recorded for checking. """ # ActionFlag.SETSYMLIST must be set if Negotiate.macros is not empty if message.macros: if not ActionFlags.SETSYMLIST & message.action_flags: message.action_flags |= ActionFlags.SETSYMLIST warn(f"adding {ActionFlags.SETSYMLIST!r} to {message}", stacklevel=4) if not ActionFlags.SETSYMLIST & self._actflags: raise ValueError("requesting symbols (macros) is not offered by the MTA") if (flags := message.protocol_flags & ~self._optflags): raise ValueError(f"requested options not offered by the MTA: {ProtocolFlags(flags)!r}") if (flags := message.action_flags & ~self._actflags): raise ValueError(f"requested actions not offered by the MTA: {ActionFlags(flags)!r}") self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag & message.protocol_flags) self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag & message.action_flags)
tests/test_core_filter_protocol.py +0 −11 Original line number Diff line number Diff line Loading @@ -99,7 +99,6 @@ class FilterProtocolTests(unittest.TestCase): for _ in FilterProtocol().read_from(buf): pass @unittest.expectedFailure def test_no_response(self) -> None: """ Check that a following message can be read immediately when no response is expected Loading Loading @@ -154,7 +153,6 @@ class FilterProtocolTests(unittest.TestCase): else: self.fail("Connect not read") @unittest.expectedFailure def test_write_unexpected_response_nr(self) -> None: """ Check that writing a message when no response is expected raises UnexpectedMessage Loading Loading @@ -212,7 +210,6 @@ class FilterProtocolTests(unittest.TestCase): else: self.fail("Connect not read") @unittest.expectedFailure def test_write_disallowed_update(self) -> None: """ Check that writing updates without negotiation raises UnexpectedMessage Loading Loading @@ -315,7 +312,6 @@ class FilterProtocolTests(unittest.TestCase): else: self.fail("Connect not read") @unittest.expectedFailure def test_disallowed_opts(self) -> None: """ Check that requesting protocol options not offered by the MTA results in ValueError Loading @@ -333,7 +329,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K), ) @unittest.expectedFailure def test_disallowed_actions(self) -> None: """ Check that requesting protocol actions not offered by the MTA results in ValueError Loading @@ -351,7 +346,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00), ) @unittest.expectedFailure def test_unrequested_action(self) -> None: """ Check that sending an action that was not requested raises UnexpectedMessage Loading @@ -378,7 +372,6 @@ class FilterProtocolTests(unittest.TestCase): ReplaceBody(b""), ) @unittest.expectedFailure def test_action(self) -> None: """ Check that sending an allowed modification action raised no issues Loading Loading @@ -406,7 +399,6 @@ class FilterProtocolTests(unittest.TestCase): ) assert len(warn_cm) == 0 @unittest.expectedFailure def test_action_bad(self) -> None: """ Check that sending a disallowed message after EOM raises UnexpectedMessage Loading @@ -433,7 +425,6 @@ class FilterProtocolTests(unittest.TestCase): Skip(), ) @unittest.expectedFailure def test_setsymlist_implicit(self) -> None: """ Check that sending a mapping of symbol lists sets SETSYMLIST and issues a warning Loading @@ -451,7 +442,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}), ) @unittest.expectedFailure def test_setsymlist_disallowed(self) -> None: """ Check that sending symbol lists when not offered by an MTA raises ValueError Loading @@ -469,7 +459,6 @@ class FilterProtocolTests(unittest.TestCase): Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}), ) @unittest.expectedFailure def test_setsymlist_implicit_disallowed(self) -> None: """ Check that sending symbol lists when not offered by an MTA raises ValueError Loading