Commit 2aca37a7 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Update abort tests for runner

parent fe4da122
Loading
Loading
Loading
Loading
Loading
+37 −7
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ import trio.testing
from kilter.protocol import *
from kilter.service import Runner
from kilter.service import Session
from kilter.service.session import Aborted

from . import AsyncTestCase
from .mock_stream import MockMessageStream
@@ -238,15 +239,17 @@ class RunnerTests(AsyncTestCase):
		"""
		Check that a runner closes cleanly when it receives an Abort
		"""
		cancelled = False
		aborted = False
		helo = ""

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal cancelled
			nonlocal aborted
			nonlocal helo
			try:
				await session.helo()
			except trio.Cancelled:
				cancelled = True
				helo = await session.helo()
			except Aborted:
				aborted = True
				raise
			return Accept()

@@ -256,9 +259,36 @@ class RunnerTests(AsyncTestCase):

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			await stream_mock.abort()
			assert [] == await stream_mock.send_msg(Abort())

			await stream_mock.send_and_expect(Helo("test.example.com"), Accept)

		assert aborted
		assert helo == "test.example.com"

	async def test_abort_close(self) -> None:
		"""
		Check that a runner closes and does not restart when it receives an Abort + Close
		"""
		runs = 0

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal runs
			runs += 1
			await session.helo()
			return Accept()

		async with trio.open_nursery() as tg, MockMessageStream() as stream_mock:
			tg.start_soon(test_filter, stream_mock.peer_stream)
			await trio.testing.wait_all_tasks_blocked()

			await stream_mock.send_and_expect(Negotiate(6, 0x1ff, 0), Negotiate)
			await stream_mock.send_and_expect(Connect("test.example.com"), Continue)
			assert [] == await stream_mock.send_msg(Abort())
			assert [] == await stream_mock.send_msg(Close())

		assert cancelled
		assert runs == 1

	async def test_bad_response(self) -> None:
		"""