Verified Commit f8113a6c authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Switch to using anyio's connected unix datagram type

This was added in 4.0 so we can use it now.
parent 0cba3203
Loading
Loading
Loading
Loading
+3 −104
Original line number Diff line number Diff line
@@ -19,40 +19,12 @@ Work-arounds for lack of AF_UNIX datagram socket support in Anyio
from __future__ import annotations

import errno
import os
import socket
import tempfile
from contextlib import suppress
from os import PathLike
from typing import Awaitable
from typing import Callable
from typing import Protocol

import sniffio

ConnectorFn = Callable[[str, str], Awaitable['DatagramSocket']]

connectors: dict[str, ConnectorFn] = {}


class DatagramSocket(Protocol):

	@property
	def _raw_socket(self) -> socket.socket: ...

	async def aclose(self) -> None: ...

	async def receive(self) -> bytes: ...

	async def send(self, item: bytes) -> None: ...


class ConnectedUNIXMixin:

	async def aclose(self: DatagramSocket) -> None:
		path = self._raw_socket.getsockname()
		await super().aclose()  # type: ignore  # Mypy doesn't handle super() well in mixins
		os.unlink(path)
from anyio import create_connected_unix_datagram_socket
from anyio.abc import ConnectedUNIXDatagramSocket as DatagramSocket


async def connect_unix_datagram(path: str | PathLike[str]) -> DatagramSocket:
@@ -64,80 +36,7 @@ async def connect_unix_datagram(path: str | PathLike[str]) -> DatagramSocket:
	for _ in range(10):
		fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.")
		with suppress(FileExistsError):
			async_lib = sniffio.current_async_library()
			connector = connectors[async_lib]
			return await connector(fname, os.fspath(path))
			return await create_connected_unix_datagram_socket(path, local_path=fname)
	raise FileExistsError(
		errno.EEXIST, "No usable temporary filename found",
	)


try:
	import trio
except ImportError: ...
else:
	from anyio._backends import _trio

	class TrioConnectedUNIXSocket(ConnectedUNIXMixin, _trio.ConnectedUDPSocket):
		...

	async def trio_connect_unix_datagram(
		local_path: str,
		remote_path: str,
	) -> TrioConnectedUNIXSocket:
		sock = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
		await sock.bind(local_path)
		try:
			await sock.connect(remote_path)
		except BaseException:  # pragma: no cover
			sock.close()
			raise
		else:
			return TrioConnectedUNIXSocket(sock)

	connectors['trio'] = trio_connect_unix_datagram


# asyncio is in the stdlib, but lets make the layout match trio 😉
try:
	import asyncio
except ImportError: ...
else:
	from anyio._backends import _asyncio

	class AsyncioConnectedUNIXSocket(ConnectedUNIXMixin, _asyncio.ConnectedUDPSocket):
		...

	async def asyncio_connect_unix_datagram(
		local_path: str,
		remote_path: str,
	) -> AsyncioConnectedUNIXSocket:
		await asyncio.sleep(0.0)
		loop = asyncio.get_running_loop()
		sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
		sock.setblocking(False)
		sock.bind(local_path)
		while True:
			try:
				sock.connect(remote_path)
			except BlockingIOError:
				future: asyncio.Future[None] = asyncio.Future()
				loop.add_writer(sock, future.set_result, None)
				future.add_done_callback(lambda _: loop.remove_writer(sock))
				await future
			except BaseException:
				sock.close()
				raise
			else:
				break

		transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
			_asyncio.DatagramProtocol,
			sock=sock,
		)
		if protocol.exception:
			transport.close()
			raise protocol.exception
		return AsyncioConnectedUNIXSocket(transport, protocol)

	connectors['asyncio'] = asyncio_connect_unix_datagram