Commit 3aaaf369 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Solve the run-after-abort problem

parent b03205b4
Loading
Loading
Loading
Loading
+11 −8
Original line number Diff line number Diff line
@@ -101,6 +101,7 @@ class Runner:
		proto = FilterProtocol()
		sender = _sender(client, proto)
		macro: Macro|None = None
		aborted = False

		await sender.asend(None)  # type: ignore # initialise

@@ -130,14 +131,20 @@ class Runner:
							# Connect messages.
							await runner.set_macros(macro)
						case Connect():
							await sender.asend(await self._connect(message, sender, runner, macro))
							await self._prepare_filters(message, sender, runner)
							if macro:
								await runner.set_macros(macro)
							await sender.asend(await runner.start(True, self.use_skip))
						case Abort():
							aborted = True
							await runner.abort(message)
							await runner.start(False, self.use_skip)
						case Close():
							await runner.aclose()
							return
						case _:
							if aborted:
								aborted = False
								await runner.start(False, self.use_skip)
							# TODO: Upgrade and remove ignores once python/mypy#14242 is in
							# TODO: Should remove assert once kilter.protocol#5 is resolved
							# Type narrowing should do the job adequately
@@ -161,20 +168,16 @@ class Runner:

		return resp

	async def _connect(
	async def _prepare_filters(
		self,
		message: Connect,
		sender: Sender,
		runner: _TaskRunner,
		macro: Macro|None,
	) -> ResponseMessage:
	) -> None:
		_logger.info(f"Client connected from {message.hostname}")
		for fltr in self.filters:
			session = Session(message, sender, _Broadcast())
			runner.add_filter(fltr, session)
		if macro:
			await runner.set_macros(macro)
		return await runner.start(True, self.use_skip)


class _TaskRunner:
+0 −3
Original line number Diff line number Diff line
import unittest

import trio.testing

from kilter.protocol import *
@@ -268,7 +266,6 @@ class RunnerTests(AsyncTestCase):
		assert aborted
		assert helo == "test.example.com"

	@unittest.expectedFailure
	async def test_abort_close(self) -> None:
		"""
		Check that a runner closes and does not restart when it receives an Abort + Close