Unverified Commit a90e91eb authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Remove hack for Unix datagram sockets

parent 9f7c099f
Loading
Loading
Loading
Loading
Loading

kodo/wpa_supplicant/_anyio.py

deleted100644 → 0
+0 −42
Original line number Diff line number Diff line
#  Copyright 2021, 2024  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Work-arounds for lack of AF_UNIX datagram socket support in Anyio
"""

from __future__ import annotations

import errno
import tempfile
from contextlib import suppress
from os import PathLike

from anyio import create_connected_unix_datagram_socket
from anyio.abc import ConnectedUNIXDatagramSocket as DatagramSocket


async def connect_unix_datagram(path: str | PathLike[str]) -> DatagramSocket:
	"""
	Return an AnyIO socket connected to a Unix datagram socket

	This behaviour is currently missing from AnyIO.
	"""
	for _ in range(10):
		fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.")
		with suppress(FileExistsError):
			return await create_connected_unix_datagram_socket(path, local_path=fname)
	raise FileExistsError(
		errno.EEXIST, "No usable temporary filename found",
	)
+8 −6
Original line number Diff line number Diff line
@@ -22,18 +22,19 @@ import enum
import logging
import os
import sys
import tempfile
from collections.abc import AsyncIterator
from collections.abc import Callable
from contextlib import asynccontextmanager
from pathlib import Path
from re import compile as regex
from types import TracebackType as Traceback
from typing import overload

import anyio
from anyio.abc import ConnectedUNIXDatagramSocket

from .. import errors
from .._anyio import DatagramSocket
from .._anyio import connect_unix_datagram
from . import consts

type EventInfo = tuple["EventPriority", str, str | None]
@@ -87,10 +88,10 @@ class BaseClient:

	event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?")

	def __init__(self, *, logger: logging.Logger | None = None) -> None:
	def __init__(self, *, logger: logging.Logger | None = None, ctrl_dir: Path | None = None) -> None:
		self.logger = logger or logging.getLogger(__package__)
		self.ctrl_dir = None
		self.sock: DatagramSocket | None = None
		self.ctrl_dir = ctrl_dir or Path(tempfile.mkdtemp())
		self.sock: ConnectedUNIXDatagramSocket | None = None
		self._lock = anyio.Lock()
		self._condition = anyio.Condition()
		self._handler_active = False
@@ -117,7 +118,8 @@ class BaseClient:
			raise RuntimeError("cannot connect to multiple daemons")

		with anyio.fail_after(1.0):
			self.sock = await connect_unix_datagram(os.fspath(path))
			sock_path = self.ctrl_dir / f"kodo.{os.getpid()}.sock"
			self.sock = await anyio.create_connected_unix_datagram_socket(path, local_path=sock_path)
			await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG)

	async def disconnect(self) -> None:
+17 −7
Original line number Diff line number Diff line
@@ -16,28 +16,38 @@
Anyio helpers for unit tests
"""

from __future__ import annotations

from collections.abc import Awaitable
from collections.abc import Callable
from unittest import mock

import anyio

type Patch[T] = mock._patch[T]  # pyright: ignore[reportPrivateUsage]


def _delay_side_effect(delay: float) -> Awaitable[None]:
	async def coro(*a: object, **k: object) -> None:
def _delay_side_effect[R](delay: float, response: R) -> Callable[..., Awaitable[R]]:
	async def coro(*a: object, **k: object) -> R:
		await anyio.sleep(delay)
		return response

	return coro


def patch_connect(delay: float = 0.0) -> mock._patch:
def patch_connect(delay: float = 0.0) -> Patch[mock.AsyncMock]:
	return mock.patch(
		"kodo.wpa_supplicant.client.base.connect_unix_datagram",
		side_effect=_delay_side_effect(delay),
		"anyio.create_connected_unix_datagram_socket",
		mock.AsyncMock(
			side_effect=_delay_side_effect(delay, mock.AsyncMock()),
		),
	)


def patch_send(delay: float = 0.0) -> mock._patch:
def patch_send(delay: float = 0.0) -> Patch[mock.AsyncMock]:
	return mock.patch(
		"kodo.wpa_supplicant.client.base.BaseClient.send_command",
		side_effect=_delay_side_effect(delay),
		mock.AsyncMock(
			side_effect=_delay_side_effect(delay, None),
		),
	)