Verified Commit 8121b89a authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Modernise tests

parent 0f1145c8
Loading
Loading
Loading
Loading
+14 −60
Original line number Diff line number Diff line
@@ -16,73 +16,27 @@
Anyio helpers for unit tests
"""

import sys
from functools import wraps
from typing import Awaitable
from typing import Callable
from typing import Literal
from typing import Tuple
from typing import Union
from unittest import TestCase
from unittest import mock
from warnings import warn

import anyio

try:
	import trio as _  # noqa
	USE_TRIO = True
except ImportError:
	USE_TRIO = False

Backend = Union[Literal['asyncio'], Literal['trio']]
def _delay_side_effect(delay: float) -> Awaitable[None]:
	async def coro(*a: object, **k: object) -> None:
		await anyio.sleep(delay)
	return coro

py_version = sys.version_info[:2]

AsyncTestFunc = Callable[..., Awaitable[None]]
TestFunc = Callable[..., None]


def with_anyio(*backends: Backend, timeout: int = 10) -> Callable[[AsyncTestFunc], TestFunc]:
	"""
	Create a wrapping decorator to run asynchronous test functions
	"""
	if not backends:
		backends = ('asyncio',)

	def decorator(testfunc: AsyncTestFunc) -> TestFunc:
		async def test_async_wrapper(tc: TestCase, args: Tuple[mock.Mock]) -> None:
			with anyio.fail_after(timeout):
				await testfunc(tc, *args)

		@wraps(testfunc)
		def test_wrapper(tc: TestCase, *args: mock.Mock) -> None:
			for backend in backends:
				if backend == 'trio' and not USE_TRIO:
					warn(
						f"not running {testfunc.__name__} with trio; package is missing",
def patch_connect(delay: float = 0.0) -> mock._patch:
	return mock.patch(
		"wpa_supplicant.client.base.connect_unix_datagram",
		side_effect=_delay_side_effect(delay),
	)
					continue
				with tc.subTest(f"backend: {backend}"):
					anyio.run(test_async_wrapper, tc, args, backend=backend)

		return test_wrapper

	return decorator


class AsyncMock(mock.Mock):
	"""
	A Mock class that acts as a coroutine when called
	"""

	def __init__(self, *args: object, delay: float = 0.0, **kwargs: object):
		mock._safe_super(AsyncMock, self).__init__(*args, **kwargs)  # type: ignore
		self.delay = delay

	async def __call__(_mock_self, *args: object, **kwargs: object) -> object:
		_mock_self._mock_check_sig(*args, **kwargs)
		if py_version >= (3, 8):
			_mock_self._increment_mock_call(*args, **kwargs)
		await anyio.sleep(_mock_self.delay)
		return _mock_self._mock_call(*args, **kwargs)
def patch_send(delay: float = 0.0) -> mock._patch:
	return mock.patch(
		"wpa_supplicant.client.base.BaseClient.send_command",
		side_effect=_delay_side_effect(delay),
	)
+5 −8
Original line number Diff line number Diff line
#  Copyright 2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2021, 2024  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -18,22 +18,20 @@ Test connecting and communicating with a server

import os
import sys
from unittest import TestCase
import unittest

from tests._anyio import with_anyio
from tests.integration.util import start_server
from wpa_supplicant.client import GlobalClient


class Tests(TestCase):
class Tests(unittest.IsolatedAsyncioTestCase):
	"""
	Tests against live wpa_suppplicant servers

	The 'wpa_supplicant' executable is required in a PATH directory for these tests to work.
	"""

	@with_anyio('asyncio', 'trio')
	async def test_connect(self):
	async def test_connect(self) -> None:
		"""
		Test connecting to the global wpa_supplicant control socket
		"""
@@ -42,8 +40,7 @@ class Tests(TestCase):
			ifaces = await client.list_interfaces()
			assert len(ifaces) == 0

	@with_anyio('asyncio', 'trio')
	async def test_new_interface(self):
	async def test_new_interface(self) -> None:
		"""
		Test adding a wireless interface and scanning for stations

