Commit ace81cef authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Fix typing for .client.base ready for Anyio updates

parent 389027f2
Loading
Loading
Loading
Loading
+52 −34
Original line number Diff line number Diff line
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2019-2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -16,19 +16,28 @@
This module provides a base WPA-Supplicant client implementation
"""

from __future__ import annotations

import contextlib
import enum
import logging
import os
import pathlib
from re import compile as regex
from types import TracebackType as Traceback
from typing import Any
from typing import AsyncContextManager
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type

import anyio
from anyio.abc import SocketStream

from .. import errors
from .. import util
@@ -44,21 +53,20 @@ class EventPriority(enum.IntEnum):
	Event Message priorities
	"""

	def get_logger_level(self, *, _mapping={}):
	def get_logger_level(self, *, _mapping: Dict[EventPriority, int] = {}) -> int:
		"""
		Return a logging level matching the `wpa_supplicant` priority level
		"""
		if not _mapping:
			# fmt: off
			cls = type(self)
			_mapping.update({
				self.MSGDUMP: logging.DEBUG,
				self.DEBUG: logging.DEBUG,
				self.INFO: logging.INFO,
				self.NOTICE: logging.INFO,
				self.WARNING: logging.WARNING,
				self.ERROR: logging.ERROR,
				cls.MSGDUMP: logging.DEBUG,
				cls.DEBUG: logging.DEBUG,
				cls.INFO: logging.INFO,
				cls.NOTICE: logging.INFO,
				cls.WARNING: logging.WARNING,
				cls.ERROR: logging.ERROR,
			})
			# fmt: on
		return _mapping[self]

	MSGDUMP = 0
@@ -79,22 +87,27 @@ class BaseClient:

	event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?")

	def __init__(self, *, logger=None):
	def __init__(self, *, logger: Optional[logging.Logger] = None):
		self.logger = logger or logging.getLogger(__package__)
		self.ctrl_dir = None
		self.sock = None
		self.sock: Optional[SocketStream] = None
		self._lock = anyio.create_lock()
		self._reply = ReplyManager()
		self._eventqueues = dict()
		self._eventqueues: Dict[str, Set[anyio.Queue]] = dict()
		self._eventcount = 0

	async def __aenter__(self):
	async def __aenter__(self) -> BaseClient:
		return self

	async def __aexit__(self, *exc_info):
	async def __aexit__(
		self,
		_et: Optional[Type[BaseException]],
		_e: Optional[BaseException],
		_tb: Optional[Traceback],
	) -> None:
		await self.disconnect()

	async def connect(self, path: PathLike):
	async def connect(self, path: PathLike) -> None:
		"""
		Connect to a WPA-Supplicant daemon through the given address
		"""
@@ -104,11 +117,11 @@ class BaseClient:
		if not isinstance(path, pathlib.Path):
			path = pathlib.Path(os.fspath(path))

		async with anyio.fail_after(1):
		async with anyio.fail_after(1.0):
			self.sock = await util.connect_unix_datagram(path.as_posix())
			await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG)

	async def disconnect(self):
	async def disconnect(self) -> None:
		"""
		Disconnect from the connected daemon, if connected
		"""
@@ -121,7 +134,7 @@ class BaseClient:
		*args: str,
		separator: str = consts.SEPARATOR_TAB,
		expect: str = consts.RESPONSE_OK,
		convert: Optional[Callable] = None,
		convert: Optional[Callable[[Any], Any]] = None,
	) -> Any:
		"""
		Send a message and await a response
@@ -143,6 +156,9 @@ class BaseClient:
		  a command that does not take arguments, or vice versa.
		  Raises ValueError
		"""
		if self.sock is None:
			raise RuntimeError("Client is not connected")

		if args:
			message = f"{message} {separator.join(args)}"
		msgbytes = message.encode()
@@ -171,7 +187,7 @@ class BaseClient:
			)
		return None

	def attach(self):
	def attach(self) -> AsyncContextManager[None]:
		"""
		Return a context manager that handles attaching to the daemon's message queue
		"""
@@ -185,15 +201,16 @@ class BaseClient:
			with self._events_queue(events) as queue:
				while queue.empty():
					await self._process(queue)
				return await queue.get()
				return await queue.get()  # type: ignore

	async def _process(self, queue: anyio.Queue):
	async def _process(self, queue: anyio.Queue) -> None:
		async with self._lock:
			# Shortcut if the queue of interest has a message from another call
			# to _process() (probably in another coroutine)
			if not queue.empty():
				return

			assert self.sock is not None
			msg = (await self.sock.recv(MAX_DGRAM_READ)).decode().strip()

		self.logger.debug("Received: %s", repr(msg))
@@ -206,10 +223,10 @@ class BaseClient:
				self.logger.warning("Unexpected response message: %s", msg)
			return

		prio, name, msg = match.groups()
		prio = EventPriority(int(prio))
		prio_, name, msg = match.groups()
		prio = EventPriority(int(prio_))

		if name is None:
		if not name:
			self.logger.log(prio.get_logger_level(), msg)
			return

@@ -222,7 +239,7 @@ class BaseClient:
				await msgqueue.put((prio, name, msg))

	@contextlib.contextmanager
	def _events_queue(self, events: Sequence[str]):
	def _events_queue(self, events: Sequence[str]) -> Generator[anyio.Queue, None, None]:
		evtqueues = self._eventqueues
		queue = anyio.create_queue(1)
		for evt in events:
@@ -238,17 +255,17 @@ class BaseClient:
				evtqueues[evt].remove(queue)

	class _AttachContext:
		def __init__(self, client):
		def __init__(self, client: BaseClient):
			self.client = client

		async def __aenter__(self):
		async def __aenter__(self) -> None:
			client = self.client
			assert client._eventcount >= 0
			if client._eventcount == 0:
				await client.send_command(consts.COMMAND_ATTACH)
			client._eventcount += 1

		async def __aexit__(self, *exc_info):
		async def __aexit__(self, *exc_info: Any) -> None:
			client = self.client
			assert client._eventcount > 0
			client._eventcount -= 1
@@ -264,19 +281,20 @@ class ReplyManager:
	A context manager supplying a locked reply queue
	"""

	def __init__(self):
	def __init__(self) -> None:
		self.lock = anyio.create_lock()
		self.queue = None
		self.queue: Optional[anyio.Queue] = None

	def __getattr__(self, name):
	def __getattr__(self, name: str) -> Any:
		return getattr(self.queue, name)

	async def __aenter__(self):
	async def __aenter__(self) -> anyio.Queue:
		await self.lock.__aenter__()
		self.queue = queue = anyio.create_queue(1)
		return queue

	async def __aexit__(self, *exc_info):
	async def __aexit__(self, *exc_info: Any) -> None:
		assert self.queue is not None
		self.queue, queue = None, self.queue
		await self.lock.__aexit__(*exc_info)
		assert queue.empty(), "Reply queue was not processed"
+2 −2
Original line number Diff line number Diff line
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2019-2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -57,4 +57,4 @@ def kv2dict(keyvalues: str) -> StringMap:
	"""
	Convert a list of line-terminated "key=value" substrings into a dictionary
	"""
	return dict(kv.split("=", 1) for kv in keyvalues.splitlines())
	return dict(kv.split("=", 1) for kv in keyvalues.splitlines())  # type: ignore