Verified Commit d1deed91 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Improve multi task Session use with ContextVars

parent 46aae465
Loading
Loading
Loading
Loading
+43 −27
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from collections.abc import Sequence
from contextvars import ContextVar
from dataclasses import dataclass
from enum import Enum
from ipaddress import IPv4Address
@@ -231,8 +232,6 @@ class Session:
		Deliver a message (or its contents) to a task waiting for it
		"""
		match message:
			case Body() if self.body.skip:
				return Skip
			case Macro():
				self.macros.update(message.macros)
				return Continue  # not strictly necessary, but type checker needs something
@@ -254,7 +253,7 @@ class Session:
		async with self.broadcast:
			self.phase = phase  # phase attribute must be modified in locked context
		await self.broadcast.send(message)
		return Skip if self.phase == Phase.BODY and self.body.skip else Continue
		return Skip if self.phase == Phase.BODY and self.body.should_skip() else Continue

	async def helo(self) -> str:
		"""
@@ -366,14 +365,18 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		self.session = session
		self.sender = sender
		self._table = list[Header]()
		self._aiter: AsyncGenerator[Header, None]
		self._aiter = ContextVar[HeaderIterator|None]("header-iter")

	async def __aenter__(self) -> HeaderIterator:
		self._aiter = HeaderIterator(self.__aiter())
		return self._aiter
		if not (aiter := self._aiter.get(None)):
			aiter = HeaderIterator(self.__aiter())
			self._aiter.set(aiter)
		return aiter

	async def __aexit__(self, *_: object) -> None:
		await self._aiter.aclose()
		if aiter := self._aiter.get():
			await aiter.aclose()
		self._aiter.set(None)

	async def __aiter(self) -> AsyncGenerator[Header, None]:
		# yield from cached headers first; allows multiple tasks to access the headers
@@ -415,7 +418,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
				case Header() as header:
					header.freeze()
					self._table.append(header)
				case EndOfHeaders():
				case _:
					return

	async def delete(self, header: Header) -> None:
@@ -519,34 +522,48 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
	def __init__(self, session: Session, sender: Sender):
		self.session = session
		self.sender = sender
		self.skip = False
		self._aiter: AsyncGenerator[memoryview, None] | None = None
		self._entered = 0
		self._skip = False
		self._aiter = ContextVar[AsyncGenerator[memoryview, None] | None]("body-iter")

	async def __aenter__(self) -> AsyncIterator[memoryview]:
		if self._aiter is None:
			self._aiter = self.__aiter()
		return self._aiter
		if not (aiter := self._aiter.get(None)):
			aiter = self.__aiter()
			self._aiter.set(aiter)
		self._entered += 1
		return aiter

	async def __aexit__(self, *_: object) -> None:
		assert self._aiter is not None
		await self._aiter.aclose()
		self._aiter = None
		if aiter := self._aiter.get(None):
			await aiter.aclose()
		self._aiter.set(None)
		self._entered -= 1

	async def __aiter(self) -> AsyncGenerator[memoryview, None]:
		while self.session.phase <= Phase.BODY:
			match (await self.session.broadcast.receive()):
				case Body() as body:
					try:
					assert isinstance(body.content, memoryview)
					yield body.content
					except GeneratorExit:
						self.skip = True
						raise
				case EndOfMessage() as eom:
					if not self.skip:
					assert isinstance(eom.content, memoryview)
					yield eom.content

	def should_skip(self) -> bool:
		"""
		Return whether the message body should be skipped

		The body should be skipped when there are no active contexts.  All correctly
		implemented filters should have started a context before the first `Body` message.

		Once this method returns `True` it becomes "locked in" and will always return `True`
		after.
		"""
		if self._skip:
			return True
		self._skip = self._entered == 0
		return self._skip

	async def write(self, chunk: bytes) -> None:
		"""
		Request that chunks of a new message body are sent to the MTA
@@ -555,7 +572,7 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
		instance as an async context (`async with`); doing so may cause a warning to be
		issued and the rest of the message body to be skipped.
		"""
		if self._aiter is not None and not self.skip:
		if self._aiter.get(None):
			warn(
				"it looks as if BodyAccessor.write() was called on an instance from within "
				"it's own async context",
@@ -568,7 +585,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
async def _until_editable(session: Session) -> None:
	if session.phase == Phase.POST:
		return
	session.body.skip = True
	while session.phase < Phase.POST:
		if session.phase == Phase.HEADERS:
			await session.headers.collect()
+1 −5
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ LOCALHOST = IPv4Address("127.0.0.1")
THIS_MODULE = Path(__file__)


class HeaderAccessorTests(AsyncTestCase):
class BodyAccessorTests(AsyncTestCase):
	"""
	Tests for the kilter.service.session.HeaderAccessor class
	"""
@@ -62,10 +62,6 @@ class HeaderAccessorTests(AsyncTestCase):
						break
					result1 += chunk

			async with session.body as body:
				async for chunk in body:
					result2 += chunk

		async with trio.open_nursery() as tg:
			tg.start_soon(test_filter)
			await trio.testing.wait_all_tasks_blocked()