Loading wpa_supplicant/client/base.py +52 −34 Original line number Diff line number Diff line # Copyright 2019 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. Loading @@ -16,19 +16,28 @@ This module provides a base WPA-Supplicant client implementation """ from __future__ import annotations import contextlib import enum import logging import os import pathlib from re import compile as regex from types import TracebackType as Traceback from typing import Any from typing import AsyncContextManager from typing import Callable from typing import Dict from typing import Generator from typing import Optional from typing import Sequence from typing import Set from typing import Tuple from typing import Type import anyio from anyio.abc import SocketStream from .. import errors from .. import util Loading @@ -44,21 +53,20 @@ class EventPriority(enum.IntEnum): Event Message priorities """ def get_logger_level(self, *, _mapping={}): def get_logger_level(self, *, _mapping: Dict[EventPriority, int] = {}) -> int: """ Return a logging level matching the `wpa_supplicant` priority level """ if not _mapping: # fmt: off cls = type(self) _mapping.update({ self.MSGDUMP: logging.DEBUG, self.DEBUG: logging.DEBUG, self.INFO: logging.INFO, self.NOTICE: logging.INFO, self.WARNING: logging.WARNING, self.ERROR: logging.ERROR, cls.MSGDUMP: logging.DEBUG, cls.DEBUG: logging.DEBUG, cls.INFO: logging.INFO, cls.NOTICE: logging.INFO, cls.WARNING: logging.WARNING, cls.ERROR: logging.ERROR, }) # fmt: on return _mapping[self] MSGDUMP = 0 Loading @@ -79,22 +87,27 @@ class BaseClient: event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?") def __init__(self, *, logger=None): def __init__(self, *, logger: Optional[logging.Logger] = None): self.logger = logger or logging.getLogger(__package__) self.ctrl_dir = None self.sock = None self.sock: Optional[SocketStream] = None self._lock = anyio.create_lock() self._reply = ReplyManager() self._eventqueues = dict() self._eventqueues: Dict[str, Set[anyio.Queue]] = dict() self._eventcount = 0 async def __aenter__(self): async def __aenter__(self) -> BaseClient: return self async def __aexit__(self, *exc_info): async def __aexit__( self, _et: Optional[Type[BaseException]], _e: Optional[BaseException], _tb: Optional[Traceback], ) -> None: await self.disconnect() async def connect(self, path: PathLike): async def connect(self, path: PathLike) -> None: """ Connect to a WPA-Supplicant daemon through the given address """ Loading @@ -104,11 +117,11 @@ class BaseClient: if not isinstance(path, pathlib.Path): path = pathlib.Path(os.fspath(path)) async with anyio.fail_after(1): async with anyio.fail_after(1.0): self.sock = await util.connect_unix_datagram(path.as_posix()) await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG) async def disconnect(self): async def disconnect(self) -> None: """ Disconnect from the connected daemon, if connected """ Loading @@ -121,7 +134,7 @@ class BaseClient: *args: str, separator: str = consts.SEPARATOR_TAB, expect: str = consts.RESPONSE_OK, convert: Optional[Callable] = None, convert: Optional[Callable[[Any], Any]] = None, ) -> Any: """ Send a message and await a response Loading @@ -143,6 +156,9 @@ class BaseClient: a command that does not take arguments, or vice versa. Raises ValueError """ if self.sock is None: raise RuntimeError("Client is not connected") if args: message = f"{message} {separator.join(args)}" msgbytes = message.encode() Loading Loading @@ -171,7 +187,7 @@ class BaseClient: ) return None def attach(self): def attach(self) -> AsyncContextManager[None]: """ Return a context manager that handles attaching to the daemon's message queue """ Loading @@ -185,15 +201,16 @@ class BaseClient: with self._events_queue(events) as queue: while queue.empty(): await self._process(queue) return await queue.get() return await queue.get() # type: ignore async def _process(self, queue: anyio.Queue): async def _process(self, queue: anyio.Queue) -> None: async with self._lock: # Shortcut if the queue of interest has a message from another call # to _process() (probably in another coroutine) if not queue.empty(): return assert self.sock is not None msg = (await self.sock.recv(MAX_DGRAM_READ)).decode().strip() self.logger.debug("Received: %s", repr(msg)) Loading @@ -206,10 +223,10 @@ class BaseClient: self.logger.warning("Unexpected response message: %s", msg) return prio, name, msg = match.groups() prio = EventPriority(int(prio)) prio_, name, msg = match.groups() prio = EventPriority(int(prio_)) if name is None: if not name: self.logger.log(prio.get_logger_level(), msg) return Loading @@ -222,7 +239,7 @@ class BaseClient: await msgqueue.put((prio, name, msg)) @contextlib.contextmanager def _events_queue(self, events: Sequence[str]): def _events_queue(self, events: Sequence[str]) -> Generator[anyio.Queue, None, None]: evtqueues = self._eventqueues queue = anyio.create_queue(1) for evt in events: Loading @@ -238,17 +255,17 @@ class BaseClient: evtqueues[evt].remove(queue) class _AttachContext: def __init__(self, client): def __init__(self, client: BaseClient): self.client = client async def __aenter__(self): async def __aenter__(self) -> None: client = self.client assert client._eventcount >= 0 if client._eventcount == 0: await client.send_command(consts.COMMAND_ATTACH) client._eventcount += 1 async def __aexit__(self, *exc_info): async def __aexit__(self, *exc_info: Any) -> None: client = self.client assert client._eventcount > 0 client._eventcount -= 1 Loading @@ -264,19 +281,20 @@ class ReplyManager: A context manager supplying a locked reply queue """ def __init__(self): def __init__(self) -> None: self.lock = anyio.create_lock() self.queue = None self.queue: Optional[anyio.Queue] = None def __getattr__(self, name): def __getattr__(self, name: str) -> Any: return getattr(self.queue, name) async def __aenter__(self): async def __aenter__(self) -> anyio.Queue: await self.lock.__aenter__() self.queue = queue = anyio.create_queue(1) return queue async def __aexit__(self, *exc_info): async def __aexit__(self, *exc_info: Any) -> None: assert self.queue is not None self.queue, queue = None, self.queue await self.lock.__aexit__(*exc_info) assert queue.empty(), "Reply queue was not processed" wpa_supplicant/util.py +2 −2 Original line number Diff line number Diff line # Copyright 2019 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. Loading Loading @@ -57,4 +57,4 @@ def kv2dict(keyvalues: str) -> StringMap: """ Convert a list of line-terminated "key=value" substrings into a dictionary """ return dict(kv.split("=", 1) for kv in keyvalues.splitlines()) return dict(kv.split("=", 1) for kv in keyvalues.splitlines()) # type: ignore Loading
wpa_supplicant/client/base.py +52 −34 Original line number Diff line number Diff line # Copyright 2019 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. Loading @@ -16,19 +16,28 @@ This module provides a base WPA-Supplicant client implementation """ from __future__ import annotations import contextlib import enum import logging import os import pathlib from re import compile as regex from types import TracebackType as Traceback from typing import Any from typing import AsyncContextManager from typing import Callable from typing import Dict from typing import Generator from typing import Optional from typing import Sequence from typing import Set from typing import Tuple from typing import Type import anyio from anyio.abc import SocketStream from .. import errors from .. import util Loading @@ -44,21 +53,20 @@ class EventPriority(enum.IntEnum): Event Message priorities """ def get_logger_level(self, *, _mapping={}): def get_logger_level(self, *, _mapping: Dict[EventPriority, int] = {}) -> int: """ Return a logging level matching the `wpa_supplicant` priority level """ if not _mapping: # fmt: off cls = type(self) _mapping.update({ self.MSGDUMP: logging.DEBUG, self.DEBUG: logging.DEBUG, self.INFO: logging.INFO, self.NOTICE: logging.INFO, self.WARNING: logging.WARNING, self.ERROR: logging.ERROR, cls.MSGDUMP: logging.DEBUG, cls.DEBUG: logging.DEBUG, cls.INFO: logging.INFO, cls.NOTICE: logging.INFO, cls.WARNING: logging.WARNING, cls.ERROR: logging.ERROR, }) # fmt: on return _mapping[self] MSGDUMP = 0 Loading @@ -79,22 +87,27 @@ class BaseClient: event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?") def __init__(self, *, logger=None): def __init__(self, *, logger: Optional[logging.Logger] = None): self.logger = logger or logging.getLogger(__package__) self.ctrl_dir = None self.sock = None self.sock: Optional[SocketStream] = None self._lock = anyio.create_lock() self._reply = ReplyManager() self._eventqueues = dict() self._eventqueues: Dict[str, Set[anyio.Queue]] = dict() self._eventcount = 0 async def __aenter__(self): async def __aenter__(self) -> BaseClient: return self async def __aexit__(self, *exc_info): async def __aexit__( self, _et: Optional[Type[BaseException]], _e: Optional[BaseException], _tb: Optional[Traceback], ) -> None: await self.disconnect() async def connect(self, path: PathLike): async def connect(self, path: PathLike) -> None: """ Connect to a WPA-Supplicant daemon through the given address """ Loading @@ -104,11 +117,11 @@ class BaseClient: if not isinstance(path, pathlib.Path): path = pathlib.Path(os.fspath(path)) async with anyio.fail_after(1): async with anyio.fail_after(1.0): self.sock = await util.connect_unix_datagram(path.as_posix()) await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG) async def disconnect(self): async def disconnect(self) -> None: """ Disconnect from the connected daemon, if connected """ Loading @@ -121,7 +134,7 @@ class BaseClient: *args: str, separator: str = consts.SEPARATOR_TAB, expect: str = consts.RESPONSE_OK, convert: Optional[Callable] = None, convert: Optional[Callable[[Any], Any]] = None, ) -> Any: """ Send a message and await a response Loading @@ -143,6 +156,9 @@ class BaseClient: a command that does not take arguments, or vice versa. Raises ValueError """ if self.sock is None: raise RuntimeError("Client is not connected") if args: message = f"{message} {separator.join(args)}" msgbytes = message.encode() Loading Loading @@ -171,7 +187,7 @@ class BaseClient: ) return None def attach(self): def attach(self) -> AsyncContextManager[None]: """ Return a context manager that handles attaching to the daemon's message queue """ Loading @@ -185,15 +201,16 @@ class BaseClient: with self._events_queue(events) as queue: while queue.empty(): await self._process(queue) return await queue.get() return await queue.get() # type: ignore async def _process(self, queue: anyio.Queue): async def _process(self, queue: anyio.Queue) -> None: async with self._lock: # Shortcut if the queue of interest has a message from another call # to _process() (probably in another coroutine) if not queue.empty(): return assert self.sock is not None msg = (await self.sock.recv(MAX_DGRAM_READ)).decode().strip() self.logger.debug("Received: %s", repr(msg)) Loading @@ -206,10 +223,10 @@ class BaseClient: self.logger.warning("Unexpected response message: %s", msg) return prio, name, msg = match.groups() prio = EventPriority(int(prio)) prio_, name, msg = match.groups() prio = EventPriority(int(prio_)) if name is None: if not name: self.logger.log(prio.get_logger_level(), msg) return Loading @@ -222,7 +239,7 @@ class BaseClient: await msgqueue.put((prio, name, msg)) @contextlib.contextmanager def _events_queue(self, events: Sequence[str]): def _events_queue(self, events: Sequence[str]) -> Generator[anyio.Queue, None, None]: evtqueues = self._eventqueues queue = anyio.create_queue(1) for evt in events: Loading @@ -238,17 +255,17 @@ class BaseClient: evtqueues[evt].remove(queue) class _AttachContext: def __init__(self, client): def __init__(self, client: BaseClient): self.client = client async def __aenter__(self): async def __aenter__(self) -> None: client = self.client assert client._eventcount >= 0 if client._eventcount == 0: await client.send_command(consts.COMMAND_ATTACH) client._eventcount += 1 async def __aexit__(self, *exc_info): async def __aexit__(self, *exc_info: Any) -> None: client = self.client assert client._eventcount > 0 client._eventcount -= 1 Loading @@ -264,19 +281,20 @@ class ReplyManager: A context manager supplying a locked reply queue """ def __init__(self): def __init__(self) -> None: self.lock = anyio.create_lock() self.queue = None self.queue: Optional[anyio.Queue] = None def __getattr__(self, name): def __getattr__(self, name: str) -> Any: return getattr(self.queue, name) async def __aenter__(self): async def __aenter__(self) -> anyio.Queue: await self.lock.__aenter__() self.queue = queue = anyio.create_queue(1) return queue async def __aexit__(self, *exc_info): async def __aexit__(self, *exc_info: Any) -> None: assert self.queue is not None self.queue, queue = None, self.queue await self.lock.__aexit__(*exc_info) assert queue.empty(), "Reply queue was not processed"
wpa_supplicant/util.py +2 −2 Original line number Diff line number Diff line # Copyright 2019 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. Loading Loading @@ -57,4 +57,4 @@ def kv2dict(keyvalues: str) -> StringMap: """ Convert a list of line-terminated "key=value" substrings into a dictionary """ return dict(kv.split("=", 1) for kv in keyvalues.splitlines()) return dict(kv.split("=", 1) for kv in keyvalues.splitlines()) # type: ignore