Commit d4654bc0 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add FilterProtocol tests for options and actions not allowed by MTA

parent 321616d6
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ Choosing messages to transmit to an MTA is left to a higher level, as is handlin

from __future__ import annotations

from typing import Iterable
from typing import Iterator
from typing import Sequence
from typing import TypeAlias
from typing import Union
@@ -206,7 +206,7 @@ class FilterProtocol:
	def read_from(
		self,
		buf: FixedSizeBuffer,
	) -> Iterable[MTAMessage]:
	) -> Iterator[MTAMessage]:
		"""
		Return an iterator yielding each complete message from a buffer

+173 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from __future__ import annotations

import unittest
from ipaddress import IPv4Address
from warnings import catch_warnings

from kilter.protocol.buffer import SimpleBuffer
from kilter.protocol.core import FilterProtocol
@@ -313,3 +314,175 @@ class FilterProtocolTests(unittest.TestCase):
					break
		else:
			self.fail("Connect not read")

	@unittest.expectedFailure
	def test_disallowed_opts(self) -> None:
		"""
		Check that requesting protocol options not offered by the MTA results in ValueError
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff, ProtocolFlags.MAX_DATA_SIZE_1M).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine

		with self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(20),
				Negotiate(6, 0x00, ProtocolFlags.MAX_DATA_SIZE_256K),
			)

	@unittest.expectedFailure
	def test_disallowed_actions(self) -> None:
		"""
		Check that requesting protocol actions not offered by the MTA results in ValueError
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff & ~ActionFlags.CHANGE_BODY, 0xfffff).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine

		with self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(20),
				Negotiate(6, ActionFlags.CHANGE_BODY|ActionFlags.CHANGE_HEADERS, 0x00),
			)

	@unittest.expectedFailure
	def test_unrequested_action(self) -> None:
		"""
		Check that sending an action that was not requested raises UnexpectedMessage
		"""
		# Prepare input messages
		buf = SimpleBuffer(60)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		EndOfMessage(b"").pack(buf)

		protocol = FilterProtocol()

		for msg in protocol.read_from(buf):
			match msg:
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x00, 0xfffff),
					)
				case EndOfMessage():
					with self.assertRaises(UnexpectedMessage):
						protocol.write_to(
							SimpleBuffer(20),
							ReplaceBody(b""),
						)

	@unittest.expectedFailure
	def test_action(self) -> None:
		"""
		Check that sending an allowed modification action raised no issues
		"""
		# Prepare input messages
		buf = SimpleBuffer(60)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		EndOfMessage(b"").pack(buf)

		protocol = FilterProtocol()

		for msg in protocol.read_from(buf):
			match msg:
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x1ff, 0xfffff),
					)
				case EndOfMessage():
					with catch_warnings(record=True) as warn_cm:
						protocol.write_to(
							SimpleBuffer(20),
							ReplaceBody(b""),
						)
					assert len(warn_cm) == 0

	@unittest.expectedFailure
	def test_action_bad(self) -> None:
		"""
		Check that sending a disallowed message after EOM raises UnexpectedMessage
		"""
		# Prepare input messages
		buf = SimpleBuffer(60)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)
		Connect("example.com", IPv4Address("10.1.1.1"), 11111).pack(buf)
		EndOfMessage(b"").pack(buf)

		protocol = FilterProtocol()

		for msg in protocol.read_from(buf):
			match msg:
				case Negotiate():
					protocol.write_to(
						SimpleBuffer(20),
						Negotiate(6, 0x1ff, 0xfffff),
					)
				case EndOfMessage():
					with self.assertRaises(UnexpectedMessage):
						protocol.write_to(
							SimpleBuffer(20),
							Skip(),
						)

	@unittest.expectedFailure
	def test_setsymlist_implicit(self) -> None:
		"""
		Check that sending a mapping of symbol lists sets SETSYMLIST and issues a warning
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff, 0xfffff).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine

		with self.assertWarns(UserWarning):
			protocol.write_to(
				SimpleBuffer(40),
				Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}),
			)

	@unittest.expectedFailure
	def test_setsymlist_disallowed(self) -> None:
		"""
		Check that sending symbol lists when not offered by an MTA raises ValueError
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine

		with self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(40),
				Negotiate(6, ActionFlags.SETSYMLIST, 0xfffff, {Stage.CONNECT: {"spam"}}),
			)

	@unittest.expectedFailure
	def test_setsymlist_implicit_disallowed(self) -> None:
		"""
		Check that sending symbol lists when not offered by an MTA raises ValueError
		"""
		# Prepare input messages
		buf = SimpleBuffer(20)
		Negotiate(6, 0x1ff & ~ActionFlags.SETSYMLIST, 0xfffff).pack(buf)

		protocol = FilterProtocol()
		next(protocol.read_from(buf))  # Prime the state machine

		with self.assertWarns(UserWarning), self.assertRaises(ValueError):
			protocol.write_to(
				SimpleBuffer(40),
				Negotiate(6, 0x00, 0xfffff, {Stage.CONNECT: {"spam"}}),
			)