Commit 2d7b5f2d authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add working but lite client code

parent 471ad354
Loading
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -32,7 +32,9 @@ setup_requires =
install_requires =
  anyio >=1.0
tests_require =
  anyio[curio]
  coverage <5
  mocket[speedups]
  nose2[coverage_plugin]
test_suite = nose2.collector.collector
lint_rcfile = setup.cfg

tests/unit/__init__.py

0 → 100644
+0 −0

Empty file added.

tests/unit/anyio.py

0 → 100644
+53 −0
Original line number Diff line number Diff line
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Anyio helpers for unit tests
"""

from functools import wraps
from unittest import mock

import anyio


def with_anyio(timeout=10):
	"""
	Create a wrapping decorator to run asynchronous test functions
	"""
	def decorator(testfunc):
		async def test_async_wrapper(args):
			async with anyio.fail_after(timeout):
				return await testfunc(*args)
		@wraps(testfunc)
		def test_wrapper(*args):
			return anyio.run(test_async_wrapper, args)
		return test_wrapper
	return decorator


class AsyncMock(mock.Mock):
	"""
	A Mock class that acts as a coroutine when called
	"""

	def __init__(self, *args, delay=0, **kwargs):
		mock._safe_super(AsyncMock, self).__init__(*args, **kwargs)
		self.delay = delay

	# pylint: disable=no-self-argument
	async def __call__(_mock_self, *args, **kwargs):
		_mock_self._mock_check_sig(*args, **kwargs)
		await anyio.sleep(_mock_self.delay)
		return _mock_self._mock_call(*args, **kwargs)
+279 −0
Original line number Diff line number Diff line
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Test cases for wpa_supplicant.client.base.BaseClient
"""

import unittest
from unittest import mock

import anyio

from tests.unit import anyio as anyio_mock

from wpa_supplicant import errors
from wpa_supplicant.client import base


@mock.patch('wpa_supplicant.util.connect_unix_datagram', new_callable=anyio_mock.AsyncMock)
@mock.patch('wpa_supplicant.client.base.BaseClient.send_message', new_callable=anyio_mock.AsyncMock)
class ConnectTests(unittest.TestCase):
	"""
	Tests for the connect() method
	"""

	@anyio_mock.with_anyio()
	async def test_connect(self, _, connect_mock):
		"""
		Check connect() calls socket.connect()
		"""
		async with base.BaseClient() as client:
			await client.connect("foo")

		connect_mock.assert_called_once_with("foo")

	@anyio_mock.with_anyio()
	async def test_connect_timeout_1(self, _, connect_mock):
		"""
		Check a socket.connect() delay causes TimeoutError to be raised
		"""
		connect_mock.delay = 2

		async with base.BaseClient() as client:
			with self.assertRaises(TimeoutError):
				await client.connect("foo")

	@anyio_mock.with_anyio()
	async def test_connect_timeout_2(self, send_mock, _):
		"""
		Check a send/recv delay causes a TimeoutError to be raised
		"""
		send_mock.delay = 2

		async with base.BaseClient() as client:
			with self.assertRaises(TimeoutError):
				await client.connect("foo")


class SendMessageTests(unittest.TestCase):
	"""
	Tests for the send_message() method
	"""

	def setUp(self):
		async def _make_client():
			return base.BaseClient()
		self.client = client = anyio.run(_make_client)
		client.sock = anyio_mock.AsyncMock()
		client.sock.send.side_effect = len

	@anyio_mock.with_anyio()
	async def test_simple(self):
		"""
		Check that a response is processed after a command
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"OK"
			assert await client.send_message("SOME_COMMAND") is None

	@anyio_mock.with_anyio()
	async def test_simple_expect(self):
		"""
		Check that an alternate expected response is processed
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"PONG"
			assert await client.send_message("PING", expect="PONG") is None

	@anyio_mock.with_anyio()
	async def test_simple_no_expect(self):
		"""
		Check that an unexpected response raises an UnexpectedResponseError
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"DING"
			with self.assertRaises(errors.UnexpectedResponseError):
				await client.send_message("PING")
			with self.assertRaises(errors.UnexpectedResponseError):
				await client.send_message("PING", expect="PONG")

	@anyio_mock.with_anyio()
	async def test_simple_convert(self):
		"""
		Check that a response is passed through a converter if given
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"FOO\nBAR\nBAZ\n"
			self.assertListEqual(
				await client.send_message(
					"SOME_COMMAND",
					convert=lambda x: x.splitlines(),
				),
				["FOO", "BAR", "BAZ"]
			)

	@anyio_mock.with_anyio()
	async def test_simple_convert_over_expect(self):
		"""
		Check that 'convert' overrides 'expect'
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"FOO\nBAR\nBAZ\n"
			self.assertListEqual(
				await client.send_message(
					"SOME_COMMAND",
					convert=lambda x: x.splitlines(),
					expect="PONG",
				),
				["FOO", "BAR", "BAZ"]
			)

	@anyio_mock.with_anyio()
	async def test_simple_fail(self):
		"""
		Check that a response of 'FAIL' causes CommandFailed to be raised
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"FAIL"
			with self.assertRaises(errors.CommandFailed):
				await client.send_message("SOME_COMMAND")

	@anyio_mock.with_anyio()
	async def test_simple_bad_command(self):
		"""
		Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised
		"""
		async with self.client as client:
			client.sock.recv.return_value = b"UNKNOWN COMMAND"
			with self.assertRaises(ValueError):
				await client.send_message("SOME_COMMAND")

	@anyio_mock.with_anyio()
	async def test_simple_too_large(self):
		"""
		Check that MessageTooLargeError is raised if send() reports an incomplete send
		"""
		async with self.client as client:
			client.sock.send.side_effect = lambda b: len(b) / 2
			with self.assertRaises(errors.MessageTooLargeError):
				await client.send_message("SOME_COMMAND")

	@anyio_mock.with_anyio()
	async def test_interleaved(self):
		"""
		Check that messages are processed alongside replies
		"""
		async with self.client as client:
			client.sock.recv.side_effect = [
				b"<2>SOME-MESSAGE",
				b"<1>SOME-OTHER-MESSAGE with|args",
				b"OK",
				b"<2>SOME-MESSAGE"
			]
			assert await client.send_message("SOME_COMMAND") is None


class EventTests(unittest.TestCase):
	"""
	Tests for the event() method
	"""

	def setUp(self):
		async def _make_client():
			return base.BaseClient()
		self.client = client = anyio.run(_make_client)
		client.sock = anyio_mock.AsyncMock()
		client.sock.send.side_effect = len

	@anyio_mock.with_anyio()
	async def test_simple(self):
		"""
		Check that an awaited message is returned when is arrives
		"""
		async with self.client as client, anyio.fail_after(2):
			client.sock.recv.side_effect = [
				b"OK",  # Respond to ATTACH
				b"<2>EXAMPLE-EVENT",
				b"OK",  # Respond to DETACH
			]
			prio, evt, args = await client.event("EXAMPLE-EVENT")
			assert prio == 2
			assert evt == "EXAMPLE-EVENT"
			assert args is None

	@anyio_mock.with_anyio()
	async def test_multiple(self):
		"""
		Check that an awaited messages is returned when it arrives between others
		"""
		async with self.client as client, anyio.fail_after(2):
			client.sock.recv.side_effect = [
				b"OK",  # Respond to ATTACH
				b"<1>OTHER-MESSAGE",
				b"<2>OTHER-MESSAGE",
				b"<4>EXAMPLE-EVENT",
				b"OK",  # Respond to DETACH
				b"<3>OTHER-MESSAGE",
			]
			prio, evt, args = await client.event("EXAMPLE-EVENT")
			assert prio == 4
			assert evt == "EXAMPLE-EVENT"
			assert args is None

	@anyio_mock.with_anyio()
	async def test_wait_multiple(self):
		"""
		Check that the first of several awaited events is returned
		"""
		async with self.client as client, anyio.fail_after(2):
			client.sock.recv.side_effect = [
				b"OK",  # Respond to ATTACH
				b"<1>OTHER-MESSAGE",
				b"<2>OTHER-MESSAGE",
				b"<4>EXAMPLE-EVENT3",
				b"<4>EXAMPLE-EVENT1",
				b"OK",  # Respond to DETACH
				b"<3>OTHER-MESSAGE",
			]
			prio, evt, args = await client.event("EXAMPLE-EVENT1", "EXAMPLE-EVENT2", "EXAMPLE-EVENT3")
			assert prio == 4
			assert evt == "EXAMPLE-EVENT3"
			assert args is None

	@anyio_mock.with_anyio()
	async def test_interleaved(self):
		"""
		Check that messages are processed as well as replies
		"""
		async with self.client as client, anyio.fail_after(2):
			client.sock.recv.side_effect = [
				b"<1>OTHER-MESSAGE",
				b"OK",  # Respond to SOME_COMMAND
				b"OK",  # Respond to ATTACH
				b"<2>OTHER-MESSAGE",
				b"<4>EXAMPLE-EVENT",
				b"<3>OTHER-MESSAGE",
				b"OK",  # Respond to DETACH
				b"FOO",
			]

			assert await client.send_message("SOME_COMMAND") is None

			prio, evt, args = await client.event("EXAMPLE-EVENT")
			assert prio == 4
			assert evt == "EXAMPLE-EVENT"
			assert args is None

			assert await client.send_message("SOME_COMMAND", expect="FOO") is None
