Verified Commit 062206b9 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Replace sender async generator with plain class

parent 719090b4
Loading
Loading
Loading
Loading
+30 −22
Original line number Diff line number Diff line
@@ -16,7 +16,6 @@ from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import AsyncGenerator
from typing import Final
from typing import TypeAlias
from warnings import warn
@@ -27,20 +26,21 @@ from async_generator import aclosing
from typing_extensions import Self

from kilter.protocol.buffer import SimpleBuffer
from kilter.protocol.core import EditMessage
from kilter.protocol.core import EventMessage
from kilter.protocol.core import FilterMessage
from kilter.protocol.core import FilterProtocol
from kilter.protocol.core import ResponseMessage
from kilter.protocol.messages import ProtocolFlags
from kilter.protocol.messages import *

from .options import get_flags
from .options import get_macros
from .session import *
from .session import Aborted
from .session import Filter
from .session import Session
from .util import Broadcast
from .util import qualname

MessageChannel: TypeAlias = anyio.abc.ObjectStream[Message]
Sender: TypeAlias = AsyncGenerator[None, ResponseMessage|EditMessage|Negotiate|Skip]

kiB: Final = 2**10
MiB: Final = 2**20
@@ -80,6 +80,26 @@ class _Broadcast(Broadcast[EventMessage]):
			self.task_status = None


class Sender:
	"""
	Concrete implementation of `kilter.service.session.Sender`
	"""

	def __init__(self, client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> None:
		self.client = client
		self.proto = proto

	async def send(self, message: FilterMessage) -> None:
		"""
		Encode and send a message to the client stream
		"""
		buffer = SimpleBuffer(1*kiB)
		self.proto.write_to(buffer, message)
		await self.client.send(buffer[:])
		if __debug__:
			_logger.debug(f"sent: {message}")


class Runner:
	"""
	A filter runner that coordinates passing data between a stream and multiple filters
@@ -100,15 +120,13 @@ class Runner:
		"""
		buff = SimpleBuffer(1*MiB)
		proto = FilterProtocol(abort_on_unknown=True)
		sender = _sender(client, proto)
		sender = Sender(client, proto)
		macro: Macro|None = None
		aborted = False

		await sender.asend(None)  # type: ignore # initialise

		async with (
			aclosing(client),
			anyio.create_task_group() as tasks,
			aclosing(sender), aclosing(client),
			_TaskRunner(tasks) as runner,
		):
			while 1:
@@ -126,7 +144,7 @@ class Runner:
						_logger.debug(f"received: {message}")
					match message:
						case Negotiate():
							await sender.asend(await self._negotiate(message))
							await sender.send(await self._negotiate(message))
						case Macro() as macro:
							# Note that this Macro will hang around as "macro"; this is for
							# Connect messages.
@@ -143,7 +161,7 @@ class Runner:
									self.filters.remove(notif.filter)
								case c_resp if needs_response:
									assert c_resp is not None and not isinstance(c_resp, _CloseFilter)
									await sender.asend(c_resp)
									await sender.send(c_resp)
								case c_resp:
									raise RuntimeError(f"unexpected response: {c_resp}")
						case Abort():
@@ -164,7 +182,7 @@ class Runner:
									self.filters.remove(notif.filter)
								case resp if needs_response:
									assert resp is not None and not isinstance(resp, _CloseFilter)
									await sender.asend(resp)
									await sender.send(resp)
								case resp:
									raise RuntimeError(f"unexpected response: {resp}")

@@ -390,13 +408,3 @@ def _make_message_channel() -> tuple[MessageChannel, MessageChannel]:
	lsend, rrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	rsend, lrecv = anyio.create_memory_object_stream(1, Message)  # type: ignore
	return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv)


async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender:
	buff = SimpleBuffer(1*kiB)
	while 1:
		proto.write_to(buff, (message := (yield)))
		if __debug__:
			_logger.debug(f"sent: {message}")
		await client.send(buff[:])
		del buff[:]
+23 −15
Original line number Diff line number Diff line
# Copyright 2022-2024 Dominik Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2022-2025 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -47,6 +47,14 @@ class Filter(Protocol):
	async def __call__(self, session: Session, /) -> ResponseMessage: ...  # noqa: D102


class Sender(Protocol):
	"""
	Senders asynchronously handle sending messages with their "send" method
	"""

	async def send(self, message: EditMessage) -> None: ...  # noqa: D102


