Commit 0478f3b8 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Expose Broadcast instance of a Session as public attr

parent 8f1e9065
Loading
Loading
Loading
Loading
+16 −16
Original line number Diff line number Diff line
@@ -224,7 +224,7 @@ class Session:
		self.port = connmsg.port

		self._editor = sender
		self._broadcast = broadcast or util.Broadcast[EventMessage]()
		self.broadcast = broadcast or util.Broadcast[EventMessage]()

		self.macros = dict[str, str]()
		self.headers = HeadersAccessor(self, sender)
@@ -237,13 +237,13 @@ class Session:
		self.phase = Phase.CONNECT

	async def __aenter__(self: Self) -> Self:
		await self._broadcast.__aenter__()
		await self.broadcast.__aenter__()
		return self

	async def __aexit__(self, *_: object) -> None:
		await self._broadcast.__aexit__(None, None, None)
		await self.broadcast.__aexit__(None, None, None)
		# on session close, wake up any remaining deliver() awaitables
		await self._broadcast.shutdown_hook()
		await self.broadcast.shutdown_hook()

	async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]:
		"""
@@ -256,9 +256,9 @@ class Session:
				self.macros.update(message.macros)
				return Continue  # not strictly necessary, but type checker needs something
			case Abort():
				async with self._broadcast:
				async with self.broadcast:
					self.phase = Phase.CONNECT
				await self._broadcast.abort(Aborted)
				await self.broadcast.abort(Aborted)
				return Continue
			case Helo():
				phase = Phase.MAIL
@@ -270,9 +270,9 @@ class Session:
				phase = Phase.BODY
			case EndOfMessage():  # pragma: no-branch
				phase = Phase.POST
		async with self._broadcast:
		async with self.broadcast:
			self.phase = phase  # phase attribute must be modified in locked context
		await self._broadcast.send(message)
		await self.broadcast.send(message)
		return Skip if self.phase == Phase.BODY and self.skip else Continue

	async def helo(self) -> str:
@@ -285,7 +285,7 @@ class Session:
				"Session",
			)
		while self.phase <= Phase.CONNECT:
			message = await self._broadcast.receive()
			message = await self.broadcast.receive()
			if isinstance(message, Helo):
				return message.hostname
		raise RuntimeError("HELO/EHLO event not received")
@@ -302,7 +302,7 @@ class Session:
				"Session.envelope_from() may only be awaited before the ENVELOPE phase",
			)
		while self.phase <= Phase.MAIL:
			message = await self._broadcast.receive()
			message = await self.broadcast.receive()
			if isinstance(message, EnvelopeFrom):
				return bytes(message.sender).decode()
		raise RuntimeError("MAIL event not received")
@@ -319,7 +319,7 @@ class Session:
				"Session.envelope_from() may only be awaited before the HEADERS phase",
			)
		while self.phase <= Phase.ENVELOPE:
			message = await self._broadcast.receive()
			message = await self.broadcast.receive()
			if isinstance(message, EnvelopeRecipient):
				yield bytes(message.recipient).decode()

@@ -333,7 +333,7 @@ class Session:
			)
		bname = name.encode("utf-8")
		while self.phase <= Phase.ENVELOPE:
			message = await self._broadcast.receive()
			message = await self.broadcast.receive()
			match message:
				case Unknown():
					if message.content[:len(bname)] == bname:
@@ -399,7 +399,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		for header in self._table:
			yield header
		while self.session.phase <= Phase.HEADERS:
			match (await self.session._broadcast.receive()):
			match (await self.session.broadcast.receive()):
				case Header() as header:
					self._table.append(header)
					try:
@@ -420,7 +420,7 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		# note the similarities between this and __aiter; the difference is no mutex or
		# yields
		while self.session.phase <= Phase.HEADERS:
			match (await self.session._broadcast.receive()):
			match (await self.session.broadcast.receive()):
				case Header() as header:
					self._table.append(header)
				case EndOfHeaders():
@@ -537,7 +537,7 @@ class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):

	async def __aiter(self) -> AsyncGenerator[memoryview, None]:
		while self.session.phase <= Phase.BODY:
			match (await self.session._broadcast.receive()):
			match (await self.session.broadcast.receive()):
				case Body() as body:
					try:
						yield body.content
@@ -560,4 +560,4 @@ async def _until_editable(session: Session) -> None:
	if session.phase == Phase.POST:
		return
	while session.phase < Phase.POST:
		await session._broadcast.receive()
		await session.broadcast.receive()