Unverified Commit d9b3e34c authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Fix typing in test modules

parent f97b820a
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ class GlobalClient(BaseClient):

	ctrl_dir = None

	async def connect(self, path: PathLike[str]) -> None:
	async def connect(self, path: PathLike[str] | str) -> None:
		if not isinstance(path, pathlib.Path):
			path = pathlib.Path(path)
		await super().connect(path)
+3 −2
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ from contextlib import asynccontextmanager
from pathlib import Path
from re import compile as regex
from types import TracebackType as Traceback
from typing import Self
from typing import overload

import anyio
@@ -99,7 +100,7 @@ class BaseClient:
		self._event: EventInfo | None
		self._eventcount = 0

	async def __aenter__(self) -> BaseClient:
	async def __aenter__(self) -> Self:
		return self

	async def __aexit__(
@@ -110,7 +111,7 @@ class BaseClient:
	) -> None:
		await self.disconnect()

	async def connect(self, path: os.PathLike[str]) -> None:
	async def connect(self, path: os.PathLike[str] | str) -> None:
		"""
		Connect to a WPA-Supplicant daemon through the given address
		"""
+1 −1
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ class InterfaceClient(BaseClient):

	name = None

	async def connect(self, path: PathLike[str]) -> None:
	async def connect(self, path: PathLike[str] | str) -> None:
		"""
		Connect to an interface UNIX port
		"""
+42 −0
Original line number Diff line number Diff line
#  Copyright 2026  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Iterable
from unittest.mock import AsyncMock

import anyio.abc

from kodo.wpa_supplicant.client import base

type MockResponse = bytes | BaseException | type[BaseException]
type MockResponseCallable = Callable[[], Awaitable[MockResponse]]


class MockClientMaker[ClientType: base.BaseClient]:
	def __init__(self, client_type: type[ClientType]) -> None:
		self.client_type = client_type

	def __call__(
		self, mock_resp: MockResponse | Iterable[MockResponse] | MockResponseCallable
	) -> ClientType:
		client = self.client_type()
		client.sock = AsyncMock(spec=anyio.abc.SocketStream)
		client.sock.send.return_value = None
		if isinstance(mock_resp, bytes):
			client.sock.receive.return_value = mock_resp
		else:
			client.sock.receive.side_effect = mock_resp
		return client
+71 −89
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@ Test cases for kodo.wpa_supplicant.client.base.BaseClient
"""

import unittest
from unittest.mock import AsyncMock

import anyio

@@ -26,6 +25,10 @@ from kodo.wpa_supplicant.client import base
from tests._anyio import patch_connect
from tests._anyio import patch_send

from . import MockClientMaker

mock_client = MockClientMaker(base.BaseClient)


class ConnectTests(unittest.IsolatedAsyncioTestCase):
	"""
@@ -34,14 +37,12 @@ class ConnectTests(unittest.IsolatedAsyncioTestCase):

	async def test_connect(self) -> None:
		"""
		Check connect() calls socket.connect()
		Check connect() returns when there is no delay
		"""
		with patch_connect() as connect_mock, patch_send():
		with patch_connect(), patch_send():
			async with base.BaseClient() as client:
				await client.connect("foo")

		connect_mock.assert_awaited_once_with("foo")

	async def test_connect_timeout_1(self) -> None:
		"""
		Check a socket.connect() delay causes TimeoutError to be raised
@@ -66,34 +67,25 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
	Tests for the send_command() method
	"""

	def setUp(self) -> None:
		self.client = client = base.BaseClient()
		client.sock = AsyncMock(spec=anyio.abc.SocketStream)
		client.sock.send.return_value = None
		assert isinstance(client.sock, anyio.abc.SocketStream)

	async def test_simple(self) -> None:
		"""
		Check that a response is processed after a command
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"OK"
		async with mock_client(b"OK") as client:
			assert await client.send_command("SOME_COMMAND") is None

	async def test_simple_expect(self) -> None:
		"""
		Check that an alternate expected response is processed
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"PONG"
		async with mock_client(b"PONG") as client:
			assert await client.send_command("PING", expect="PONG") is None

	async def test_simple_no_expect(self) -> None:
		"""
		Check that an unexpected response raises an UnexpectedResponseError
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"DING"
		async with mock_client(b"DING") as client:
			with self.assertRaises(errors.UnexpectedResponseError):
				await client.send_command("PING")
			with self.assertRaises(errors.UnexpectedResponseError):
@@ -103,8 +95,7 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that a response is passed through a converter if given
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"FOO\nBAR\nBAZ\n"
		async with mock_client(b"FOO\nBAR\nBAZ\n") as client:
			self.assertListEqual(
				await client.send_command(
					"SOME_COMMAND", convert=lambda x: x.splitlines(),
@@ -116,8 +107,7 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that 'convert' overrides 'expect'
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"FOO\nBAR\nBAZ\n"
		async with mock_client(b"FOO\nBAR\nBAZ\n") as client:
			self.assertListEqual(
				await client.send_command(
					"SOME_COMMAND", convert=lambda x: x.splitlines(), expect="PONG",
@@ -129,8 +119,7 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that a response of 'FAIL' causes CommandFailed to be raised
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"FAIL"
		async with mock_client(b"FAIL") as client:
			with self.assertRaises(errors.CommandFailed):
				await client.send_command("SOME_COMMAND")

@@ -138,8 +127,7 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised
		"""
		async with self.client as client:
			client.sock.receive.return_value = b"UNKNOWN COMMAND"
		async with mock_client(b"UNKNOWN COMMAND") as client:
			with self.assertRaises(ValueError):
				await client.send_command("SOME_COMMAND")

