Commit 36471d3d authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Invert responsibility of locking Broadcast in Session

This adds the requirement that filters are run in a Session context.

Closes #5
parent 30f6a85d
Loading
Loading
Loading
Loading
+63 −70
Original line number Diff line number Diff line
@@ -231,9 +231,11 @@ class Session:
		self.phase = Phase.CONNECT

	async def __aenter__(self: Self) -> Self:
		await self._broadcast.__aenter__()
		return self

	async def __aexit__(self, *_: object) -> None:
		await self._broadcast.__aexit__(None, None, None)
		# on session close, wake up any remaining deliver() awaitables
		await self._broadcast.aclose()

@@ -248,15 +250,17 @@ class Session:
				self.macros.update(message.macros)
				return Continue  # not strictly necessary, but type checker needs something
			case Helo():
				self.phase = Phase.MAIL
				phase = Phase.MAIL
			case EnvelopeFrom() | EnvelopeRecipient() | Unknown():
				self.phase = Phase.ENVELOPE
				phase = Phase.ENVELOPE
			case Data() | Header():
				self.phase = Phase.HEADERS
				phase = Phase.HEADERS
			case EndOfHeaders() | Body():
				self.phase = Phase.BODY
				phase = Phase.BODY
			case EndOfMessage():  # pragma: no-branch
				self.phase = Phase.POST
				phase = Phase.POST
		async with self._broadcast:
			self.phase = phase  # phase attribute must be modified in locked context
		await self._broadcast.send(message)
		return Skip if self.phase == Phase.BODY and self.skip else Continue

@@ -269,7 +273,6 @@ class Session:
				"Session.helo() must be awaited before any other async features of a "
				"Session",
			)
		async with self._broadcast:
		while self.phase <= Phase.CONNECT:
			message = await self._broadcast.receive()
			if isinstance(message, Helo):
@@ -287,7 +290,6 @@ class Session:
			raise RuntimeError(
				"Session.envelope_from() may only be awaited before the ENVELOPE phase",
			)
		async with self._broadcast:
		while self.phase <= Phase.MAIL:
			message = await self._broadcast.receive()
			if isinstance(message, EnvelopeFrom):
@@ -305,7 +307,6 @@ class Session:
			raise RuntimeError(
				"Session.envelope_from() may only be awaited before the HEADERS phase",
			)
		async with self._broadcast:
		while self.phase <= Phase.ENVELOPE:
			message = await self._broadcast.receive()
			if isinstance(message, EnvelopeRecipient):
@@ -319,7 +320,6 @@ class Session:
			raise RuntimeError(
				"Session.extension() may only be awaited before the HEADERS phase",
			)
		async with self._broadcast:
		while self.phase <= Phase.ENVELOPE:
			message = await self._broadcast.receive()
			match message:
@@ -383,7 +383,6 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		await self._aiter.aclose()

	async def __aiter(self) -> AsyncGenerator[Header, None]:
		async with self.session._broadcast:
		# yield from cached headers first; allows multiple tasks to access the headers
		# in an uncoordinated manner; note the broadcaster is locked at this point
		for header in self._table:
@@ -395,7 +394,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
					try:
						yield header
					except GeneratorExit:
							await self._collect()
						await self.collect()
						raise
				case EndOfHeaders():
					return
@@ -407,10 +406,6 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		Calling this method before the `Phase.BODY` phase allows later processing of headers
		(after the HEADER phase) without the need for an empty loop.
		"""
		async with self.session._broadcast:
			await self._collect()

	async def _collect(self) -> None:
		# note the similarities between this and __aiter; the difference is no mutex or
		# yields
		while self.session.phase <= Phase.HEADERS:
@@ -530,7 +525,6 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
		await self._aiter.aclose()

	async def __aiter(self) -> AsyncGenerator[memoryview, None]:
		async with self.session._broadcast:
		while self.session.phase <= Phase.BODY:
			match (await self.session._broadcast.receive()):
				case Body() as body:
@@ -554,6 +548,5 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
async def _until_editable(session: Session) -> None:
	if session.phase == Phase.POST:
		return
	async with session._broadcast:
	while session.phase < Phase.POST:
		await session._broadcast.receive()