Loading kilter/protocol/core.py +10 −10 Original line number Diff line number Diff line Loading @@ -214,8 +214,8 @@ class FilterProtocol: self.nr = set[bytes]() self.actions = set[bytes]([messages.Progress.ident]) self.state: tuple[messages.Message, set[bytes]]|None = None self._optflags: int = 0 self._actflags: int = 0 self._optflags = ProtocolFlags(0) self._actflags = ActionFlags(0) def read_from( self, Loading Loading @@ -307,15 +307,15 @@ class FilterProtocol: """ # ActionFlag.SETSYMLIST must be set if Negotiate.macros is not empty if message.macros: if not ActionFlags.SETSYMLIST & message.action_flags: if ActionFlags.SETSYMLIST not in message.action_flags: message.action_flags |= ActionFlags.SETSYMLIST warn(f"adding {ActionFlags.SETSYMLIST!r} to {message}", stacklevel=4) if not ActionFlags.SETSYMLIST & self._actflags: if ActionFlags.SETSYMLIST not in self._actflags: raise ValueError("requesting symbols (macros) is not offered by the MTA") if (flags := message.protocol_flags & ~self._optflags): raise ValueError(f"requested options not offered by the MTA: {ProtocolFlags(flags)!r}") if (flags := message.action_flags & ~self._actflags): raise ValueError(f"requested actions not offered by the MTA: {ActionFlags(flags)!r}") self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag & message.protocol_flags) self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag & message.action_flags) if (pflags := message.protocol_flags & ~self._optflags): raise ValueError(f"requested options not offered by the MTA: {pflags!r}") if (aflags := message.action_flags & ~self._actflags): raise ValueError(f"requested actions not offered by the MTA: {aflags!r}") self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag in message.protocol_flags) self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag in message.action_flags) kilter/protocol/messages.py +9 −6 Original line number Diff line number Diff line Loading @@ -143,6 +143,8 @@ class ActionFlags(BitField): https://pythonhosted.org/pymilter/milter_api/smfi_register.html#flags """ NONE = 0x0 ADD_HEADERS = ADDHDRS = 0x1 CHANGE_HEADERS = CHGHDRS = 0x10 CHANGE_BODY = CHGBODY = 0x2 Loading @@ -162,6 +164,8 @@ class ProtocolFlags(BitField): https://pythonhosted.org/pymilter/milter_api/xxfi_negotiate.html """ NONE = 0x0 NO_CONNECT = 0x1 NO_HELO = 0x2 NO_SENDER = NO_MAIL = 0x4 Loading Loading @@ -380,9 +384,8 @@ class Negotiate(Message, ident=b"O"): version: int # TODO: use set[Enum]? action_flags: int protocol_flags: int action_flags: ActionFlags protocol_flags: ProtocolFlags macros: Mapping[Stage, Collection[str]] = field(default_factory=dict) Loading @@ -390,14 +393,14 @@ class Negotiate(Message, ident=b"O"): @classmethod def from_buffer(cls, buf: memoryview) -> Self: opts = cast(tuple[int, int, int], cls._struct.unpack_from(buf)) version, actions, options = cast(tuple[int, int, int], cls._struct.unpack_from(buf)) buf = buf[cls._struct.size:] macros = dict() while len(buf) > 0: stage, *_ = LONG.unpack_from(buf) names, buf = split_cstring(buf[LONG.size:]) macros[Stage(stage)] = names.tobytes().decode("utf-8").split() return cls(*opts, macros) macros[Stage(stage)] = str(names, "utf-8").split() return cls(version, ActionFlags(actions), ProtocolFlags(options), macros) def to_buffer(self, buf: FixedSizeBuffer) -> None: self._struct.pack_into( Loading tests/example_filter.py +1 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,7 @@ async def process_client(client: trio.SocketStream, nursery: trio.Nursery) -> No logging.info(f"RECEIVED {message!r}") match message: case messages.Negotiate(): message.protocol_flags = 0 message.protocol_flags = messages.ProtocolFlags.NONE await send_channel.send(message) case messages.Macro() | messages.Abort() | messages.Close(): continue Loading tests/test_core_filter_protocol.py +36 −33 Original line number Diff line number Diff line Loading @@ -17,6 +17,9 @@ from kilter.protocol.exceptions import UnknownMessage from kilter.protocol.messages import NoDataMessage from kilter.protocol.messages import * ALL_ACTION_FLAGS = ActionFlags(0x1ff) ALL_PROTOCOL_FLAGS = ProtocolFlags(0xfffff) class FilterProtocolTests(unittest.TestCase): """ Loading @@ -29,7 +32,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Macro(b"\x00", dict(spam="ham")).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Loading @@ -40,7 +43,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x01, 0x11f), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags(0x13e)), ) case Connect(): protocol.write_to( Loading Loading @@ -79,7 +82,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Macro(b"\x00", dict(spam="ham")).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Loading @@ -105,7 +108,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Loading @@ -116,7 +119,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, ProtocolFlags.NR_CONNECT), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NR_CONNECT), ) case Connect(): pass Loading @@ -127,7 +130,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) protocol = FilterProtocol() Loading @@ -137,7 +140,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, 0x00), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NONE), ) case Connect(): protocol.write_to( Loading @@ -159,7 +162,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) protocol = FilterProtocol() Loading @@ -169,7 +172,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, ProtocolFlags.NR_CONNECT), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NR_CONNECT), ) case Connect(): with self.assertRaises(UnexpectedMessage): Loading @@ -187,7 +190,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Loading @@ -198,7 +201,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.ADD_HEADERS, 0x00), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE), ) case Connect(): with self.assertRaises(UnexpectedMessage): Loading @@ -216,7 +219,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Data().pack(buf) Loading @@ -230,7 +233,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, 0x00), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NONE), ) case EndOfMessage(): with self.assertRaises(UnexpectedMessage): Loading @@ -253,7 +256,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Data().pack(buf) Loading @@ -267,7 +270,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.ADD_HEADERS, 0x00), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE), ) case EndOfMessage(): protocol.write_to( Loading @@ -290,7 +293,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) protocol = FilterProtocol() Loading @@ -300,7 +303,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.ADD_HEADERS, 0x00), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE), ) case Connect(): with self.assertRaises(InvalidMessage): Loading @@ -318,7 +321,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -326,7 +329,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K), Negotiate(6, ActionFlags.NONE, ProtocolFlags.MAX_DATA_SIZE_256K), ) def test_disallowed_actions(self) -> None: Loading @@ -335,7 +338,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff & ~ActionFlags.CHANGE_BODY, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.CHANGE_BODY, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -343,7 +346,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00), Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, ProtocolFlags.NONE), ) def test_unrequested_action(self) -> None: Loading @@ -352,7 +355,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(60) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) EndOfMessage(b"").pack(buf) Loading @@ -363,7 +366,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, 0xfffff), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS), ) case EndOfMessage(): with self.assertRaises(UnexpectedMessage): Loading @@ -378,7 +381,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(60) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) EndOfMessage(b"").pack(buf) Loading @@ -389,7 +392,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x1ff, 0xfffff), Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS), ) case EndOfMessage(): with catch_warnings(record=True) as warn_cm: Loading @@ -405,7 +408,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(60) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) EndOfMessage(b"").pack(buf) Loading @@ -416,7 +419,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x1ff, 0xfffff), Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS), ) case EndOfMessage(): with self.assertRaises(UnexpectedMessage): Loading @@ -431,7 +434,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -439,7 +442,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertWarns(UserWarning): protocol.write_to( SimpleBuffer(40), Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) def test_setsymlist_disallowed(self) -> None: Loading @@ -448,7 +451,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -456,7 +459,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(40), Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}), Negotiate(6, ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) def test_setsymlist_implicit_disallowed(self) -> None: Loading @@ -465,7 +468,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -473,5 +476,5 @@ class FilterProtocolTests(unittest.TestCase): with self.assertWarns(UserWarning), self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(40), Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) Loading
kilter/protocol/core.py +10 −10 Original line number Diff line number Diff line Loading @@ -214,8 +214,8 @@ class FilterProtocol: self.nr = set[bytes]() self.actions = set[bytes]([messages.Progress.ident]) self.state: tuple[messages.Message, set[bytes]]|None = None self._optflags: int = 0 self._actflags: int = 0 self._optflags = ProtocolFlags(0) self._actflags = ActionFlags(0) def read_from( self, Loading Loading @@ -307,15 +307,15 @@ class FilterProtocol: """ # ActionFlag.SETSYMLIST must be set if Negotiate.macros is not empty if message.macros: if not ActionFlags.SETSYMLIST & message.action_flags: if ActionFlags.SETSYMLIST not in message.action_flags: message.action_flags |= ActionFlags.SETSYMLIST warn(f"adding {ActionFlags.SETSYMLIST!r} to {message}", stacklevel=4) if not ActionFlags.SETSYMLIST & self._actflags: if ActionFlags.SETSYMLIST not in self._actflags: raise ValueError("requesting symbols (macros) is not offered by the MTA") if (flags := message.protocol_flags & ~self._optflags): raise ValueError(f"requested options not offered by the MTA: {ProtocolFlags(flags)!r}") if (flags := message.action_flags & ~self._actflags): raise ValueError(f"requested actions not offered by the MTA: {ActionFlags(flags)!r}") self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag & message.protocol_flags) self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag & message.action_flags) if (pflags := message.protocol_flags & ~self._optflags): raise ValueError(f"requested options not offered by the MTA: {pflags!r}") if (aflags := message.action_flags & ~self._actflags): raise ValueError(f"requested actions not offered by the MTA: {aflags!r}") self.nr.update(ident for ident, flag in NR_FLAG_MAP.items() if flag in message.protocol_flags) self.actions.update(ident for ident, flag in UPDATE_FLAG_MAP.items() if flag in message.action_flags)
kilter/protocol/messages.py +9 −6 Original line number Diff line number Diff line Loading @@ -143,6 +143,8 @@ class ActionFlags(BitField): https://pythonhosted.org/pymilter/milter_api/smfi_register.html#flags """ NONE = 0x0 ADD_HEADERS = ADDHDRS = 0x1 CHANGE_HEADERS = CHGHDRS = 0x10 CHANGE_BODY = CHGBODY = 0x2 Loading @@ -162,6 +164,8 @@ class ProtocolFlags(BitField): https://pythonhosted.org/pymilter/milter_api/xxfi_negotiate.html """ NONE = 0x0 NO_CONNECT = 0x1 NO_HELO = 0x2 NO_SENDER = NO_MAIL = 0x4 Loading Loading @@ -380,9 +384,8 @@ class Negotiate(Message, ident=b"O"): version: int # TODO: use set[Enum]? action_flags: int protocol_flags: int action_flags: ActionFlags protocol_flags: ProtocolFlags macros: Mapping[Stage, Collection[str]] = field(default_factory=dict) Loading @@ -390,14 +393,14 @@ class Negotiate(Message, ident=b"O"): @classmethod def from_buffer(cls, buf: memoryview) -> Self: opts = cast(tuple[int, int, int], cls._struct.unpack_from(buf)) version, actions, options = cast(tuple[int, int, int], cls._struct.unpack_from(buf)) buf = buf[cls._struct.size:] macros = dict() while len(buf) > 0: stage, *_ = LONG.unpack_from(buf) names, buf = split_cstring(buf[LONG.size:]) macros[Stage(stage)] = names.tobytes().decode("utf-8").split() return cls(*opts, macros) macros[Stage(stage)] = str(names, "utf-8").split() return cls(version, ActionFlags(actions), ProtocolFlags(options), macros) def to_buffer(self, buf: FixedSizeBuffer) -> None: self._struct.pack_into( Loading
tests/example_filter.py +1 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,7 @@ async def process_client(client: trio.SocketStream, nursery: trio.Nursery) -> No logging.info(f"RECEIVED {message!r}") match message: case messages.Negotiate(): message.protocol_flags = 0 message.protocol_flags = messages.ProtocolFlags.NONE await send_channel.send(message) case messages.Macro() | messages.Abort() | messages.Close(): continue Loading
tests/test_core_filter_protocol.py +36 −33 Original line number Diff line number Diff line Loading @@ -17,6 +17,9 @@ from kilter.protocol.exceptions import UnknownMessage from kilter.protocol.messages import NoDataMessage from kilter.protocol.messages import * ALL_ACTION_FLAGS = ActionFlags(0x1ff) ALL_PROTOCOL_FLAGS = ProtocolFlags(0xfffff) class FilterProtocolTests(unittest.TestCase): """ Loading @@ -29,7 +32,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Macro(b"\x00", dict(spam="ham")).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Loading @@ -40,7 +43,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x01, 0x11f), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags(0x13e)), ) case Connect(): protocol.write_to( Loading Loading @@ -79,7 +82,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Macro(b"\x00", dict(spam="ham")).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Loading @@ -105,7 +108,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Loading @@ -116,7 +119,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, ProtocolFlags.NR_CONNECT), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NR_CONNECT), ) case Connect(): pass Loading @@ -127,7 +130,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) protocol = FilterProtocol() Loading @@ -137,7 +140,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, 0x00), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NONE), ) case Connect(): protocol.write_to( Loading @@ -159,7 +162,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) protocol = FilterProtocol() Loading @@ -169,7 +172,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, ProtocolFlags.NR_CONNECT), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NR_CONNECT), ) case Connect(): with self.assertRaises(UnexpectedMessage): Loading @@ -187,7 +190,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Loading @@ -198,7 +201,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.ADD_HEADERS, 0x00), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE), ) case Connect(): with self.assertRaises(UnexpectedMessage): Loading @@ -216,7 +219,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Data().pack(buf) Loading @@ -230,7 +233,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, 0x00), Negotiate(6, ActionFlags.NONE, ProtocolFlags.NONE), ) case EndOfMessage(): with self.assertRaises(UnexpectedMessage): Loading @@ -253,7 +256,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) Helo("example.com").pack(buf) Data().pack(buf) Loading @@ -267,7 +270,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.ADD_HEADERS, 0x00), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE), ) case EndOfMessage(): protocol.write_to( Loading @@ -290,7 +293,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(100) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) protocol = FilterProtocol() Loading @@ -300,7 +303,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.ADD_HEADERS, 0x00), Negotiate(6, ActionFlags.ADD_HEADERS, ProtocolFlags.NONE), ) case Connect(): with self.assertRaises(InvalidMessage): Loading @@ -318,7 +321,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -326,7 +329,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K), Negotiate(6, ActionFlags.NONE, ProtocolFlags.MAX_DATA_SIZE_256K), ) def test_disallowed_actions(self) -> None: Loading @@ -335,7 +338,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff & ~ActionFlags.CHANGE_BODY, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.CHANGE_BODY, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -343,7 +346,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(20), Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00), Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, ProtocolFlags.NONE), ) def test_unrequested_action(self) -> None: Loading @@ -352,7 +355,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(60) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) EndOfMessage(b"").pack(buf) Loading @@ -363,7 +366,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x00, 0xfffff), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS), ) case EndOfMessage(): with self.assertRaises(UnexpectedMessage): Loading @@ -378,7 +381,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(60) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) EndOfMessage(b"").pack(buf) Loading @@ -389,7 +392,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x1ff, 0xfffff), Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS), ) case EndOfMessage(): with catch_warnings(record=True) as warn_cm: Loading @@ -405,7 +408,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(60) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf) EndOfMessage(b"").pack(buf) Loading @@ -416,7 +419,7 @@ class FilterProtocolTests(unittest.TestCase): case Negotiate(): protocol.write_to( SimpleBuffer(20), Negotiate(6, 0x1ff, 0xfffff), Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS), ) case EndOfMessage(): with self.assertRaises(UnexpectedMessage): Loading @@ -431,7 +434,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -439,7 +442,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertWarns(UserWarning): protocol.write_to( SimpleBuffer(40), Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) def test_setsymlist_disallowed(self) -> None: Loading @@ -448,7 +451,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -456,7 +459,7 @@ class FilterProtocolTests(unittest.TestCase): with self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(40), Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}), Negotiate(6, ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), ) def test_setsymlist_implicit_disallowed(self) -> None: Loading @@ -465,7 +468,7 @@ class FilterProtocolTests(unittest.TestCase): """ # Prepare input messages buf = SimpleBuffer(20) Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf) Negotiate(6, ALL_ACTION_FLAGS & ~ActionFlags.SETSYMLIST, ALL_PROTOCOL_FLAGS).pack(buf) protocol = FilterProtocol() next(protocol.read_from(buf)) # Prime the state machine Loading @@ -473,5 +476,5 @@ class FilterProtocolTests(unittest.TestCase): with self.assertWarns(UserWarning), self.assertRaises(ValueError): protocol.write_to( SimpleBuffer(40), Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}), Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}), )