Loading kilter/service/session.py +43 −27 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator from collections.abc import AsyncIterator from collections.abc import Sequence from contextvars import ContextVar from dataclasses import dataclass from enum import Enum from ipaddress import IPv4Address Loading Loading @@ -231,8 +232,6 @@ class Session: Deliver a message (or its contents) to a task waiting for it """ match message: case Body() if self.body.skip: return Skip case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something Loading @@ -254,7 +253,7 @@ class Session: async with self.broadcast: self.phase = phase # phase attribute must be modified in locked context await self.broadcast.send(message) return Skip if self.phase == Phase.BODY and self.body.skip else Continue return Skip if self.phase == Phase.BODY and self.body.should_skip() else Continue async def helo(self) -> str: """ Loading Loading @@ -366,14 +365,18 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): self.session = session self.sender = sender self._table = list[Header]() self._aiter: AsyncGenerator[Header, None] self._aiter = ContextVar[HeaderIterator|None]("header-iter") async def __aenter__(self) -> HeaderIterator: self._aiter = HeaderIterator(self.__aiter()) return self._aiter if not (aiter := self._aiter.get(None)): aiter = HeaderIterator(self.__aiter()) self._aiter.set(aiter) return aiter async def __aexit__(self, *_: object) -> None: await self._aiter.aclose() if aiter := self._aiter.get(): await aiter.aclose() self._aiter.set(None) async def __aiter(self) -> AsyncGenerator[Header, None]: # yield from cached headers first; allows multiple tasks to access the headers Loading Loading @@ -415,7 +418,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): case Header() as header: header.freeze() self._table.append(header) case EndOfHeaders(): case _: return async def delete(self, header: Header) -> None: Loading Loading @@ -519,34 +522,48 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): def __init__(self, session: Session, sender: Sender): self.session = session self.sender = sender self.skip = False self._aiter: AsyncGenerator[memoryview, None] | None = None self._entered = 0 self._skip = False self._aiter = ContextVar[AsyncGenerator[memoryview, None] | None]("body-iter") async def __aenter__(self) -> AsyncIterator[memoryview]: if self._aiter is None: self._aiter = self.__aiter() return self._aiter if not (aiter := self._aiter.get(None)): aiter = self.__aiter() self._aiter.set(aiter) self._entered += 1 return aiter async def __aexit__(self, *_: object) -> None: assert self._aiter is not None await self._aiter.aclose() self._aiter = None if aiter := self._aiter.get(None): await aiter.aclose() self._aiter.set(None) self._entered -= 1 async def __aiter(self) -> AsyncGenerator[memoryview, None]: while self.session.phase <= Phase.BODY: match (await self.session.broadcast.receive()): case Body() as body: try: assert isinstance(body.content, memoryview) yield body.content except GeneratorExit: self.skip = True raise case EndOfMessage() as eom: if not self.skip: assert isinstance(eom.content, memoryview) yield eom.content def should_skip(self) -> bool: """ Return whether the message body should be skipped The body should be skipped when there are no active contexts. All correctly implemented filters should have started a context before the first `Body` message. Once this method returns `True` it becomes "locked in" and will always return `True` after. """ if self._skip: return True self._skip = self._entered == 0 return self._skip async def write(self, chunk: bytes) -> None: """ Request that chunks of a new message body are sent to the MTA Loading @@ -555,7 +572,7 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): instance as an async context (`async with`); doing so may cause a warning to be issued and the rest of the message body to be skipped. """ if self._aiter is not None and not self.skip: if self._aiter.get(None): warn( "it looks as if BodyAccessor.write() was called on an instance from within " "it's own async context", Loading @@ -568,7 +585,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): async def _until_editable(session: Session) -> None: if session.phase == Phase.POST: return session.body.skip = True while session.phase < Phase.POST: if session.phase == Phase.HEADERS: await session.headers.collect() Loading tests/test_body_accessor.py +1 −5 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ LOCALHOST = IPv4Address("127.0.0.1") THIS_MODULE = Path(__file__) class HeaderAccessorTests(AsyncTestCase): class BodyAccessorTests(AsyncTestCase): """ Tests for the kilter.service.session.HeaderAccessor class """ Loading Loading @@ -62,10 +62,6 @@ class HeaderAccessorTests(AsyncTestCase): break result1 += chunk async with session.body as body: async for chunk in body: result2 += chunk async with trio.open_nursery() as tg: tg.start_soon(test_filter) await trio.testing.wait_all_tasks_blocked() Loading Loading
kilter/service/session.py +43 −27 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator from collections.abc import AsyncIterator from collections.abc import Sequence from contextvars import ContextVar from dataclasses import dataclass from enum import Enum from ipaddress import IPv4Address Loading Loading @@ -231,8 +232,6 @@ class Session: Deliver a message (or its contents) to a task waiting for it """ match message: case Body() if self.body.skip: return Skip case Macro(): self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something Loading @@ -254,7 +253,7 @@ class Session: async with self.broadcast: self.phase = phase # phase attribute must be modified in locked context await self.broadcast.send(message) return Skip if self.phase == Phase.BODY and self.body.skip else Continue return Skip if self.phase == Phase.BODY and self.body.should_skip() else Continue async def helo(self) -> str: """ Loading Loading @@ -366,14 +365,18 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): self.session = session self.sender = sender self._table = list[Header]() self._aiter: AsyncGenerator[Header, None] self._aiter = ContextVar[HeaderIterator|None]("header-iter") async def __aenter__(self) -> HeaderIterator: self._aiter = HeaderIterator(self.__aiter()) return self._aiter if not (aiter := self._aiter.get(None)): aiter = HeaderIterator(self.__aiter()) self._aiter.set(aiter) return aiter async def __aexit__(self, *_: object) -> None: await self._aiter.aclose() if aiter := self._aiter.get(): await aiter.aclose() self._aiter.set(None) async def __aiter(self) -> AsyncGenerator[Header, None]: # yield from cached headers first; allows multiple tasks to access the headers Loading Loading @@ -415,7 +418,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): case Header() as header: header.freeze() self._table.append(header) case EndOfHeaders(): case _: return async def delete(self, header: Header) -> None: Loading Loading @@ -519,34 +522,48 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): def __init__(self, session: Session, sender: Sender): self.session = session self.sender = sender self.skip = False self._aiter: AsyncGenerator[memoryview, None] | None = None self._entered = 0 self._skip = False self._aiter = ContextVar[AsyncGenerator[memoryview, None] | None]("body-iter") async def __aenter__(self) -> AsyncIterator[memoryview]: if self._aiter is None: self._aiter = self.__aiter() return self._aiter if not (aiter := self._aiter.get(None)): aiter = self.__aiter() self._aiter.set(aiter) self._entered += 1 return aiter async def __aexit__(self, *_: object) -> None: assert self._aiter is not None await self._aiter.aclose() self._aiter = None if aiter := self._aiter.get(None): await aiter.aclose() self._aiter.set(None) self._entered -= 1 async def __aiter(self) -> AsyncGenerator[memoryview, None]: while self.session.phase <= Phase.BODY: match (await self.session.broadcast.receive()): case Body() as body: try: assert isinstance(body.content, memoryview) yield body.content except GeneratorExit: self.skip = True raise case EndOfMessage() as eom: if not self.skip: assert isinstance(eom.content, memoryview) yield eom.content def should_skip(self) -> bool: """ Return whether the message body should be skipped The body should be skipped when there are no active contexts. All correctly implemented filters should have started a context before the first `Body` message. Once this method returns `True` it becomes "locked in" and will always return `True` after. """ if self._skip: return True self._skip = self._entered == 0 return self._skip async def write(self, chunk: bytes) -> None: """ Request that chunks of a new message body are sent to the MTA Loading @@ -555,7 +572,7 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): instance as an async context (`async with`); doing so may cause a warning to be issued and the rest of the message body to be skipped. """ if self._aiter is not None and not self.skip: if self._aiter.get(None): warn( "it looks as if BodyAccessor.write() was called on an instance from within " "it's own async context", Loading @@ -568,7 +585,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): async def _until_editable(session: Session) -> None: if session.phase == Phase.POST: return session.body.skip = True while session.phase < Phase.POST: if session.phase == Phase.HEADERS: await session.headers.collect() Loading
tests/test_body_accessor.py +1 −5 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ LOCALHOST = IPv4Address("127.0.0.1") THIS_MODULE = Path(__file__) class HeaderAccessorTests(AsyncTestCase): class BodyAccessorTests(AsyncTestCase): """ Tests for the kilter.service.session.HeaderAccessor class """ Loading Loading @@ -62,10 +62,6 @@ class HeaderAccessorTests(AsyncTestCase): break result1 += chunk async with session.body as body: async for chunk in body: result2 += chunk async with trio.open_nursery() as tg: tg.start_soon(test_filter) await trio.testing.wait_all_tasks_blocked() Loading