class Phase(int, Enum):
	"""
	Session phases indicate what messages to expect and are impacted by received messages
@@ -191,14 +199,14 @@ class Session:
	def __init__(
		self,
		connmsg: Connect,
		sender: AsyncGenerator[None, EditMessage],
		sender: Sender,
		broadcast: util.Broadcast[EventMessage]|None = None,
	):
		self.host = connmsg.hostname
		self.address = connmsg.address
		self.port = connmsg.port

		self._editor = sender
		self.sender = sender
		self.broadcast = broadcast or util.Broadcast[EventMessage]()

		self.macros = dict[str, str]()
@@ -326,14 +334,14 @@ class Session:
		Move onto the `Phase.POST` phase and instruct the MTA to change the sender address
		"""
		await _until_editable(self)
		await self._editor.asend(ChangeSender(sender, args or None))
		await self.sender.send(ChangeSender(sender, args or None))

	async def add_recipient(self, recipient: str, args: str = "") -> None:
		"""
		Move onto the `Phase.POST` phase and instruct the MTA to add a new recipient address
		"""
		await _until_editable(self)
		await self._editor.asend(
		await self.sender.send(
			AddRecipientPar(recipient, args) if args else AddRecipient(recipient),
		)

@@ -342,7 +350,7 @@ class Session:
		Move onto the `Phase.POST` phase and instruct the MTA to remove a recipient address
		"""
		await _until_editable(self)
		await self._editor.asend(RemoveRecipient(recipient))
		await self.sender.send(RemoveRecipient(recipient))


class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
@@ -354,9 +362,9 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
	entered.
	"""

	def __init__(self, session: Session, sender: AsyncGenerator[None, EditMessage]):
	def __init__(self, session: Session, sender: Sender):
		self.session = session
		self._editor = sender
		self.sender = sender
		self._table = list[Header]()
		self._aiter: AsyncGenerator[Header, None]

@@ -417,7 +425,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		await self.collect()
		await _until_editable(self.session)
		index = _index_by_name(self._table, header)
		await self._editor.asend(ChangeHeader(index, header.name, b""))
		await self.sender.send(ChangeHeader(index, header.name, b""))
		self._table.remove(header)

	async def update(self, header: Header, value: bytes) -> None:
@@ -427,7 +435,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		await self.collect()
		await _until_editable(self.session)
		index = _index_by_name(self._table, header)
		await self._editor.asend(ChangeHeader(index, header.name, value))
		await self.sender.send(ChangeHeader(index, header.name, value))
		index = self._table.index(header)
		self._table[index].value = value

@@ -452,10 +460,10 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
			case _:
				raise TypeError("Expect a Position")
		if index >= len(self._table):
			await self._editor.asend(AddHeader(header.name, header.value))
			await self.sender.send(AddHeader(header.name, header.value))
			self._table.append(header)
		else:
			await self._editor.asend(InsertHeader(index + 1, header.name, header.value))
			await self.sender.send(InsertHeader(index + 1, header.name, header.value))
			self._table.insert(index, header)


@@ -508,9 +516,9 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
	entered.
	"""

	def __init__(self, session: Session, sender: AsyncGenerator[None, EditMessage]):
	def __init__(self, session: Session, sender: Sender):
		self.session = session
		self._editor = sender
		self.sender = sender
		self.skip = False
		self._aiter: AsyncGenerator[memoryview, None] | None = None

@@ -554,7 +562,7 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
				stacklevel=2,
			)
		await _until_editable(self.session)
		await self._editor.asend(ReplaceBody(chunk))
		await self.sender.send(ReplaceBody(chunk))


async def _until_editable(session: Session) -> None:
+4 −15
Original line number Diff line number Diff line
from collections.abc import AsyncGenerator
from types import TracebackType
from unittest.mock import AsyncMock

from kilter.protocol import Message


class MockEditor(AsyncGenerator[None, Message]):
class MockEditor:
	"""
	A mock of the interface used for sending update messages to the MTA
	"""

	def __init__(self) -> None:
		self._asend = AsyncMock()
		self._athrow = AsyncMock()
		self.mock_send = AsyncMock()

	async def asend(self, value: Message) -> None:  # noqa: D102
		await self._asend(value)

	async def athrow(  # noqa: D102
		self,
		e: type[BaseException]|BaseException,
		m: object = ...,
		t: TracebackType|None = None, /,
	) -> None:
		await self._athrow(self, e, m, t)
	async def send(self, value: Message) -> None:  # noqa: D102
		await self.mock_send(value)
+2 −2
Original line number Diff line number Diff line
@@ -93,7 +93,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await trio.testing.wait_all_tasks_blocked()
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ReplaceBody(b"A new message"))
		sender.mock_send.assert_awaited_with(ReplaceBody(b"A new message"))

	async def test_write_in_iter_context(self) -> None:
		"""
@@ -114,4 +114,4 @@ class HeaderAccessorTests(AsyncTestCase):
			await trio.testing.wait_all_tasks_blocked()
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ReplaceBody(b"A new message"))
		sender.mock_send.assert_awaited_with(ReplaceBody(b"A new message"))
+8 −8
Original line number Diff line number Diff line
@@ -186,7 +186,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ChangeHeader(2, "Spam", b""))
		sender.mock_send.assert_awaited_with(ChangeHeader(2, "Spam", b""))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Eggs", b"and spam"),
@@ -218,7 +218,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ChangeHeader(2, "Spam", b"no spam!"))
		sender.mock_send.assert_awaited_with(ChangeHeader(2, "Spam", b"no spam!"))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Spam", b"no spam!"),
@@ -249,7 +249,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(InsertHeader(1, "Ham", b"and eggs"))
		sender.mock_send.assert_awaited_with(InsertHeader(1, "Ham", b"and eggs"))
		assert result == [
			Header("Ham", b"and eggs"),
			Header("Spam", b"spam spam spam"),
@@ -279,7 +279,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(AddHeader("Ham", b"and eggs"))
		sender.mock_send.assert_awaited_with(AddHeader("Ham", b"and eggs"))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Eggs", b"and spam"),
@@ -312,7 +312,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(InsertHeader(2, "Ham", b"and eggs"))
		sender.mock_send.assert_awaited_with(InsertHeader(2, "Ham", b"and eggs"))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Ham", b"and eggs"),
@@ -345,7 +345,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(InsertHeader(2, "Ham", b"and eggs"))
		sender.mock_send.assert_awaited_with(InsertHeader(2, "Ham", b"and eggs"))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Ham", b"and eggs"),
@@ -378,7 +378,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(AddHeader("Ham", b"and eggs"))
		sender.mock_send.assert_awaited_with(AddHeader("Ham", b"and eggs"))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Eggs", b"and spam"),
@@ -415,7 +415,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_has_awaits([
		sender.mock_send.assert_has_awaits([
			call(InsertHeader(2, "Ham", b"and eggs")),
			call(InsertHeader(3, "Ham", b"and spam")),
		])
Loading