Loading .pre-commit-config.yaml +2 −2 Original line number Diff line number Diff line Loading @@ -86,10 +86,10 @@ repos: rev: v1.15.0 hooks: - id: mypy args: [kilter/service, tests] args: [kilter/service, tests, --python-version=3.11] pass_filenames: false additional_dependencies: - anyio ~=3.1 - anyio ~=4.0 - kilter.protocol ~=0.6.0 - sphinx - trio-typing kilter/service/runner.py +2 −2 Original line number Diff line number Diff line Loading @@ -405,6 +405,6 @@ class _TaskRunner: def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore lsend, rrecv = anyio.create_memory_object_stream[Message](1) rsend, lrecv = anyio.create_memory_object_stream[Message](1) return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) pyproject.toml +3 −3 Original line number Diff line number Diff line Loading @@ -11,9 +11,9 @@ license = {file = "LICENCE.txt"} readme = "README.md" dynamic = ["version", "description"] requires-python = "~=3.10" requires-python = "~=3.11" dependencies = [ "anyio ~=3.0", "anyio ~=4.0", "async-generator ~=1.2", "kilter.protocol ~=0.6.0", "typing-extensions ~=4.0", Loading @@ -27,7 +27,7 @@ classifiers = [ [project.optional-dependencies] tests = [ "trio <0.22", # Until anyio supports BaseExceptionGroup "trio", ] docs = [ "sphinx ~=5.0", Loading tests/__init__.py +47 −0 Original line number Diff line number Diff line Loading @@ -2,22 +2,39 @@ A package of tests for kilter.service modules """ from __future__ import annotations import functools import os from collections.abc import Callable from collections.abc import Coroutine from collections.abc import Iterator from contextlib import contextmanager from inspect import iscoroutinefunction from typing import TYPE_CHECKING from typing import Any from typing import Protocol from typing import Self from typing import TypeVar from unittest import TestCase import trio E = TypeVar("E", bound=BaseException) SyncTest = Callable[[TestCase], None] AsyncTest = Callable[[TestCase], Coroutine[Any, Any, None]] LIMIT_SCALE_FACTOR = float(os.environ.get("LIMIT_SCALE_FACTOR", 1)) if TYPE_CHECKING: class AssertRaisesContext(Protocol[E]): # noqa: D101 exception: E expected: type[BaseException] | tuple[type[BaseException], ...] msg: str|None class AsyncTestCase(TestCase): """ A variation of `unittest.TestCase` with support for awaitable (async) test functions Loading @@ -30,6 +47,28 @@ class AsyncTestCase(TestCase): if name.startswith("test_") and iscoroutinefunction(value): setattr(cls, name, _syncwrap(value, time_limit * LIMIT_SCALE_FACTOR)) @contextmanager def assertRaises( # type: ignore[override] self, expected_exception: type[E]|tuple[type[E], ...], *, msg: str|None = None, ) -> Iterator[AssertRaisesContext[E]]: """ Return a context manager that asserts a given exception is raised with the context Extends the base assertRaises with support for ExceptionGroups. If at most one leaf exception is raised in the group and it matches the expected type, it will be treated as a successful failure. """ with super().assertRaises(expected_exception, msg=msg) as context: try: yield context except* expected_exception as grp: exc = [*_leaf_exc(grp)] assert len(exc) == 1 raise exc[0] from grp def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest: @functools.wraps(test) Loading @@ -41,3 +80,11 @@ def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest: raise TimeoutError trio.run(limiter) return wrap def _leaf_exc(group: BaseExceptionGroup) -> Iterator[BaseException]: for exc in group.exceptions: if isinstance(exc, BaseExceptionGroup): yield from _leaf_exc(exc) else: yield exc tests/mock_stream.py +2 −2 Original line number Diff line number Diff line Loading @@ -50,8 +50,8 @@ class MockMessageStream: self.closed = False async def __aenter__(self) -> Self: send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes) send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes) send_obj, recv_bytes = anyio.create_memory_object_stream[bytes](5) send_bytes, recv_obj = anyio.create_memory_object_stream[bytes](5) self._stream = StapledObjectStream(send_obj, recv_obj) self.peer_stream = StapledByteStream( Loading Loading
.pre-commit-config.yaml +2 −2 Original line number Diff line number Diff line Loading @@ -86,10 +86,10 @@ repos: rev: v1.15.0 hooks: - id: mypy args: [kilter/service, tests] args: [kilter/service, tests, --python-version=3.11] pass_filenames: false additional_dependencies: - anyio ~=3.1 - anyio ~=4.0 - kilter.protocol ~=0.6.0 - sphinx - trio-typing
kilter/service/runner.py +2 −2 Original line number Diff line number Diff line Loading @@ -405,6 +405,6 @@ class _TaskRunner: def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore lsend, rrecv = anyio.create_memory_object_stream[Message](1) rsend, lrecv = anyio.create_memory_object_stream[Message](1) return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)
pyproject.toml +3 −3 Original line number Diff line number Diff line Loading @@ -11,9 +11,9 @@ license = {file = "LICENCE.txt"} readme = "README.md" dynamic = ["version", "description"] requires-python = "~=3.10" requires-python = "~=3.11" dependencies = [ "anyio ~=3.0", "anyio ~=4.0", "async-generator ~=1.2", "kilter.protocol ~=0.6.0", "typing-extensions ~=4.0", Loading @@ -27,7 +27,7 @@ classifiers = [ [project.optional-dependencies] tests = [ "trio <0.22", # Until anyio supports BaseExceptionGroup "trio", ] docs = [ "sphinx ~=5.0", Loading
tests/__init__.py +47 −0 Original line number Diff line number Diff line Loading @@ -2,22 +2,39 @@ A package of tests for kilter.service modules """ from __future__ import annotations import functools import os from collections.abc import Callable from collections.abc import Coroutine from collections.abc import Iterator from contextlib import contextmanager from inspect import iscoroutinefunction from typing import TYPE_CHECKING from typing import Any from typing import Protocol from typing import Self from typing import TypeVar from unittest import TestCase import trio E = TypeVar("E", bound=BaseException) SyncTest = Callable[[TestCase], None] AsyncTest = Callable[[TestCase], Coroutine[Any, Any, None]] LIMIT_SCALE_FACTOR = float(os.environ.get("LIMIT_SCALE_FACTOR", 1)) if TYPE_CHECKING: class AssertRaisesContext(Protocol[E]): # noqa: D101 exception: E expected: type[BaseException] | tuple[type[BaseException], ...] msg: str|None class AsyncTestCase(TestCase): """ A variation of `unittest.TestCase` with support for awaitable (async) test functions Loading @@ -30,6 +47,28 @@ class AsyncTestCase(TestCase): if name.startswith("test_") and iscoroutinefunction(value): setattr(cls, name, _syncwrap(value, time_limit * LIMIT_SCALE_FACTOR)) @contextmanager def assertRaises( # type: ignore[override] self, expected_exception: type[E]|tuple[type[E], ...], *, msg: str|None = None, ) -> Iterator[AssertRaisesContext[E]]: """ Return a context manager that asserts a given exception is raised with the context Extends the base assertRaises with support for ExceptionGroups. If at most one leaf exception is raised in the group and it matches the expected type, it will be treated as a successful failure. """ with super().assertRaises(expected_exception, msg=msg) as context: try: yield context except* expected_exception as grp: exc = [*_leaf_exc(grp)] assert len(exc) == 1 raise exc[0] from grp def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest: @functools.wraps(test) Loading @@ -41,3 +80,11 @@ def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest: raise TimeoutError trio.run(limiter) return wrap def _leaf_exc(group: BaseExceptionGroup) -> Iterator[BaseException]: for exc in group.exceptions: if isinstance(exc, BaseExceptionGroup): yield from _leaf_exc(exc) else: yield exc
tests/mock_stream.py +2 −2 Original line number Diff line number Diff line Loading @@ -50,8 +50,8 @@ class MockMessageStream: self.closed = False async def __aenter__(self) -> Self: send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes) send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes) send_obj, recv_bytes = anyio.create_memory_object_stream[bytes](5) send_bytes, recv_obj = anyio.create_memory_object_stream[bytes](5) self._stream = StapledObjectStream(send_obj, recv_obj) self.peer_stream = StapledByteStream( Loading