Commit dd9e9817 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add Broadcast synchronisation class

parent fee60e05
Loading
Loading
Loading
Loading

kilter/service/util.py

0 → 100644
+57 −0
Original line number Diff line number Diff line
# Copyright 2022 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""
Common helper utilities
"""

from __future__ import annotations

from typing import Generic
from typing import Optional
from typing import TypeVar

import anyio

T = TypeVar("T")


class Broadcast(anyio.Condition, Generic[T]):
	"""
	A reliable, blocking message queue for delivering to multiple listening tasks

	Listeners must acquire the lock (by using the `Broadcast` instance as a context manager)
	before calling `Broadcast.receive()` or it will fail.  If a listener is repeatedly
	awaiting messages in a loop, the loop should be inside the locked context or messages
	may be lost to race conditions.
	"""

	def __init__(self) -> None:
		super().__init__()
		self.obj: Optional[T] = None

	async def send(self, obj: T) -> None:
		"""
		Send a message object and block until all listeners have received it
		"""
		async with self:
			self.obj = obj
			self.notify_all()
		await anyio.sleep(0.0)  # ensure listeners have opportunity to wait for locks
		while 1:
			async with self:
				if self.statistics().lock_statistics.tasks_waiting:
					continue
				self.obj = None
				break

	async def receive(self) -> T:
		"""
		Listen for a single message and return it once it arrives
		"""
		await self.wait()
		assert self.obj is not None
		return self.obj
+1 −0
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ classifiers = [
tests = [
	"coverage[toml]",
	"kodo.plugins.cover_test_context @ https://code.kodo.org.uk/dom/cover-plugin-test-context/-/archive/main/cover-plugin-test-context-main.zip",
	"trio",
]

[project.urls]

tests/__init__.py

0 → 100644
+43 −0
Original line number Diff line number Diff line
"""
A package of tests for kilter.service modules
"""

import functools
import inspect
from collections.abc import Callable
from collections.abc import Coroutine
from inspect import iscoroutinefunction
from types import CoroutineType
from types import FunctionType
from typing import Any
from unittest import TestCase

import trio

SyncTest = Callable[[TestCase], None]
AsyncTest = Callable[[TestCase], Coroutine[Any, Any, None]]


class AsyncTestCase(TestCase):
	"""
	A variation of `unittest.TestCase` with support for awaitable (async) test functions
	"""

	@classmethod
	def __init_subclass__(cls, time_limit: float = 1.0, **kwargs: Any) -> None:
		super().__init_subclass__(**kwargs)
		for name, value in ((n, getattr(cls, n)) for n in dir(cls)):
			if name.startswith("test_") and iscoroutinefunction(value):
				setattr(cls, name, _syncwrap(value, time_limit))


def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest:
	@functools.wraps(test)
	def wrap(self: TestCase) -> None:
		async def limiter() -> None:
			with trio.move_on_after(time_limit) as cancel_scope:
				await test(self)
			if cancel_scope.cancelled_caught:
				raise TimeoutError
		trio.run(limiter)
	return wrap
+87 −0
Original line number Diff line number Diff line
import trio.testing

from kilter.service.util import Broadcast

from . import AsyncTestCase


class BroadcastTests(AsyncTestCase):
	"""
	Tests for the kilter.service.sync.Broadcast class
	"""

	async def test_send_no_listeners(self) -> None:
		"""
		Check that sending a message with no listeners does not block
		"""
		broadcast = Broadcast[int]()

		with trio.move_on_after(2.0) as cancel_scope:
			await broadcast.send(1)

		assert not cancel_scope.cancelled_caught

	async def test_send_one_listener(self) -> None:
		"""
		Check that sending a message to a single listener works
		"""
		broadcast = Broadcast[int]()
		messages = list[int]()

		async def listener() -> None:
			async with broadcast:
				messages.append(await broadcast.receive())

		async with trio.open_nursery() as task_group:
			task_group.start_soon(listener)
			await trio.testing.wait_all_tasks_blocked()

			await broadcast.send(1)
			await broadcast.send(2)

		assert messages == [1]

	async def test_send_multiple_listeners(self) -> None:
		"""
		Check that sending a message to multiple listeners works
		"""
		broadcast = Broadcast[int]()
		messages = list[int]()

		async def listener() -> None:
			async with broadcast:
				messages.append(await broadcast.receive())

		async with trio.open_nursery() as task_group:
			for _ in range(4):
				task_group.start_soon(listener)
			await trio.testing.wait_all_tasks_blocked()

			await broadcast.send(1)
			await broadcast.send(2)

		assert messages == [1, 1, 1, 1]

	async def test_recieve_loop(self) -> None:
		"""
		Check that receiving multiple messages in a loop works
		"""
		broadcast = Broadcast[int]()
		messages = list[int|str]()

		async def listener() -> None:
			async with broadcast:
				msg = 0
				while msg < 4:
					msg = await broadcast.receive()
					messages.append(msg)

		async with trio.open_nursery() as task_group:
			task_group.start_soon(listener)
			task_group.start_soon(listener)
			await trio.testing.wait_all_tasks_blocked()

			for n in range(1, 10):  # Deliberately higher than the listeners go
				await broadcast.send(n)

		assert messages == [1, 1, 2, 2, 3, 3, 4, 4]