Loading tests/test_core_filter_protocol.py +90 −0 Original line number Diff line number Diff line Loading @@ -491,3 +491,93 @@ class FilterProtocolTests(unittest.TestCase): SimpleBuffer(40), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) def test_negotiated(self) -> None: """ Check that attributes expose the correct negotiated information """ protocol = negotiated_protocol( ActionFlags.ADD_HEADERS|ActionFlags.CHANGE_HEADERS, ProtocolFlags.SKIP|ProtocolFlags.NR_CONNECT|ProtocolFlags.NR_HELO, ) assert protocol.skip == True assert AddHeader.ident in protocol.actions assert ChangeHeader.ident in protocol.actions assert InsertHeader.ident in protocol.actions self.assertSetEqual(protocol.nr, {Connect.ident, Helo.ident}) def test_negotiated_none(self) -> None: """ Check that attributes expose the correct negotiated information """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.NONE) assert protocol.skip == False assert len(protocol.nr) == 0 def test_needs_response(self) -> None: """ Check that needs_response returns True for selected messages """ protocol = negotiated_protocol( ActionFlags.NONE, ProtocolFlags.NR_CONNECT|ProtocolFlags.NR_HELO, ) assert protocol.needs_response(Negotiate(0, ActionFlags.NONE, ProtocolFlags.NONE)) assert not protocol.needs_response(Connect("example.com")) assert not protocol.needs_response(Helo("example.com")) assert not protocol.needs_response(Macro(0, {})) assert not protocol.needs_response(Abort()) assert not protocol.needs_response(Close()) assert protocol.needs_response(EnvelopeRecipient(b"spam@example.com")) assert protocol.needs_response(Data()) def test_skip(self) -> None: """ Check that sending Skip when negotiated is allowed """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.SKIP) buf = SimpleBuffer(20) Body(b"spam").pack(buf) next(protocol.read_from(buf)) protocol.write_to(SimpleBuffer(20), Skip()) def test_skip_wrong_state(self) -> None: """ Check that sending Skip before one would be valid raises UnexpectedMessage """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.SKIP) with self.assertRaises(UnexpectedMessage): protocol.write_to(SimpleBuffer(20), Skip()) def test_skip_not_negotiated(self) -> None: """ Check that sending Skip when not negotiated raises UnexpectedMessage """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.NONE) buf = SimpleBuffer(20) Body(b"spam").pack(buf) next(protocol.read_from(buf)) with self.assertRaises(UnexpectedMessage): protocol.write_to(SimpleBuffer(20), Skip()) def negotiated_protocol(actions: ActionFlags, options: ProtocolFlags) -> FilterProtocol: """ Return a post-negotiation FilterProtocol which has negotiated for the given features """ buf = SimpleBuffer(20) Negotiate(6, ActionFlags.ALL, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine protocol.write_to(SimpleBuffer(20), Negotiate(6, actions, options)) return protocol Loading
tests/test_core_filter_protocol.py +90 −0 Original line number Diff line number Diff line Loading @@ -491,3 +491,93 @@ class FilterProtocolTests(unittest.TestCase): SimpleBuffer(40), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) def test_negotiated(self) -> None: """ Check that attributes expose the correct negotiated information """ protocol = negotiated_protocol( ActionFlags.ADD_HEADERS|ActionFlags.CHANGE_HEADERS, ProtocolFlags.SKIP|ProtocolFlags.NR_CONNECT|ProtocolFlags.NR_HELO, ) assert protocol.skip == True assert AddHeader.ident in protocol.actions assert ChangeHeader.ident in protocol.actions assert InsertHeader.ident in protocol.actions self.assertSetEqual(protocol.nr, {Connect.ident, Helo.ident}) def test_negotiated_none(self) -> None: """ Check that attributes expose the correct negotiated information """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.NONE) assert protocol.skip == False assert len(protocol.nr) == 0 def test_needs_response(self) -> None: """ Check that needs_response returns True for selected messages """ protocol = negotiated_protocol( ActionFlags.NONE, ProtocolFlags.NR_CONNECT|ProtocolFlags.NR_HELO, ) assert protocol.needs_response(Negotiate(0, ActionFlags.NONE, ProtocolFlags.NONE)) assert not protocol.needs_response(Connect("example.com")) assert not protocol.needs_response(Helo("example.com")) assert not protocol.needs_response(Macro(0, {})) assert not protocol.needs_response(Abort()) assert not protocol.needs_response(Close()) assert protocol.needs_response(EnvelopeRecipient(b"spam@example.com")) assert protocol.needs_response(Data()) def test_skip(self) -> None: """ Check that sending Skip when negotiated is allowed """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.SKIP) buf = SimpleBuffer(20) Body(b"spam").pack(buf) next(protocol.read_from(buf)) protocol.write_to(SimpleBuffer(20), Skip()) def test_skip_wrong_state(self) -> None: """ Check that sending Skip before one would be valid raises UnexpectedMessage """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.SKIP) with self.assertRaises(UnexpectedMessage): protocol.write_to(SimpleBuffer(20), Skip()) def test_skip_not_negotiated(self) -> None: """ Check that sending Skip when not negotiated raises UnexpectedMessage """ protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.NONE) buf = SimpleBuffer(20) Body(b"spam").pack(buf) next(protocol.read_from(buf)) with self.assertRaises(UnexpectedMessage): protocol.write_to(SimpleBuffer(20), Skip()) def negotiated_protocol(actions: ActionFlags, options: ProtocolFlags) -> FilterProtocol: """ Return a post-negotiation FilterProtocol which has negotiated for the given features """ buf = SimpleBuffer(20) Negotiate(6, ActionFlags.ALL, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine protocol.write_to(SimpleBuffer(20), Negotiate(6, actions, options)) return protocol