Commit 30f6a85d authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Wrap test filters to run in a Session context

parent a5ee4472
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from kilter.service import *

from . import AsyncTestCase
from .mock_editor import MockEditor
from .util_session import with_session

LOCALHOST = IPv4Address("127.0.0.1")

@@ -23,6 +24,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = b""

		@with_session(session)
		async def test_filter() -> None:
			nonlocal result
			async with session.body as body:
@@ -46,6 +48,7 @@ class HeaderAccessorTests(AsyncTestCase):
		result1 = b""
		result2 = b""

		@with_session(session)
		async def test_filter() -> None:
			nonlocal result1
			nonlocal result2
@@ -86,6 +89,7 @@ class HeaderAccessorTests(AsyncTestCase):
			return s.content == o.content
		ReplaceBody.__eq__ = _eq  # type: ignore

		@with_session(session)
		async def test_filter() -> None:
			await session.body.write(b"A new message")

+18 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from kilter.service.session import Phase

from . import AsyncTestCase
from .mock_editor import MockEditor
from .util_session import with_session

LOCALHOST = IPv4Address("127.0.0.1")

@@ -25,6 +26,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = []

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				async for header in headers:
@@ -51,6 +53,7 @@ class HeaderAccessorTests(AsyncTestCase):
		result1 = []
		result2 = []

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				async for header in headers:
@@ -88,6 +91,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.collect()

@@ -114,6 +118,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.collect()

@@ -140,6 +145,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = []

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				async for header in headers.restrict("Spam", "Ham"):
@@ -164,6 +170,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.collect()
			await session.headers.delete(Header("Spam", b"spam?"))
@@ -196,6 +203,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.update(Header("Spam", b"spam?"), b"no spam!")
			async with session.headers as headers:
@@ -228,6 +236,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.insert(Header("Ham", b"and eggs"), START)
			async with session.headers as headers:
@@ -257,6 +266,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.insert(Header("Ham", b"and eggs"), END)
			async with session.headers as headers:
@@ -286,6 +296,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.insert(
				Header("Ham", b"and eggs"),
@@ -318,6 +329,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.insert(
				Header("Ham", b"and eggs"),
@@ -350,6 +362,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.insert(
				Header("Ham", b"and eggs"),
@@ -382,6 +395,7 @@ class HeaderAccessorTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)
		result = []

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.insert(
				Header("Ham", b"and eggs"),
@@ -420,6 +434,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				assert Header("From", b"test@example.com") == await headers.asend()
@@ -438,6 +453,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				await headers.asend()
@@ -459,6 +475,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				await headers.asend()
@@ -481,6 +498,7 @@ class HeaderAccessorTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			async with session.headers as headers:
				await headers.asend()
+18 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from kilter.service.session import Session

from . import AsyncTestCase
from .mock_editor import MockEditor
from .util_session import with_session

LOCALHOST = IPv4Address("127.0.0.1")

@@ -63,6 +64,7 @@ class SessionTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = "spam"

		@with_session(session)
		async def test_filter() -> None:
			nonlocal result
			result = await session.envelope_from()
@@ -83,6 +85,7 @@ class SessionTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = "spam"

		@with_session(session)
		async def test_filter() -> None:
			nonlocal result
			result = await session.helo()
@@ -100,6 +103,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			await session.envelope_from()
			with self.assertRaises(RuntimeError) as acm:
@@ -118,6 +122,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			with self.assertRaises(RuntimeError) as acm:
				await session.helo()
@@ -135,6 +140,7 @@ class SessionTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = "spam"

		@with_session(session)
		async def test_filter() -> None:
			nonlocal result
			result = await session.envelope_from()
@@ -152,6 +158,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.collect()
			with self.assertRaises(RuntimeError) as acm:
@@ -170,6 +177,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			with self.assertRaises(RuntimeError) as acm:
				await session.envelope_from()
@@ -187,6 +195,7 @@ class SessionTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = []

		@with_session(session)
		async def test_filter() -> None:
			async for rcpt in session.envelope_recipients():
				result.append(rcpt)
@@ -207,6 +216,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.collect()
			with self.assertRaises(RuntimeError) as acm:
@@ -227,6 +237,7 @@ class SessionTests(AsyncTestCase):
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())
		result = []

		@with_session(session)
		async def test_filter() -> None:
			result.append(await session.extension("SPAM"))
			result.append(await session.extension("MAIL"))
@@ -256,6 +267,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			await session.headers.collect()
			with self.assertRaises(RuntimeError) as acm:
@@ -274,6 +286,7 @@ class SessionTests(AsyncTestCase):
		"""
		session = Session(Connect("example.com", LOCALHOST, 1025), MockEditor())

		@with_session(session)
		async def test_filter() -> None:
			with self.assertRaises(RuntimeError) as acm:
				await session.extension("TEST")
@@ -291,6 +304,7 @@ class SessionTests(AsyncTestCase):
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			await session.change_sender("test@example.com")
@@ -314,6 +328,7 @@ class SessionTests(AsyncTestCase):
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			await session.add_recipient("test@example.com")
@@ -337,6 +352,7 @@ class SessionTests(AsyncTestCase):
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		@with_session(session)
		async def test_filter() -> None:
			assert session.phase == Phase.CONNECT
			await session.remove_recipient("test@example.com")
@@ -358,6 +374,7 @@ class SessionTests(AsyncTestCase):
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		@with_session(session)
		async def test_filter() -> None:
			self.assertDictEqual(session.macros, {})
			await session.helo()
@@ -381,6 +398,7 @@ class SessionTests(AsyncTestCase):
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		@with_session(session)
		async def test_filter() -> None:
			await trio.sleep(0.1)
			assert await session.helo() == "test.example.com"

tests/util_session.py

0 → 100644
+30 −0
Original line number Diff line number Diff line
from __future__ import annotations

from collections.abc import Coroutine as _Coroutine
from functools import wraps
from typing import Any
from typing import Callable
from typing import TypeAlias
from typing import TypeVar

from kilter.service import Session

T = TypeVar("T")
Decorator: TypeAlias = Callable[[T], T]
Coroutine: TypeAlias = _Coroutine[Any, Any, T]


def with_session(session: Session) -> Decorator[Callable[[], Coroutine[None]]]:
	"""
	Run an async filter function within a Session context

	This is for testing only; the Session instance is assumed to be available to the filter
	function as a closure variable.
	"""
	def deco(func: Callable[[], Coroutine[None]]) -> Callable[[], Coroutine[None]]:
		@wraps(func)
		async def wrapper() -> None:
			async with session:
				await func()
		return wrapper
	return deco