Verified Commit 5dc45110 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add tests for needs_response() and sending Skip

parent f4d879a6
Loading
Loading
Loading
Loading
Loading
+90 −0
Original line number Diff line number Diff line
@@ -491,3 +491,93 @@ class FilterProtocolTests(unittest.TestCase):
				SimpleBuffer(40),
				Negotiate(6, ActionFlags.NONE, ALL_PROTOCOL_FLAGS, {Stage.CONNECT: {"spam"}}),
			)

	def test_negotiated(self) -> None:
		"""
		Check that attributes expose the correct negotiated information
		"""
		protocol = negotiated_protocol(
			ActionFlags.ADD_HEADERS|ActionFlags.CHANGE_HEADERS,
			ProtocolFlags.SKIP|ProtocolFlags.NR_CONNECT|ProtocolFlags.NR_HELO,
		)

		assert protocol.skip == True
		assert AddHeader.ident in protocol.actions
		assert ChangeHeader.ident in protocol.actions
		assert InsertHeader.ident in protocol.actions
		self.assertSetEqual(protocol.nr, {Connect.ident, Helo.ident})

	def test_negotiated_none(self) -> None:
		"""
		Check that attributes expose the correct negotiated information
		"""
		protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.NONE)

		assert protocol.skip == False
		assert len(protocol.nr) == 0

	def test_needs_response(self) -> None:
		"""
		Check that needs_response returns True for selected messages
		"""
		protocol = negotiated_protocol(
			ActionFlags.NONE,
			ProtocolFlags.NR_CONNECT|ProtocolFlags.NR_HELO,
		)

		assert protocol.needs_response(Negotiate(0, ActionFlags.NONE, ProtocolFlags.NONE))
		assert not protocol.needs_response(Connect("example.com"))
		assert not protocol.needs_response(Helo("example.com"))
		assert not protocol.needs_response(Macro(0, {}))
		assert not protocol.needs_response(Abort())
		assert not protocol.needs_response(Close())
		assert protocol.needs_response(EnvelopeRecipient(b"spam@example.com"))
		assert protocol.needs_response(Data())

	def test_skip(self) -> None:
		"""
		Check that sending Skip when negotiated is allowed
		"""
		protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.SKIP)

		buf = SimpleBuffer(20)
		Body(b"spam").pack(buf)
		next(protocol.read_from(buf))

		protocol.write_to(SimpleBuffer(20), Skip())

	def test_skip_wrong_state(self) -> None:
		"""
		Check that sending Skip before one would be valid raises UnexpectedMessage
		"""
		protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.SKIP)

		with self.assertRaises(UnexpectedMessage):
			protocol.write_to(SimpleBuffer(20), Skip())

	def test_skip_not_negotiated(self) -> None:
		"""
		Check that sending Skip when not negotiated raises UnexpectedMessage
		"""
		protocol = negotiated_protocol(ActionFlags.NONE, ProtocolFlags.NONE)

		buf = SimpleBuffer(20)
		Body(b"spam").pack(buf)
		next(protocol.read_from(buf))

		with self.assertRaises(UnexpectedMessage):
			protocol.write_to(SimpleBuffer(20), Skip())


def negotiated_protocol(actions: ActionFlags, options: ProtocolFlags) -> FilterProtocol:
	"""
	Return a post-negotiation FilterProtocol which has negotiated for the given features
	"""
	buf = SimpleBuffer(20)
	Negotiate(6, ActionFlags.ALL, ALL_PROTOCOL_FLAGS).pack(buf)

	protocol = FilterProtocol()
	next(protocol.read_from(buf))  # Prime the state machine
	protocol.write_to(SimpleBuffer(20), Negotiate(6, actions, options))

	return protocol