Commit ba0b5eb0 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Warn if Session.body.write called in async iter context

A warning will be issued in the following situation:

```
>>> def filter(session: Session):
...   async with session.body as body_iter:
...     await session.body.write(b"spam")

>>> test_runner(filter)
UserWarning: it looks as if BodyAccessor.write() was called on an
instance from within it's own async context
  await session.body.write(b"spam")
```
parent c3bf8f3c
Loading
Loading
Loading
Loading
+18 −2
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ from typing import Protocol
from typing import TypeAlias
from typing import TypeVar
from typing import Union
from warnings import warn

from ..protocol.messages import *
from . import util
@@ -526,13 +527,17 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
		self.session = session
		self._editor = sender
		self.skip = False
		self._aiter: AsyncGenerator[memoryview, None] | None = None

	async def __aenter__(self) -> AsyncIterator[memoryview]:
		if self._aiter is None:
			self._aiter = self.__aiter()
		return self._aiter

	async def __aexit__(self, *_: object) -> None:
		assert self._aiter is not None
		await self._aiter.aclose()
		self._aiter = None

	async def __aiter(self) -> AsyncGenerator[memoryview, None]:
		while self.session.phase <= Phase.BODY:
@@ -550,7 +555,17 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
	async def write(self, chunk: bytes) -> None:
		"""
		Request that chunks of a new message body are sent to the MTA

		This method should not be called from within the scope created by using it's
		instance as an async context (`async with`); doing so may cause a warning to be
		issued and the rest of the message body to be skipped.
		"""
		if self._aiter is not None and not self.skip:
			warn(
				"it looks as if BodyAccessor.write() was called on an instance from within "
				"it's own async context",
				stacklevel=2,
			)
		await _until_editable(self.session)
		await self._editor.asend(ReplaceBody(chunk))

@@ -558,5 +573,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
async def _until_editable(session: Session) -> None:
	if session.phase == Phase.POST:
		return
	session.body.skip = True
	while session.phase < Phase.POST:
		await session.broadcast.receive()
+30 −0
Original line number Diff line number Diff line
from ipaddress import IPv4Address
from pathlib import Path

import trio.testing

@@ -10,6 +11,7 @@ from .mock_editor import MockEditor
from .util_session import with_session

LOCALHOST = IPv4Address("127.0.0.1")
THIS_MODULE = Path(__file__)


class HeaderAccessorTests(AsyncTestCase):
@@ -99,3 +101,31 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ReplaceBody(b"A new message"))

	async def test_write_in_iter_context(self) -> None:
		"""
		Check that `write()` in an async with context issues a warning
		"""
		sender = MockEditor()
		session = Session(Connect("example.com", LOCALHOST, 1025), sender)

		# Temporary hack for missing equality check in kilter.protocol
		def _eq(s: ReplaceBody, o: object) -> bool:
			if not isinstance(o, type(s)):
				return NotImplemented
			return s.content == o.content
		ReplaceBody.__eq__ = _eq  # type: ignore

		@with_session(session)
		async def test_filter() -> None:
			async with session.body:
				with self.assertWarns(UserWarning) as cm:
					await session.body.write(b"A new message")
				assert THIS_MODULE.samefile(cm.filename)

		async with trio.open_nursery() as tg:
			tg.start_soon(test_filter)
			await trio.testing.wait_all_tasks_blocked()
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ReplaceBody(b"A new message"))