Commit 968d05eb authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Refactor ._anyio after enabling follow imports in mypy

parent 768a7245
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -97,7 +97,6 @@ strict = true
warn_unused_configs = True
warn_unreachable = true
implicit_reexport = true
follow_imports = skip

[flake8]
max-line-length = 92
+31 −24
Original line number Diff line number Diff line
@@ -16,42 +16,50 @@
Work-arounds for lack of AF_UNIX datagram socket support in Anyio
"""

import abc
from __future__ import annotations

import errno
import os
import socket
import tempfile
from contextlib import suppress
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Dict
from typing import Protocol
from typing import cast

import anyio.abc

try:
	from anyio import _get_asynclib
except ImportError:
	from anyio._core._eventloop import get_asynclib as _get_asynclib
import sniffio

from .types import PathLike

ConnectorFn = Callable[[PathLike, PathLike], Coroutine[Any, Any, 'DatagramSocket']]

connectors: Dict[str, ConnectorFn] = {}

class ConnectedUNIXAbstract(abc.ABC):

	_raw_socket: socket.socket
class DatagramSocket(Protocol):

	@abc.abstractmethod
	async def aclose(self) -> None:
		pass
	@property
	def _raw_socket(self) -> socket.socket: ...

	async def aclose(self) -> None: ...

class ConnectedUNIXMixin(ConnectedUNIXAbstract):
	async def receive(self) -> bytes: ...

	async def aclose(self) -> None:
	async def send(self, item: bytes) -> None: ...


class ConnectedUNIXMixin:

	async def aclose(self: DatagramSocket) -> None:
		path = self._raw_socket.getsockname()
		await super().aclose()
		await super().aclose()  # type: ignore  # Mypy doesn't handle super() well in mixins
		os.unlink(path)


async def connect_unix_datagram(path: PathLike) -> anyio.abc.SocketStream:
async def connect_unix_datagram(path: PathLike) -> DatagramSocket:
	"""
	Return an AnyIO socket connected to a Unix datagram socket

@@ -60,10 +68,9 @@ async def connect_unix_datagram(path: PathLike) -> anyio.abc.SocketStream:
	for _ in range(10):
		fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.")
		with suppress(FileExistsError):
			return await _get_asynclib().connect_unix_datagram(
				local_path=fname,
				remote_path=path,
			)
			async_lib = sniffio.current_async_library()
			connector = connectors[async_lib]
			return await connector(fname, path)
	raise FileExistsError(
		errno.EEXIST, "No usable temporary filename found",
	)
@@ -75,7 +82,7 @@ except ImportError: ...
else:
	from anyio._backends import _trio

	class TrioConnectedUNIXSocket(ConnectedUNIXMixin, _trio.ConnectedUDPSocket):  # type: ignore
	class TrioConnectedUNIXSocket(ConnectedUNIXMixin, _trio.ConnectedUDPSocket):
		...

	async def trio_connect_unix_datagram(
@@ -92,7 +99,7 @@ else:
		else:
			return TrioConnectedUNIXSocket(sock)

	_trio.connect_unix_datagram = trio_connect_unix_datagram
	connectors['trio'] = trio_connect_unix_datagram


# asyncio is in the stdlib, but lets make the layout match trio 😉
@@ -102,7 +109,7 @@ except ImportError: ...
else:
	from anyio._backends import _asyncio

	class AsyncioConnectedUNIXSocket(ConnectedUNIXMixin, _asyncio.ConnectedUDPSocket):  # type: ignore
	class AsyncioConnectedUNIXSocket(ConnectedUNIXMixin, _asyncio.ConnectedUDPSocket):
		...

	async def asyncio_connect_unix_datagram(
@@ -139,4 +146,4 @@ else:
			raise protocol.exception
		return AsyncioConnectedUNIXSocket(transport, protocol)

	_asyncio.connect_unix_datagram = asyncio_connect_unix_datagram
	connectors['asyncio'] = asyncio_connect_unix_datagram
+2 −2
Original line number Diff line number Diff line
@@ -36,9 +36,9 @@ from typing import Union
from typing import overload

import anyio
from anyio.abc import SocketStream

from .. import errors
from .._anyio import DatagramSocket
from .._anyio import connect_unix_datagram
from ..types import PathLike
from . import consts
@@ -98,7 +98,7 @@ class BaseClient:
	def __init__(self, *, logger: Optional[logging.Logger] = None):
		self.logger = logger or logging.getLogger(__package__)
		self.ctrl_dir = None
		self.sock: Optional[SocketStream] = None
		self.sock: Optional[DatagramSocket] = None
		self._lock = anyio.Lock()
		self._condition = anyio.Condition()
		self._handler_active = False