+47 −75
Original line number Diff line number Diff line
#  Copyright 2019-2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2019-2021, 2024  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -17,74 +17,62 @@ Test cases for wpa_supplicant.client.base.BaseClient
"""

import unittest
from unittest import mock
from unittest.mock import AsyncMock

import anyio

from tests import _anyio as anyio_mock
from tests._anyio import patch_connect
from tests._anyio import patch_send
from wpa_supplicant import errors
from wpa_supplicant.client import base


@mock.patch(
	"wpa_supplicant.client.base.connect_unix_datagram",
	new_callable=anyio_mock.AsyncMock,
)
@mock.patch(
	"wpa_supplicant.client.base.BaseClient.send_command",
	new_callable=anyio_mock.AsyncMock,
)
class ConnectTests(unittest.TestCase):
class ConnectTests(unittest.IsolatedAsyncioTestCase):
	"""
	Tests for the connect() method
	"""

	@anyio_mock.with_anyio()
	async def test_connect(self, _, connect_mock):
	async def test_connect(self) -> None:
		"""
		Check connect() calls socket.connect()
		"""
		with patch_connect() as connect_mock, patch_send():
			async with base.BaseClient() as client:
				await client.connect("foo")

		connect_mock.assert_called_once_with("foo")
		connect_mock.assert_awaited_once_with("foo")

	@anyio_mock.with_anyio()
	async def test_connect_timeout_1(self, _, connect_mock):
	async def test_connect_timeout_1(self) -> None:
		"""
		Check a socket.connect() delay causes TimeoutError to be raised
		"""
		connect_mock.delay = 2

		with patch_connect(2.0), patch_send():
			async with base.BaseClient() as client:
				with self.assertRaises(TimeoutError):
					await client.connect("foo")

	@anyio_mock.with_anyio()
	async def test_connect_timeout_2(self, send_mock, _):
	async def test_connect_timeout_2(self) -> None:
		"""
		Check a send/recv delay causes a TimeoutError to be raised
		"""
		send_mock.delay = 2

		with patch_connect(), patch_send(2.0):
			async with base.BaseClient() as client:
				with self.assertRaises(TimeoutError):
					await client.connect("foo")


class SendMessageTests(unittest.TestCase):
class SendMessageTests(unittest.IsolatedAsyncioTestCase):
	"""
	Tests for the send_command() method
	"""

	def setUp(self):
	def setUp(self) -> None:
		self.client = client = base.BaseClient()
		client.sock = anyio_mock.AsyncMock(spec=anyio.abc.SocketStream)
		client.sock = AsyncMock(spec=anyio.abc.SocketStream)
		client.sock.send.return_value = None
		assert isinstance(client.sock, anyio.abc.SocketStream)

	@anyio_mock.with_anyio()
	async def test_simple(self):
	async def test_simple(self) -> None:
		"""
		Check that a response is processed after a command
		"""
@@ -92,8 +80,7 @@ class SendMessageTests(unittest.TestCase):
			client.sock.receive.return_value = b"OK"
			assert await client.send_command("SOME_COMMAND") is None

	@anyio_mock.with_anyio()
	async def test_simple_expect(self):
	async def test_simple_expect(self) -> None:
		"""
		Check that an alternate expected response is processed
		"""
@@ -101,8 +88,7 @@ class SendMessageTests(unittest.TestCase):
			client.sock.receive.return_value = b"PONG"
			assert await client.send_command("PING", expect="PONG") is None

	@anyio_mock.with_anyio()
	async def test_simple_no_expect(self):
	async def test_simple_no_expect(self) -> None:
		"""
		Check that an unexpected response raises an UnexpectedResponseError
		"""
@@ -113,8 +99,7 @@ class SendMessageTests(unittest.TestCase):
			with self.assertRaises(errors.UnexpectedResponseError):
				await client.send_command("PING", expect="PONG")

	@anyio_mock.with_anyio()
	async def test_simple_convert(self):
	async def test_simple_convert(self) -> None:
		"""
		Check that a response is passed through a converter if given
		"""
@@ -127,8 +112,7 @@ class SendMessageTests(unittest.TestCase):
				["FOO", "BAR", "BAZ"],
			)

	@anyio_mock.with_anyio()
	async def test_simple_convert_over_expect(self):
	async def test_simple_convert_over_expect(self) -> None:
		"""
		Check that 'convert' overrides 'expect'
		"""
@@ -141,8 +125,7 @@ class SendMessageTests(unittest.TestCase):
				["FOO", "BAR", "BAZ"],
			)

	@anyio_mock.with_anyio()
	async def test_simple_fail(self):
	async def test_simple_fail(self) -> None:
		"""
		Check that a response of 'FAIL' causes CommandFailed to be raised
		"""
@@ -151,8 +134,7 @@ class SendMessageTests(unittest.TestCase):
			with self.assertRaises(errors.CommandFailed):
				await client.send_command("SOME_COMMAND")

	@anyio_mock.with_anyio()
	async def test_simple_bad_command(self):
	async def test_simple_bad_command(self) -> None:
		"""
		Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised
		"""
@@ -161,8 +143,7 @@ class SendMessageTests(unittest.TestCase):
			with self.assertRaises(ValueError):
				await client.send_command("SOME_COMMAND")

	@anyio_mock.with_anyio()
	async def test_interleaved(self):
	async def test_interleaved(self) -> None:
		"""
		Check that messages are processed alongside replies
		"""
@@ -175,8 +156,7 @@ class SendMessageTests(unittest.TestCase):
			]
			assert await client.send_command("SOME_COMMAND") is None

	@anyio_mock.with_anyio()
	async def test_unexpected(self):
	async def test_unexpected(self) -> None:
		"""
		Check that unexpected replies are logged cleanly
		"""
@@ -190,8 +170,7 @@ class SendMessageTests(unittest.TestCase):
			]
			assert await client.event("CTRL-EVENT-EXAMPLE")

	@anyio_mock.with_anyio()
	async def test_unconnected(self):
	async def test_unconnected(self) -> None:
		"""
		Check that calling send_command() on an unconnected client raises RuntimeError
		"""
@@ -200,8 +179,7 @@ class SendMessageTests(unittest.TestCase):
		with self.assertRaises(RuntimeError):
			await client.send_command("SOME_COMMAND")

	@anyio_mock.with_anyio()
	async def test_multi_task(self):
	async def test_multi_task(self) -> None:
		"""
		Check that calling send_command() from multiple tasks works as expected
		"""
@@ -213,7 +191,7 @@ class SendMessageTests(unittest.TestCase):
			(0.0, b"OK"),           # Response to DETACH
		])

		async def recv():
		async def recv() -> bytes:
			delay, data = next(recv_responses)
			await anyio.sleep(delay)
			return data
@@ -222,7 +200,7 @@ class SendMessageTests(unittest.TestCase):
			client.sock.receive.side_effect = recv

			@task_group.start_soon
			async def wait_for_event():
			async def wait_for_event() -> None:
				self.assertTupleEqual(
					await client.event("CTRL-FOO"),
					(base.EventPriority.INFO, "CTRL-FOO", None),
@@ -235,8 +213,7 @@ class SendMessageTests(unittest.TestCase):
			# At this point the response to SOME_COMMAND1 is still delayed
			await client.send_command("SOME_COMMAND2", expect="REPLY2")

	@anyio_mock.with_anyio()
	async def test_multi_task_decode_error(self):
	async def test_multi_task_decode_error(self) -> None:
		"""
		Check that decode errors closes the socket and causes all tasks to raise EOFError
		"""
@@ -251,27 +228,26 @@ class SendMessageTests(unittest.TestCase):
			client.sock.receive.side_effect = recv_responses

			@task_group.start_soon
			async def wait_for_event():
			async def wait_for_event() -> None:
				with self.assertRaises(anyio.ClosedResourceError):
					await client.event("CTRL-FOO"),
					await client.event("CTRL-FOO")
			await anyio.sleep(0.1)  # Ensure send_command("ATTACH") has been sent

			with self.assertRaises(anyio.ClosedResourceError):
				await client.send_command("SOME_COMMAND", expect="REPLY")


class EventTests(unittest.TestCase):
class EventTests(unittest.IsolatedAsyncioTestCase):
	"""
	Tests for the event() method
	"""

	def setUp(self):
	def setUp(self) -> None:
		self.client = client = base.BaseClient()
		client.sock = anyio_mock.AsyncMock()
		client.sock = AsyncMock()
		client.sock.send.return_value = None

	@anyio_mock.with_anyio()
	async def test_simple(self):
	async def test_simple(self) -> None:
		"""
		Check that an awaited message is returned when is arrives
		"""
@@ -287,8 +263,7 @@ class EventTests(unittest.TestCase):
				assert evt == "CTRL-EVENT-EXAMPLE"
				assert args is None

	@anyio_mock.with_anyio()
	async def test_multiple(self):
	async def test_multiple(self) -> None:
		"""
		Check that an awaited messages is returned when it arrives between others
		"""
@@ -307,8 +282,7 @@ class EventTests(unittest.TestCase):
				assert evt == "CTRL-EVENT-EXAMPLE"
				assert args is None

	@anyio_mock.with_anyio()
	async def test_wait_multiple(self):
	async def test_wait_multiple(self) -> None:
		"""
		Check that the first of several awaited events is returned
		"""
@@ -330,8 +304,7 @@ class EventTests(unittest.TestCase):
				assert evt == "CTRL-EVENT-EXAMPLE3"
				assert args is None

	@anyio_mock.with_anyio()
	async def test_interleaved(self):
	async def test_interleaved(self) -> None:
		"""
		Check that messages are processed as well as replies
		"""
@@ -357,8 +330,7 @@ class EventTests(unittest.TestCase):

				assert await client.send_command("SOME_COMMAND", expect="FOO") is None

	@anyio_mock.with_anyio()
	async def test_unconnected(self):
	async def test_unconnected(self) -> None:
		"""
		Check that calling event() on an unconnected client raises RuntimeError
		"""
+14 −20
Original line number Diff line number Diff line
#  Copyright 2019-2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2019-2021, 2024  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -18,25 +18,24 @@ Test cases for wpa_supplicant.client.GlobalClient

import pathlib
import unittest
from unittest.mock import AsyncMock
from unittest.mock import patch

from tests import _anyio as anyio_mock
from wpa_supplicant.client import GlobalClient
from wpa_supplicant.client import InterfaceClient


class InterfaceMethodsTests(unittest.TestCase):
class InterfaceMethodsTests(unittest.IsolatedAsyncioTestCase):
	"""
	Tests for the *_interface(s?) methods
	"""

	def setUp(self):
	def setUp(self) -> None:
		self.client = client = GlobalClient()
		client.sock = anyio_mock.AsyncMock()
		client.sock = AsyncMock()
		client.sock.send.return_value = None

	@anyio_mock.with_anyio()
	async def test_connect(self):
	async def test_connect(self) -> None:
		"""
		Check that connect sets ctrl_dir
		"""
@@ -45,7 +44,7 @@ class InterfaceMethodsTests(unittest.TestCase):

		with patch(
			"wpa_supplicant.client.base.BaseClient.connect",
			new_callable=anyio_mock.AsyncMock,
			new_callable=AsyncMock,
		):
			await client1.connect("/tmp/foo/bar")
			await client2.connect(pathlib.Path("/tmp/foo/bar"))
@@ -56,8 +55,7 @@ class InterfaceMethodsTests(unittest.TestCase):
		assert client1.ctrl_dir == pathlib.Path("/tmp/foo")
		assert client2.ctrl_dir == pathlib.Path("/tmp/foo")

	@anyio_mock.with_anyio()
	async def test_list_interfaces(self):
	async def test_list_interfaces(self) -> None:
		"""
		Check list_interfaces() processes lines of names in a list
		"""
@@ -76,8 +74,7 @@ class InterfaceMethodsTests(unittest.TestCase):

			client.sock.send.assert_called_once_with(b"INTERFACES")

	@anyio_mock.with_anyio()
	async def test_add_interface(self):
	async def test_add_interface(self) -> None:
		"""
		Check add_interface() sends the correct arguments
		"""
@@ -93,10 +90,9 @@ class InterfaceMethodsTests(unittest.TestCase):

	@patch(
		"wpa_supplicant.client.interfaces.InterfaceClient.connect",
		new_callable=anyio_mock.AsyncMock,
		new_callable=AsyncMock,
	)
	@anyio_mock.with_anyio()
	async def test_connect_interface(self, connect_mock):
	async def test_connect_interface(self, connect_mock: AsyncMock) -> None:
		"""
		Check connect_interface() returns a connected InterfaceClient
		"""
@@ -116,10 +112,9 @@ class InterfaceMethodsTests(unittest.TestCase):

	@patch(
		"wpa_supplicant.client.interfaces.InterfaceClient.connect",
		new_callable=anyio_mock.AsyncMock,
		new_callable=AsyncMock,
	)
	@anyio_mock.with_anyio()
	async def test_connect_interface_with_add(self, connect_mock):
	async def test_connect_interface_with_add(self, connect_mock: AsyncMock) -> None:
		"""
		Check connect_interface() adds the interface when not already managed
		"""
@@ -140,8 +135,7 @@ class InterfaceMethodsTests(unittest.TestCase):
			self.assertTupleEqual(args[0][0], (b"INTERFACES",))
			assert args[1][0][0].startswith(b"INTERFACE_ADD enp1s0\t")

	@anyio_mock.with_anyio()
	async def test_unconnected(self):
	async def test_unconnected(self) -> None:
		"""
		Check that calling add_interface() on an unconnected client raises RuntimeError