+138 −0
Original line number Diff line number Diff line
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Test cases for wpa_supplicant.client.master.MasterClient
"""

import pathlib
import unittest
from unittest.mock import patch

import anyio

from tests.unit import anyio as anyio_mock

from wpa_supplicant.client import interfaces, master


class InterfaceMethodsTests(unittest.TestCase):
	"""
	Tests for the *_interface(s?) methods
	"""

	def setUp(self):
		async def _make_client():
			return master.MasterClient()
		self.client = client = anyio.run(_make_client)
		client.sock = anyio_mock.AsyncMock()
		client.sock.send.side_effect = len

	@anyio_mock.with_anyio()
	async def test_connect(self):
		"""
		Check that connect sets ctrl_dir
		"""
		client1 = master.MasterClient()
		client2 = master.MasterClient()

		with patch("wpa_supplicant.client.base.BaseClient.connect",
				new_callable=anyio_mock.AsyncMock):
			await client1.connect("/tmp/foo/bar")
			await client2.connect(pathlib.Path("/tmp/foo/bar"))

		self.assertIsInstance(client1.ctrl_dir, pathlib.Path)
		self.assertIsInstance(client2.ctrl_dir, pathlib.Path)

		assert client1.ctrl_dir == pathlib.Path("/tmp/foo")
		assert client2.ctrl_dir == pathlib.Path("/tmp/foo")

	@anyio_mock.with_anyio()
	async def test_managed_interfaces(self):
		"""
		Check managed_interfaces() processes lines of names in a list
		"""
		async with self.client as client:
			client.sock.recv.return_value = (
				b"enp0s0\n"
				b"enp1s0\n"
				b"wlp2s0\n"
			)

			self.assertListEqual(
				await client.managed_interfaces(),
				["enp0s0", "enp1s0", "wlp2s0"]
			)

			client.sock.send.assert_called_once_with(b"INTERFACES")

	@anyio_mock.with_anyio()
	async def test_add_interface(self):
		"""
		Check add_interface() sends the correct arguments
		"""
		async with self.client as client:
			client.ctrl_dir = pathlib.Path('/tmp')
			client.sock.recv.return_value = b"OK"

			assert await client.add_interface('enp1s0', driver='wired') is None

			client.sock.send.assert_called_once()
			args = client.sock.send.call_args[0]
			assert args[0].startswith(b"INTERFACE_ADD enp1s0\t\twired\t/tmp\t")

	@patch('wpa_supplicant.client.interfaces.InterfaceClient.connect',
			new_callable=anyio_mock.AsyncMock)
	@anyio_mock.with_anyio()
	async def test_connect_interface(self, connect_mock):
		"""
		Check connect_interface() returns a connected InterfaceClient
		"""
		async with self.client as client:
			client.ctrl_dir = pathlib.Path('/tmp')
			client.sock.recv.side_effect = [
				b"enp1s0\n",  # Response to INTERFACES
			]

			ifclient = await client.connect_interface('enp1s0')

			self.assertIsInstance(ifclient, interfaces.InterfaceClient)
			connect_mock.assert_called_once_with("/tmp/enp1s0")

			# Check only INTERFACES was sent
			client.sock.send.assert_called_once_with(b"INTERFACES")

	@patch('wpa_supplicant.client.interfaces.InterfaceClient.connect',
			new_callable=anyio_mock.AsyncMock)
	@anyio_mock.with_anyio()
	async def test_connect_interface_with_add(self, connect_mock):
		"""
		Check connect_interface() adds the interface when not already managed
		"""
		async with self.client as client:
			client.ctrl_dir = pathlib.Path('/tmp')
			client.sock.recv.side_effect = [
				b"",    # Response to INTERFACES
				b"OK",  # Response to INTERFACE_ADD
			]

			ifclient = await client.connect_interface('enp1s0')

			self.assertIsInstance(ifclient, interfaces.InterfaceClient)
			connect_mock.assert_called_once_with("/tmp/enp1s0")

			# Check INTERFACE_ADD sent after INTERFACES
			args = client.sock.send.call_args_list
			self.assertTupleEqual(args[0][0], (b"INTERFACES",))
			assert args[1][0][0].startswith(b"INTERFACE_ADD enp1s0\t")
Loading