Loading kilter/service/session.py +63 −70 Original line number Diff line number Diff line Loading @@ -231,9 +231,11 @@ class Session: self.phase = Phase.CONNECT async def __aenter__(self: Self) -> Self: await self._broadcast.__aenter__() return self async def __aexit__(self, *_: object) -> None: await self._broadcast.__aexit__(None, None, None) # on session close, wake up any remaining deliver() awaitables await self._broadcast.aclose() Loading @@ -248,15 +250,17 @@ class Session: self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something case Helo(): self.phase = Phase.MAIL phase = Phase.MAIL case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): self.phase = Phase.ENVELOPE phase = Phase.ENVELOPE case Data() | Header(): self.phase = Phase.HEADERS phase = Phase.HEADERS case EndOfHeaders() | Body(): self.phase = Phase.BODY phase = Phase.BODY case EndOfMessage(): # pragma: no-branch self.phase = Phase.POST phase = Phase.POST async with self._broadcast: self.phase = phase # phase attribute must be modified in locked context await self._broadcast.send(message) return Skip if self.phase == Phase.BODY and self.skip else Continue Loading @@ -269,7 +273,6 @@ class Session: "Session.helo() must be awaited before any other async features of a " "Session", ) async with self._broadcast: while self.phase <= Phase.CONNECT: message = await self._broadcast.receive() if isinstance(message, Helo): Loading @@ -287,7 +290,6 @@ class Session: raise RuntimeError( "Session.envelope_from() may only be awaited before the ENVELOPE phase", ) async with self._broadcast: while self.phase <= Phase.MAIL: message = await self._broadcast.receive() if isinstance(message, EnvelopeFrom): Loading @@ -305,7 +307,6 @@ class Session: raise RuntimeError( "Session.envelope_from() may only be awaited before the HEADERS phase", ) async with self._broadcast: while self.phase <= Phase.ENVELOPE: message = await self._broadcast.receive() if isinstance(message, EnvelopeRecipient): Loading @@ -319,7 +320,6 @@ class Session: raise RuntimeError( "Session.extension() may only be awaited before the HEADERS phase", ) async with self._broadcast: while self.phase <= Phase.ENVELOPE: message = await self._broadcast.receive() match message: Loading Loading @@ -383,7 +383,6 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): await self._aiter.aclose() async def __aiter(self) -> AsyncGenerator[Header, None]: async with self.session._broadcast: # yield from cached headers first; allows multiple tasks to access the headers # in an uncoordinated manner; note the broadcaster is locked at this point for header in self._table: Loading @@ -395,7 +394,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): try: yield header except GeneratorExit: await self._collect() await self.collect() raise case EndOfHeaders(): return Loading @@ -407,10 +406,6 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): Calling this method before the `Phase.BODY` phase allows later processing of headers (after the HEADER phase) without the need for an empty loop. """ async with self.session._broadcast: await self._collect() async def _collect(self) -> None: # note the similarities between this and __aiter; the difference is no mutex or # yields while self.session.phase <= Phase.HEADERS: Loading Loading @@ -530,7 +525,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): await self._aiter.aclose() async def __aiter(self) -> AsyncGenerator[memoryview, None]: async with self.session._broadcast: while self.session.phase <= Phase.BODY: match (await self.session._broadcast.receive()): case Body() as body: Loading @@ -554,6 +548,5 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): async def _until_editable(session: Session) -> None: if session.phase == Phase.POST: return async with session._broadcast: while session.phase < Phase.POST: await session._broadcast.receive() Loading
kilter/service/session.py +63 −70 Original line number Diff line number Diff line Loading @@ -231,9 +231,11 @@ class Session: self.phase = Phase.CONNECT async def __aenter__(self: Self) -> Self: await self._broadcast.__aenter__() return self async def __aexit__(self, *_: object) -> None: await self._broadcast.__aexit__(None, None, None) # on session close, wake up any remaining deliver() awaitables await self._broadcast.aclose() Loading @@ -248,15 +250,17 @@ class Session: self.macros.update(message.macros) return Continue # not strictly necessary, but type checker needs something case Helo(): self.phase = Phase.MAIL phase = Phase.MAIL case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): self.phase = Phase.ENVELOPE phase = Phase.ENVELOPE case Data() | Header(): self.phase = Phase.HEADERS phase = Phase.HEADERS case EndOfHeaders() | Body(): self.phase = Phase.BODY phase = Phase.BODY case EndOfMessage(): # pragma: no-branch self.phase = Phase.POST phase = Phase.POST async with self._broadcast: self.phase = phase # phase attribute must be modified in locked context await self._broadcast.send(message) return Skip if self.phase == Phase.BODY and self.skip else Continue Loading @@ -269,7 +273,6 @@ class Session: "Session.helo() must be awaited before any other async features of a " "Session", ) async with self._broadcast: while self.phase <= Phase.CONNECT: message = await self._broadcast.receive() if isinstance(message, Helo): Loading @@ -287,7 +290,6 @@ class Session: raise RuntimeError( "Session.envelope_from() may only be awaited before the ENVELOPE phase", ) async with self._broadcast: while self.phase <= Phase.MAIL: message = await self._broadcast.receive() if isinstance(message, EnvelopeFrom): Loading @@ -305,7 +307,6 @@ class Session: raise RuntimeError( "Session.envelope_from() may only be awaited before the HEADERS phase", ) async with self._broadcast: while self.phase <= Phase.ENVELOPE: message = await self._broadcast.receive() if isinstance(message, EnvelopeRecipient): Loading @@ -319,7 +320,6 @@ class Session: raise RuntimeError( "Session.extension() may only be awaited before the HEADERS phase", ) async with self._broadcast: while self.phase <= Phase.ENVELOPE: message = await self._broadcast.receive() match message: Loading Loading @@ -383,7 +383,6 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): await self._aiter.aclose() async def __aiter(self) -> AsyncGenerator[Header, None]: async with self.session._broadcast: # yield from cached headers first; allows multiple tasks to access the headers # in an uncoordinated manner; note the broadcaster is locked at this point for header in self._table: Loading @@ -395,7 +394,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): try: yield header except GeneratorExit: await self._collect() await self.collect() raise case EndOfHeaders(): return Loading @@ -407,10 +406,6 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]): Calling this method before the `Phase.BODY` phase allows later processing of headers (after the HEADER phase) without the need for an empty loop. """ async with self.session._broadcast: await self._collect() async def _collect(self) -> None: # note the similarities between this and __aiter; the difference is no mutex or # yields while self.session.phase <= Phase.HEADERS: Loading Loading @@ -530,7 +525,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): await self._aiter.aclose() async def __aiter(self) -> AsyncGenerator[memoryview, None]: async with self.session._broadcast: while self.session.phase <= Phase.BODY: match (await self.session._broadcast.receive()): case Body() as body: Loading @@ -554,6 +548,5 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): async def _until_editable(session: Session) -> None: if session.phase == Phase.POST: return async with session._broadcast: while session.phase < Phase.POST: await session._broadcast.receive()