+10 −12
Original line number Diff line number Diff line
#  Copyright 2019-2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2019-2021, 2024  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -17,26 +17,27 @@ Test cases for wpa_supplicant.client.interfaces.InterfaceClient
"""

import unittest
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import AsyncMock
from unittest.mock import call

from tests import _anyio as anyio_mock
from wpa_supplicant import config
from wpa_supplicant.client import interfaces


class MethodsTests(unittest.TestCase):
class MethodsTests(unittest.IsolatedAsyncioTestCase):
	"""
	Tests for InterfaceClient methods
	"""

	def setUp(self):
	def setUp(self) -> None:
		self.client = client = interfaces.InterfaceClient()
		client.sock = anyio_mock.AsyncMock()
		client.sock = AsyncMock()
		client.sock.send.return_value = None

	@contextmanager
	def subTest(self, *args, reset=[], **kwargs):
	def subTest(self, *args: object, reset: list[AsyncMock] = [], **kwargs: object) -> Iterator[None]:
		with super().subTest(*args, **kwargs):
			try:
				yield
@@ -44,8 +45,7 @@ class MethodsTests(unittest.TestCase):
				for mock in reset:
					mock.reset_mock()

	@anyio_mock.with_anyio()
	async def test_scan(self):
	async def test_scan(self) -> None:
		"""
		Check that a scan command waits for a notification then terminates correctly
		"""
@@ -64,8 +64,7 @@ class MethodsTests(unittest.TestCase):
				self.assertIsInstance(bss, dict)
				self.assertIn("good", bss)

	@anyio_mock.with_anyio()
	async def test_set_network(self):
	async def test_set_network(self) -> None:
		"""
		Check that set_network sends values to the daemon and raises TypeError for bad types
		"""
@@ -105,8 +104,7 @@ class MethodsTests(unittest.TestCase):
					self.assertRaises(TypeError):
				await client.set_network("0", "key_mgmt", 1)

	@anyio_mock.with_anyio()
	async def test_add_network(self):
	async def test_add_network(self) -> None:
		"""
		Check that add_network adds a new network and configures it
		"""