Verified Commit 54cf151d authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Make Mypy check tests/ and update tests for kilter.protocol 0.3.0

parent aaeba43e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -93,7 +93,7 @@ repos:
  rev: v1.4.1
  hooks:
  - id: mypy
    args: [--follow-imports=silent, kilter/service]
    args: [--follow-imports=silent, kilter/service, tests]
    pass_filenames: false
    additional_dependencies:
    - anyio ~=3.1
+25 −16
Original line number Diff line number Diff line
@@ -9,6 +9,15 @@ from . import AsyncTestCase
from .mock_stream import MockMessageStream


def make_negotiate(options: int = 0, actions: int = 0x1ff) -> Negotiate:
	"""
	Construct a Negotiate message from integer flags

	Defaults to all actions, and no protocol options.
	"""
	return Negotiate(6, ActionFlags(actions), ProtocolFlags(options))


class RunnerTests(AsyncTestCase):
	"""
	Tests for the Runner class
@@ -31,8 +40,8 @@ class RunnerTests(AsyncTestCase):
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0xff3ff),
				Negotiate(6, 0x1ff, 0x00000),
				Negotiate(6, ActionFlags(0x1ff), ProtocolFlags(0xff3ff)),
				Negotiate(6, ActionFlags(0x1ff), ProtocolFlags.NONE),
			)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

@@ -53,7 +62,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Reject)

	async def test_post_header(self) -> None:
@@ -69,7 +78,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(Helo("test.example.com"), Continue)
			await stream_mock.send_and_expect(EnvelopeFrom(b"test@example.com"), Accept)
@@ -93,7 +102,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(Body(b"This is a "), Continue)
			await stream_mock.send_and_expect(Body(b"message sent "), Continue)
@@ -127,8 +136,8 @@ class RunnerTests(AsyncTestCase):
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0xff3ff | ProtocolFlags.SKIP),
				Negotiate(6, 0x1ff, 0x00000 | ProtocolFlags.SKIP),
				make_negotiate(options=0xff3ff | ProtocolFlags.SKIP),
				Negotiate(6, ActionFlags(0x1ff), ProtocolFlags.SKIP),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
@@ -167,7 +176,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(Body(b"This is a "), Continue)
			await stream_mock.send_and_expect(Body(b"message sent "), Continue)
@@ -219,8 +228,8 @@ class RunnerTests(AsyncTestCase):
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(
				Negotiate(6, 0x1ff, 0xff3ff | ProtocolFlags.SKIP),
				Negotiate(6, 0x1ff, 0x00000 | ProtocolFlags.SKIP),
				make_negotiate(options=0xff3ff | ProtocolFlags.SKIP),
				Negotiate(6, ActionFlags(0x1ff), ProtocolFlags.SKIP),
			)

			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
@@ -259,7 +268,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.send_and_expect(Helo("test.example.com"), Continue)

@@ -288,7 +297,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			assert [] == await stream_mock.send_msg(Abort())
			assert [] == await stream_mock.send_msg(Close())
@@ -297,9 +306,9 @@ class RunnerTests(AsyncTestCase):

	async def test_bad_response(self) -> None:
		"""
		Check that a runner closes cleanly when it receives an Abort
		Check that when a filter returns an invalid response, it is converted to a failure
		"""
		@Runner
		@Runner  # type: ignore[arg-type]
		async def test_filter(session: Session) -> Skip:
			await session.helo()
			return Skip()
@@ -308,7 +317,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)

			with self.assertWarns(UserWarning) as wcm:
@@ -331,7 +340,7 @@ class RunnerTests(AsyncTestCase):
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(make_negotiate(), Negotiate)
			await stream_mock.send_and_expect(
				Macro(Connect.ident, {"{spam}": "yes", "{eggs}": "yes"}),
			)
+19 −19
Original line number Diff line number Diff line
@@ -25,38 +25,38 @@ class SessionTests(AsyncTestCase):
		Check that the phase progresses correctly when messages are delivered
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		assert session.phase == Phase.CONNECT
		self.assertEqual(session.phase, Phase.CONNECT)

		await session.deliver(Helo("example.com"))
		assert session.phase == Phase.MAIL
		self.assertEqual(session.phase, Phase.MAIL)

		await session.deliver(EnvelopeFrom(b"test@example.com"))
		assert session.phase == Phase.ENVELOPE
		self.assertEqual(session.phase, Phase.ENVELOPE)

		await session.deliver(Data())
		assert session.phase == Phase.HEADERS
		self.assertEqual(session.phase, Phase.HEADERS)

		await session.deliver(Body(b""))
		assert session.phase == Phase.BODY
		self.assertEqual(session.phase, Phase.BODY)

		await session.deliver(EndOfMessage(b""))
		assert session.phase == Phase.POST
		self.assertEqual(session.phase, Phase.POST)

	async def test_deliver_phases_2(self) -> None:
		"""
		Check that the phase progresses correctly when messages are delivered
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		assert session.phase == Phase.CONNECT
		self.assertEqual(session.phase, Phase.CONNECT)

		await session.deliver(EnvelopeRecipient(b"test@example.com", []))
		assert session.phase == Phase.ENVELOPE
		self.assertEqual(session.phase, Phase.ENVELOPE)

		await session.deliver(Header("To", b"test@example.com"))
		assert session.phase == Phase.HEADERS
		self.assertEqual(session.phase, Phase.HEADERS)

		await session.deliver(EndOfHeaders())
		assert session.phase == Phase.BODY
		self.assertEqual(session.phase, Phase.BODY)

	async def test_receive_ignore(self) -> None:
		"""
@@ -307,9 +307,9 @@ class SessionTests(AsyncTestCase):

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			self.assertEqual(session.phase, Phase.CONNECT)
			await session.change_sender("test@example.com")
			assert session.phase == Phase.POST
			self.assertEqual(session.phase, Phase.POST)
			await session.change_sender("test@example.com", "SPAM")

		async with trio.open_nursery() as tg:
@@ -331,9 +331,9 @@ class SessionTests(AsyncTestCase):

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			self.assertEqual(session.phase, Phase.CONNECT)
			await session.add_recipient("test@example.com")
			assert session.phase == Phase.POST
			self.assertEqual(session.phase, Phase.POST)
			await session.add_recipient("test@example.com", "SPAM")

		async with trio.open_nursery() as tg:
@@ -355,9 +355,9 @@ class SessionTests(AsyncTestCase):

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			self.assertEqual(session.phase, Phase.CONNECT)
			await session.remove_recipient("test@example.com")
			assert session.phase == Phase.POST
			self.assertEqual(session.phase, Phase.POST)

		async with trio.open_nursery() as tg:
			tg.start_soon(test_filter)
@@ -421,12 +421,12 @@ class SessionTests(AsyncTestCase):

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			self.assertEqual(session.phase, Phase.CONNECT)
			assert await session.helo() == "test.example.org"
			assert session.phase == Phase.MAIL
			self.assertEqual(session.phase, Phase.MAIL)
			with self.assertRaises(Aborted):
				await session.extension("MAIL")
			assert session.phase == Phase.CONNECT
			self.assertEqual(session.phase, Phase.CONNECT)
			assert await session.helo() == "test.example.com"

		async with trio.open_nursery() as tg: