Verified Commit 5e3f8659 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add regression tests for #21

parent d1deed91
Loading
Loading
Loading
Loading
+52 −1
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from contextlib import asynccontextmanager
from functools import wraps
from types import TracebackType
from typing import AsyncContextManager
from typing import Literal

import anyio
from anyio.streams.buffered import BufferedByteReceiveStream
@@ -25,6 +26,8 @@ P = typing.ParamSpec("P")
SendT = typing.TypeVar("SendT")
YieldT = typing.TypeVar("YieldT")

DEFAULT_PEER = "test.example.com"


def _make_aclosing(
	func: Callable[P, AsyncGenerator[YieldT, SendT]],
@@ -87,16 +90,55 @@ class MockMessageStream:
			await anyio.wait_all_tasks_blocked()
			yield stream_mock

	@classmethod
	@asynccontextmanager
	async def connected(
		cls,
		runner: Runner,
		host: str = DEFAULT_PEER,
		/,
		helo: str|None|Literal[False] = DEFAULT_PEER,
	) -> AsyncIterator[Self]:
		"""
		Return a context manager that yields a prepared and connected stream mock

		Negotiate, Connect, and optionally Helo messages will have been sent over the stream
		once the context has been entered.
		"""
		async with cls.started(runner) as self:
			await self.send_and_expect(make_negotiate(), Negotiate)
			await self.send_and_expect(Connect(host), Continue)
			if helo:
				await self.send_msg(Helo(helo))
			yield self
			if helo:
				await self._abort()
			await self.close()

	@asynccontextmanager
	async def envelope(self, sender: bytes, *recipients: bytes) -> AsyncIterator[None]:
		"""
		Return a context manager that encapsulates a message envelope
		"""
		await self.send_and_expect(EnvelopeFrom(sender), Continue)
		for recipient in recipients:
			await self.send_and_expect(EnvelopeRecipient(recipient), Continue)
		yield
		await self._abort()

	async def abort(self) -> None:
		"""
		Send Abort and close the stream
		"""
		await self._abort()
		await self.close()

	async def _abort(self) -> None:
		try:
			resp = await self.send_msg(Abort())
		except anyio.BrokenResourceError:
			return
		assert len(resp) == 0, resp
		await self.close()

	async def close(self) -> None:
		"""
@@ -158,3 +200,12 @@ class MockMessageStream:
				assert isinstance(r, e), f"expected {e}, got {type(r)}"
			else:
				assert r == e, r


def make_negotiate(options: int = 0, actions: int = 0x1ff) -> Negotiate:
	"""
	Construct a Negotiate message from integer flags

	Defaults to all actions, and no protocol options.
	"""
	return Negotiate(6, ActionFlags(actions), ProtocolFlags(options))
+210 −9
Original line number Diff line number Diff line
@@ -11,15 +11,7 @@ from kilter.service.session import Aborted

from . import AsyncTestCase
from .mock_stream import MockMessageStream


def make_negotiate(options: int = 0, actions: int = 0x1ff) -> Negotiate:
	"""
	Construct a Negotiate message from integer flags

	Defaults to all actions, and no protocol options.
	"""
	return Negotiate(6, ActionFlags(actions), ProtocolFlags(options))
from .mock_stream import make_negotiate


class RunnerTests(AsyncTestCase):
@@ -418,3 +410,212 @@ class RunnerNegotiateTests(AsyncTestCase):
				await stream_mock.send_msg(
					make_negotiate(options=0),
				)


class SessionReuseTests(AsyncTestCase):
	"""
	Tests for sessions handling multiple messages

	Most of these added as regression tests for #21
	"""

	QUOTE1 = """
	Strange women lying in ponds, distributing swords, is no basis for a system of
	government!
	""".encode("utf-8")

	QUOTE2 = """
	Supreme executive power derives from a mandate from the masses,
	not from some farcical aquatic ceremony.
	""".encode("utf-8")

	QUOTE3 = """
	He’s not the Messiah; he’s a very naughty boy!
	""".encode("utf-8")

	async def test_mail(self) -> None:
		"""
		Check each message gets its own MAIL command and a copy of the HELO/EHLO message
		"""
		counter = 0

		@Runner
		async def test_filter(session: Session) -> Accept:
			nonlocal counter
			counter += 1
			assert await session.helo() == "test.example.com"
			assert await session.envelope_from() == "test@example.com"
			return Accept()

		async with MockMessageStream.connected(test_filter) as stream:
			await stream.send_and_expect(EnvelopeFrom(b"test@example.com"), Accept)
			await stream.send_msg(Abort())
			await stream.send_and_expect(EnvelopeFrom(b"test@example.com"), Accept)

		assert counter == 2

	async def test_headers(self) -> None:
		"""
		Check that the header accessor is reset for each message
		"""
		results = list[bytes]()

		@Runner
		async def test_filter(session: Session) -> Accept:
			async with session.headers as headers:
				async for header in headers.restrict("X-Test"):
					results.append(header.value)
			return Accept()

		async with MockMessageStream.connected(test_filter) as stream:
			async with stream.envelope(b"test1@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"spam"), Continue)
				await stream.send_and_expect(Header("X-Test", b"ham"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Accept)
			async with stream.envelope(b"test2@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"eggs"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Accept)

		assert results == [b"spam", b"ham", b"eggs"], results

	async def test_body(self) -> None:
		"""
		Check that the body accessor is reset for each message
		"""
		results = list[bytes]()

		@Runner
		async def test_filter(session: Session) -> Accept:
			async with session.body as body:
				results.extend([cnk.tobytes() async for cnk in body])
			return Accept()

		async with MockMessageStream.connected(test_filter) as stream:
			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Body(self.QUOTE1), Continue)
				await stream.send_and_expect(Body(self.QUOTE2), Continue)
				await stream.send_and_expect(EndOfMessage(b""), Accept)
			assert results == [self.QUOTE1, self.QUOTE2, b""]

			del results[:]

			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Body(self.QUOTE3), Continue)
				await stream.send_and_expect(EndOfMessage(b""), Accept)
			assert results == [self.QUOTE3, b""]

	async def test_combined(self) -> None:
		"""
		Check that headers and body accessors are reset for each message
		"""
		header_list = list[bytes]()
		body_list = list[bytes]()

		@Runner
		async def test_filter(session: Session) -> Accept:
			async with session.headers as headers:
				async for header in headers.restrict("X-Test"):
					header_list.append(header.value)
			async with session.body as body:
				body_list.extend([cnk.tobytes() async for cnk in body])
			return Accept()

		async with MockMessageStream.connected(test_filter) as stream:
			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"spam"), Continue)
				await stream.send_and_expect(Header("X-Test", b"ham"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Continue)
				await stream.send_and_expect(Body(self.QUOTE1), Continue)
				await stream.send_and_expect(Body(self.QUOTE2), Continue)
				await stream.send_and_expect(EndOfMessage(b""), Accept)
			assert header_list == [b"spam", b"ham"]
			assert body_list == [self.QUOTE1, self.QUOTE2, b""]

			del header_list[:]
			del body_list[:]

			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"eggs"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Continue)
				await stream.send_and_expect(Body(self.QUOTE3), Continue)
				await stream.send_and_expect(EndOfMessage(b""), Accept)
			assert header_list == [b"eggs"]
			assert body_list == [self.QUOTE3, b""]

	async def test_abort_session(self) -> None:
		"""
		Check that aborting a session clears all session state
		"""
		results = list[bytes]()

		@Runner
		async def test_filter(session: Session) -> Accept:
			async with session.headers as headers:
				async for header in headers.restrict("X-Test"):
					results.append(header.value)
			return Accept()

		async with MockMessageStream.connected(test_filter) as stream:
			async with stream.envelope(b"test1@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"spam"), Continue)
				await stream.send_and_expect(Header("X-Test", b"ham"), Continue)
			await stream.send_msg(Abort())
			async with stream.envelope(b"test2@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"eggs"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Accept)

		assert results == [b"spam", b"ham", b"eggs"], results

	async def test_reject(self) -> None:
		"""
		Check that rejecting a message does not close a session
		"""
		header_list = list[bytes]()

		@Runner
		async def test_filter(session: Session) -> Accept|Reject:
			async with session.headers as headers:
				async for header in headers.restrict("X-Test"):
					header_list.append(header.value)
					if header.value == b"ham":
						return Reject()
			return Accept()

		async with MockMessageStream.connected(test_filter) as stream:
			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"spam"), Continue)
				await stream.send_and_expect(Header("X-Test", b"ham"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Reject)
			assert header_list == [b"spam", b"ham"]

			del header_list[:]

			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"eggs"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Accept)
			assert header_list == [b"eggs"]

	async def test_multi_reject(self) -> None:
		"""
		Check that a rejection from one filter of several rejects the current message
		"""
		async def test_filter1(session: Session) -> Accept|Reject:
			async with session.headers as headers:
				async for header in headers.restrict("X-Test"):
					if header.value == b"ham":
						return Reject()
			return Accept()

		async def test_filter2(session: Session) -> Accept|Reject:
			await session.headers.collect()
			return Accept()

		async with MockMessageStream.connected(Runner(test_filter1, test_filter2)) as stream:
			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"spam"), Continue)
				await stream.send_and_expect(Header("X-Test", b"ham"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Reject)

			async with stream.envelope(b"test@example.com", b"test@example.com"):
				await stream.send_and_expect(Header("X-Test", b"eggs"), Continue)
				await stream.send_and_expect(EndOfHeaders(), Accept)