@@ -147,28 +135,28 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that messages are processed alongside replies
		"""
		async with self.client as client:
			client.sock.receive.side_effect = [
		responses = [
			b"<2>SOME-MESSAGE",
			b"<1>SOME-OTHER-MESSAGE with|args",
			b"OK",
			b"<2>SOME-MESSAGE",
		]
		async with mock_client(responses) as client:
			assert await client.send_command("SOME_COMMAND") is None

	async def test_unexpected(self) -> None:
		"""
		Check that unexpected replies are logged cleanly
		"""
		async with self.client as client:
			client.sock.receive.side_effect = [
		responses = [
			b"OK",  # Response to "ATTACH"
			b"UNEXPECTED1",
			b"UNEXPECTED2",
			b"<2>CTRL-EVENT-EXAMPLE",
			b"OK",  # Response to "DETACH"
		]
			assert await client.event("CTRL-EVENT-EXAMPLE")
		async with mock_client(responses) as client:
			await client.event("CTRL-EVENT-EXAMPLE")

	async def test_unconnected(self) -> None:
		"""
@@ -196,15 +184,15 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
			await anyio.sleep(delay)
			return data

		async with self.client as client, anyio.create_task_group() as task_group:
			client.sock.receive.side_effect = recv
		async with mock_client(recv) as client, anyio.create_task_group() as task_group:

			@task_group.start_soon
			async def wait_for_event() -> None:
				self.assertTupleEqual(
					await client.event("CTRL-FOO"),
					(base.EventPriority.INFO, "CTRL-FOO", None),
				)
			task_group.start_soon(wait_for_event)

			await anyio.sleep(0.1)  # Ensure send_command("ATTACH") has been sent

			task_group.start_soon(client.send_command, "SOME_COMMAND1")
@@ -224,13 +212,13 @@ class SendMessageTests(unittest.IsolatedAsyncioTestCase):
			anyio.EndOfStream,
		]

		async with self.client as client, anyio.create_task_group() as task_group:
			client.sock.receive.side_effect = recv_responses
		async with mock_client(recv_responses) as client, anyio.create_task_group() as task_group:

			@task_group.start_soon
			async def wait_for_event() -> None:
				with self.assertRaises(anyio.ClosedResourceError):
					await client.event("CTRL-FOO")
			task_group.start_soon(wait_for_event)

			await anyio.sleep(0.1)  # Ensure send_command("ATTACH") has been sent

			with self.assertRaises(anyio.ClosedResourceError):
@@ -242,22 +230,17 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
	Tests for the event() method
	"""

	def setUp(self) -> None:
		self.client = client = base.BaseClient()
		client.sock = AsyncMock()
		client.sock.send.return_value = None

	async def test_simple(self) -> None:
		"""
		Check that an awaited message is returned when is arrives
		"""
		with anyio.fail_after(2):
			async with self.client as client:
				client.sock.receive.side_effect = [
		responses = [
			b"OK",  # Respond to ATTACH
			b"<2>CTRL-EVENT-EXAMPLE",
			b"OK",  # Respond to DETACH
		]
		with anyio.fail_after(2):
			async with mock_client(responses) as client:
				prio, evt, args = await client.event("CTRL-EVENT-EXAMPLE")
				assert prio == 2
				assert evt == "CTRL-EVENT-EXAMPLE"
@@ -267,9 +250,7 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that an awaited messages is returned when it arrives between others
		"""
		with anyio.fail_after(2):
			async with self.client as client:
				client.sock.receive.side_effect = [
		responses = [
			b"OK",  # Respond to ATTACH
			b"<1>OTHER-MESSAGE",
			b"<2>CTRL-EVENT-OTHER",
@@ -277,6 +258,8 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
			b"OK",  # Respond to DETACH
			b"<3>OTHER-MESSAGE",
		]
		with anyio.fail_after(2):
			async with mock_client(responses) as client:
				prio, evt, args = await client.event("CTRL-EVENT-EXAMPLE")
				assert prio == 4
				assert evt == "CTRL-EVENT-EXAMPLE"
@@ -286,9 +269,7 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that the first of several awaited events is returned
		"""
		with anyio.fail_after(2):
			async with self.client as client:
				client.sock.receive.side_effect = [
		responses = [
			b"OK",  # Respond to ATTACH
			b"<1>OTHER-MESSAGE",
			b"<2>CTRL-EVENT-OTHER",
@@ -297,6 +278,8 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
			b"OK",  # Respond to DETACH
			b"<3>CTRL-EVENT-OTHER",
		]
		with anyio.fail_after(2):
			async with mock_client(responses) as client:
				prio, evt, args = await client.event(
					"CTRL-EVENT-EXAMPLE1", "CTRL-EVENT-EXAMPLE2", "CTRL-EVENT-EXAMPLE3",
				)
@@ -308,9 +291,7 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
		"""
		Check that messages are processed as well as replies
		"""
		with anyio.fail_after(2):
			async with self.client as client:
				client.sock.receive.side_effect = [
		responses = [
			b"<1>OTHER-MESSAGE",
			b"OK",  # Respond to SOME_COMMAND
			b"OK",  # Respond to ATTACH
@@ -320,7 +301,8 @@ class EventTests(unittest.IsolatedAsyncioTestCase):
			b"OK",  # Respond to DETACH
			b"FOO",
		]

		with anyio.fail_after(2):
			async with mock_client(responses) as client:
				assert await client.send_command("SOME_COMMAND") is None

				prio, evt, args = await client.event("CTRL-EVENT-EXAMPLE")
Loading