Commit 027f3dc6 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Use type-extensions.Self instead of TypeVars

Closes #11
parent a665b5db
Loading
Loading
Loading
Loading
+2 −6
Original line number Diff line number Diff line
@@ -16,15 +16,14 @@ from __future__ import annotations

import logging
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from typing import Final
from typing import TypeAlias
from typing import TypeVar
from warnings import warn

import anyio.abc
from anyio.streams.stapled import StapledObjectStream
from async_generator import aclosing
from typing_extensions import Self

from kilter.protocol.buffer import SimpleBuffer
from kilter.protocol.core import EditMessage
@@ -182,15 +181,12 @@ class Runner:

class _TaskRunner:

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="_TaskRunner")

	def __init__(self, tasks: anyio.abc.TaskGroup):
		self.tasks = tasks
		self.filters = list[tuple[Filter, Session]]()
		self.channels = list[MessageChannel]()

	async def __aenter__(self: Self) -> Self:
	async def __aenter__(self) -> Self:
		return self

	async def __aexit__(self, *_: object) -> None:
+4 −10
Original line number Diff line number Diff line
@@ -18,13 +18,13 @@ from ipaddress import IPv4Address
from ipaddress import IPv6Address
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING
from typing import AsyncContextManager
from typing import Literal
from typing import Protocol
from typing import TypeVar
from warnings import warn

from typing_extensions import Self

from ..protocol.core import EditMessage
from ..protocol.core import EventMessage
from ..protocol.core import ResponseMessage
@@ -147,9 +147,6 @@ class Session:
	The kernel of a filter, providing an API for filters to access messages from an MTA
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="Session")

	host: str
	"""
	A hostname from a reverse address lookup performed when a client connects
@@ -211,7 +208,7 @@ class Session:
		# so some phases will be skipped; checks should not try to exactly match a phase.
		self.phase = Phase.CONNECT

	async def __aenter__(self: Self) -> Self:
	async def __aenter__(self) -> Self:
		await self.broadcast.__aenter__()
		return self

@@ -452,13 +449,10 @@ class HeaderIterator(AsyncGenerator[Header, None]):
	Iterator for headers obtained by using a `HeadersAccessor` as a context manager
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="HeaderIterator")

	def __init__(self, aiter: AsyncGenerator[Header, None]):
		self._aiter = aiter

	def __aiter__(self: Self) -> Self:
	def __aiter__(self) -> Self:
		return self

	async def __anext__(self) -> Header:  # noqa: D102
+1 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ dependencies = [
	"anyio ~=3.0",
	"async-generator ~=1.2",
	"kilter.protocol ~=0.2.1",
	"typing-extensions ~=4.0",
]
classifiers = [
	"Development Status :: 1 - Planning",
+2 −6
Original line number Diff line number Diff line
@@ -7,15 +7,14 @@ from collections.abc import Callable
from contextlib import asynccontextmanager
from functools import wraps
from types import TracebackType
from typing import TYPE_CHECKING
from typing import AsyncContextManager
from typing import TypeVar

import anyio
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.stapled import StapledByteStream
from anyio.streams.stapled import StapledObjectStream
from async_generator import aclosing
from typing_extensions import Self

from kilter.protocol import *
from kilter.protocol.buffer import SimpleBuffer
@@ -45,14 +44,11 @@ class MockMessageStream:
	A mock of the right-side of an `anyio.abc.ByteStream` with test support on the left side
	"""

	if TYPE_CHECKING:
		Self = TypeVar("Self", bound="MockMessageStream")

	def __init__(self) -> None:
		self.buffer = SimpleBuffer(1024)
		self.closed = False

	async def __aenter__(self: Self) -> Self:
	async def __aenter__(self) -> Self:
		send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes)
		send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes)