Verified Commit 46aae465 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Upgrade anyio to 4.x

parent 062206b9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -86,10 +86,10 @@ repos:
  rev: v1.15.0
  hooks:
  - id: mypy
    args: [kilter/service, tests]
    args: [kilter/service, tests, --python-version=3.11]
    pass_filenames: false
    additional_dependencies:
    - anyio ~=3.1
    - anyio ~=4.0
    - kilter.protocol ~=0.6.0
    - sphinx
    - trio-typing
+2 −2
Original line number Diff line number Diff line
@@ -405,6 +405,6 @@ class _TaskRunner:


def _make_message_channel() -> tuple[MessageChannel, MessageChannel]:
	lsend, rrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	rsend, lrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	lsend, rrecv = anyio.create_memory_object_stream[Message](1)
	rsend, lrecv = anyio.create_memory_object_stream[Message](1)
	return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)
+3 −3
Original line number Diff line number Diff line
@@ -11,9 +11,9 @@ license = {file = "LICENCE.txt"}
readme = "README.md"
dynamic = ["version", "description"]

requires-python = "~=3.10"
requires-python = "~=3.11"
dependencies = [
	"anyio ~=3.0",
	"anyio ~=4.0",
	"async-generator ~=1.2",
	"kilter.protocol ~=0.6.0",
	"typing-extensions ~=4.0",
@@ -27,7 +27,7 @@ classifiers = [

[project.optional-dependencies]
tests = [
	"trio <0.22",  # Until anyio supports BaseExceptionGroup
	"trio",
]
docs = [
	"sphinx ~=5.0",
+47 −0
Original line number Diff line number Diff line
@@ -2,22 +2,39 @@
A package of tests for kilter.service modules
"""

from __future__ import annotations

import functools
import os
from collections.abc import Callable
from collections.abc import Coroutine
from collections.abc import Iterator
from contextlib import contextmanager
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Any
from typing import Protocol
from typing import Self
from typing import TypeVar
from unittest import TestCase

import trio

E = TypeVar("E", bound=BaseException)

SyncTest = Callable[[TestCase], None]
AsyncTest = Callable[[TestCase], Coroutine[Any, Any, None]]

LIMIT_SCALE_FACTOR = float(os.environ.get("LIMIT_SCALE_FACTOR", 1))


if TYPE_CHECKING:
	class AssertRaisesContext(Protocol[E]):  # noqa: D101
		exception: E
		expected: type[BaseException] | tuple[type[BaseException], ...]
		msg: str|None


class AsyncTestCase(TestCase):
	"""
	A variation of `unittest.TestCase` with support for awaitable (async) test functions
@@ -30,6 +47,28 @@ class AsyncTestCase(TestCase):
			if name.startswith("test_") and iscoroutinefunction(value):
				setattr(cls, name, _syncwrap(value, time_limit * LIMIT_SCALE_FACTOR))

	@contextmanager
	def assertRaises(  # type: ignore[override]
		self,
		expected_exception: type[E]|tuple[type[E], ...],
		*,
		msg: str|None = None,
	) -> Iterator[AssertRaisesContext[E]]:
		"""
		Return a context manager that asserts a given exception is raised with the context

		Extends the base assertRaises with support for ExceptionGroups.  If at most one leaf
		exception is raised in the group and it matches the expected type, it will be
		treated as a successful failure.
		"""
		with super().assertRaises(expected_exception, msg=msg) as context:
			try:
				yield context
			except* expected_exception as grp:
				exc = [*_leaf_exc(grp)]
				assert len(exc) == 1
				raise exc[0] from grp


def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest:
	@functools.wraps(test)
@@ -41,3 +80,11 @@ def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest:
				raise TimeoutError
		trio.run(limiter)
	return wrap


def _leaf_exc(group: BaseExceptionGroup) -> Iterator[BaseException]:
	for exc in group.exceptions:
		if isinstance(exc, BaseExceptionGroup):
			yield from _leaf_exc(exc)
		else:
			yield exc
+2 −2
Original line number Diff line number Diff line
@@ -50,8 +50,8 @@ class MockMessageStream:
		self.closed = False

	async def __aenter__(self) -> Self:
		send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes)
		send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes)
		send_obj, recv_bytes = anyio.create_memory_object_stream[bytes](5)
		send_bytes, recv_obj = anyio.create_memory_object_stream[bytes](5)

		self._stream = StapledObjectStream(send_obj, recv_obj)
		self.peer_stream = StapledByteStream(