Coverage for kilter/service/session.py: 96.96%
317 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-07 13:26 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-07 13:26 +0000
1# Copyright 2022-2025 Dominik Sekotill <dom.sekotill@kodo.org.uk>
2#
3# This Source Code Form is subject to the terms of the Mozilla Public
4# License, v. 2.0. If a copy of the MPL was not distributed with this
5# file, You can obtain one at http://mozilla.org/MPL/2.0/.
7"""
8Sessions are the kernel of a filter, providing it with an async API to access messages
9"""
11from __future__ import annotations
13from collections.abc import AsyncGenerator
14from collections.abc import AsyncIterator
15from collections.abc import Sequence
16from contextvars import ContextVar
17from dataclasses import dataclass
18from enum import Enum
19from ipaddress import IPv4Address
20from ipaddress import IPv6Address
21from pathlib import Path
22from types import TracebackType
23from typing import AsyncContextManager
24from typing import Literal
25from typing import Protocol
26from typing import TypeAlias
27from warnings import warn
29from typing_extensions import Self
31from ..protocol.core import EditMessage
32from ..protocol.core import EventMessage
33from ..protocol.messages import *
34from . import util
36FilterResponse: TypeAlias = Accept | Reject | Discard | ReplyCode
39class Aborted(BaseException):
40 """
41 An exception for aborting filters on receipt of an Abort message
42 """
45class Filter(Protocol):
46 """
47 Filters are callables that accept a `Session` and return a response
48 """
50 async def __call__(self, session: Session, /) -> FilterResponse: ... # noqa: D102
53class Sender(Protocol):
54 """
55 Senders asynchronously handle sending messages with their "send" method
56 """
58 async def send(self, message: EditMessage) -> None: ... # noqa: D102
61class Phase(int, Enum):
62 """
63 Session phases indicate what messages to expect and are impacted by received messages
65 Users should not generally need to use these values, however an understanding of the
66 state-flow they represent is useful for understanding some error exception
67 raised by `Session` methods.
68 """
70 INIT = 0
71 """
72 This phase is the pre-connected phase of a session; this phase will be completed before
73 users see the session object.
74 """
76 CONNECT = 1
77 """
78 This phase is the starting phase of a session, during which a HELO/EHLO message may be
79 awaited with `Session.helo()`.
80 """
82 MAIL = 2
83 """
84 This phase is entered after HELO/EHLO, during which a MAIL message may be awaited with
85 `Session.envelope_from()`. The `Session.extension()` method may also be used to get
86 the raw MAIL command with any extension arguments, or any other extension commands
87 that the MTA does not support (if the MTA supports passing these commands to
88 a filter).
89 """
91 ENVELOPE = 3
92 """
93 This phase is entered after MAIL, during which any RCPT commands may be awaited with
94 `Session.envelope_recipients()`. The `Session.extension()` method may also be used to
95 get the raw RCPT command with any extension arguments, or any other extension commands
96 that the MTA does not support (if the MTA supports passing these commands to
97 a filter).
98 """
100 HEADERS = 4
101 """
102 This phase is entered after a DATA command, while message headers are processed.
103 Headers may be iterated as they arrive, or be collected for later through the
104 `Session.headers` object.
105 """
107 BODY = 5
108 """
109 This phase is entered after a message's headers have been processed. The raw message
110 body may be iterated over in chunks through the `Session.body` object.
111 """
113 POST = 6
114 """
115 This phase is entered once a message's body has been completed (or skipped). During
116 this phase the message editing methods of a `Session` object or the `Session.headers`
117 and `Session.body` objects may be used.
118 """
121@dataclass
122class Position:
123 """
124 A base class for `Before` and `After`, this class is not intended to be used directly
125 """
127 subject: Header|Literal["start"]|Literal["end"]
130@dataclass
131class Before(Position):
132 """
133 Indicates a relative position preceding a subject `Header` in a header list
135 See `HeadersAccessor.insert`.
136 """
138 subject: Header
141@dataclass
142class After(Position):
143 """
144 Indicates a relative position following a subject `Header` in a header list
146 See `HeadersAccessor.insert`.
147 """
149 subject: Header
152START = Position("start")
153"""
154Indicates the start of a header list, before the first (current) header
155"""
157END = Position("end")
158"""
159Indicates the end of a header list, after the last (current) header
160"""
163class Session:
164 """
165 The kernel of a filter, providing an API for filters to access messages from an MTA
166 """
168 host: str
169 """
170 A hostname from a reverse address lookup performed when a client connects
172 If no name is found this value defaults to the standard presentation format for
173 `Session.address` surrounded by "[" and "]", e.g. "[192.0.2.100]"
174 """
176 address: IPv4Address|IPv6Address|Path|None
177 """
178 The address of the connected client, or None if unknown
179 """
181 port: int
182 """
183 The port of the connected client if applicable, or 0 otherwise
184 """
186 macros: dict[str, str]
187 """
188 A mapping of string replacements sent by the MTA
190 See `smfi_getsymval <https://pythonhosted.org/pymilter/milter_api/smfi_getsymval.html>`_
191 from `libmilter` for more information.
193 Warning:
194 The current implementation is very naïve and does not behave exactly like
195 `libmilter`, nor is it very robust. It will definitely change in the future.
196 """
198 headers: HeadersAccessor
199 """
200 A `HeadersAccessor` object for accessing and modifying the message header fields
201 """
203 body: BodyAccessor
204 """
205 A `BodyAccessor` object for accessing and modifying the message body
206 """
208 def __init__(
209 self,
210 sender: Sender,
211 broadcast: util.Broadcast[EventMessage]|None = None,
212 ):
213 self.host = ""
214 self.address = None
215 self.port = 0
217 self.sender = sender
218 self.broadcast = broadcast or util.Broadcast[EventMessage]()
220 self.macros = dict[str, str]()
221 self.headers = HeadersAccessor(self, sender)
222 self.body = BodyAccessor(self, sender)
224 # Phase checking is a bit fuzzy as a filter may not request every message,
225 # so some phases will be skipped; checks should not try to exactly match a phase.
226 self.phase = Phase.INIT
228 self._helo: Helo|None = None
230 async def __aenter__(self) -> Self:
231 await self.broadcast.__aenter__()
232 return self
234 async def __aexit__(self, *_: object) -> None:
235 await self.broadcast.__aexit__(None, None, None)
236 # on session close, wake up any remaining deliver() awaitables
237 await self.broadcast.shutdown_hook()
239 def _reset(self) -> None:
240 self.headers = HeadersAccessor(self, self.sender)
241 self.body = BodyAccessor(self, self.sender)
243 async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]:
244 """
245 Deliver a message (or its contents) to a task waiting for it
246 """
247 match message:
248 case Connect():
249 self.host = message.hostname
250 self.address = message.address
251 self.port = message.port
252 async with self.broadcast:
253 self.phase = Phase.CONNECT
254 return Continue
255 case Macro():
256 self.macros.update(message.macros)
257 return Continue # not strictly necessary, but type checker needs something
258 case Abort():
259 async with self.broadcast:
260 self.phase = Phase.CONNECT
261 await self.broadcast.abort(Aborted)
262 self._reset()
263 return Continue
264 case Helo():
265 phase = Phase.MAIL
266 case EnvelopeFrom() | EnvelopeRecipient() | Unknown():
267 phase = Phase.ENVELOPE
268 case Data() | Header():
269 phase = Phase.HEADERS
270 case EndOfHeaders() | Body():
271 phase = Phase.BODY
272 case EndOfMessage(): # pragma: no-branch
273 phase = Phase.POST
274 async with self.broadcast:
275 self.phase = phase # phase attribute must be modified in locked context
276 await self.broadcast.send(message)
277 return Skip if self.phase == Phase.BODY and self.body.should_skip() else Continue
279 async def helo(self) -> str:
280 """
281 Wait for a HELO/EHLO message and return the client's claimed hostname
282 """
283 if self.phase > Phase.CONNECT:
284 raise RuntimeError(
285 "Session.helo() must be awaited before any other async features of a "
286 "Session",
287 )
288 if self._helo:
289 return self._helo.hostname
290 while self.phase <= Phase.CONNECT:
291 message = await self.broadcast.receive()
292 if isinstance(message, Helo):
293 self._helo = message
294 return message.hostname
295 raise RuntimeError("HELO/EHLO event not received")
297 async def envelope_from(self) -> str:
298 """
299 Wait for a MAIL command message and return the sender identity
301 Note that if extensions arguments are wanted, users should use `Session.extension()`
302 instead with a name of ``"MAIL"``.
303 """
304 if self.phase > Phase.MAIL:
305 raise RuntimeError(
306 "Session.envelope_from() may only be awaited before the ENVELOPE phase",
307 )
308 while self.phase <= Phase.MAIL:
309 message = await self.broadcast.receive()
310 if isinstance(message, EnvelopeFrom):
311 return bytes(message.sender).decode()
312 raise RuntimeError("MAIL event not received")
314 async def envelope_recipients(self) -> AsyncIterator[str]:
315 """
316 Wait for RCPT command messages and iteratively yield the recipients' identities
318 Note that if extensions arguments are wanted, users should use `Session.extension()`
319 instead with a name of ``"RCPT"``.
320 """
321 if self.phase > Phase.ENVELOPE:
322 raise RuntimeError(
323 "Session.envelope_from() may only be awaited before the HEADERS phase",
324 )
325 while self.phase <= Phase.ENVELOPE:
326 message = await self.broadcast.receive()
327 if isinstance(message, EnvelopeRecipient):
328 yield bytes(message.recipient).decode()
330 async def extension(self, name: str) -> memoryview:
331 """
332 Wait for the named command extension and return the raw command for processing
333 """
334 if self.phase > Phase.ENVELOPE:
335 raise RuntimeError(
336 "Session.extension() may only be awaited before the HEADERS phase",
337 )
338 bname = name.encode("utf-8")
339 while self.phase <= Phase.ENVELOPE:
340 message = await self.broadcast.receive()
341 match message:
342 case Unknown():
343 if message.content[:len(bname)] == bname:
344 assert isinstance(message.content, memoryview)
345 return message.content
346 # fake buffers for MAIL and RCPT commands
347 case EnvelopeFrom() if name == "MAIL":
348 vals = [b"MAIL FROM", message.sender, *message.arguments]
349 return memoryview(b" ".join(vals))
350 case EnvelopeRecipient() if name == "RCPT":
351 vals = [b"RCPT TO", message.recipient, *message.arguments]
352 return memoryview(b" ".join(vals))
353 raise RuntimeError(f"{name} event not received")
355 async def change_sender(self, sender: str, args: str = "") -> None:
356 """
357 Move onto the `Phase.POST` phase and instruct the MTA to change the sender address
358 """
359 await _until_editable(self)
360 await self.sender.send(ChangeSender(sender, args or None))
362 async def add_recipient(self, recipient: str, args: str = "") -> None:
363 """
364 Move onto the `Phase.POST` phase and instruct the MTA to add a new recipient address
365 """
366 await _until_editable(self)
367 await self.sender.send(
368 AddRecipientPar(recipient, args) if args else AddRecipient(recipient),
369 )
371 async def remove_recipient(self, recipient: str) -> None:
372 """
373 Move onto the `Phase.POST` phase and instruct the MTA to remove a recipient address
374 """
375 await _until_editable(self)
376 await self.sender.send(RemoveRecipient(recipient))
379class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
380 """
381 A class that allows access and modification of the message headers sent from an MTA
383 To access headers (which are only available iteratively), use an instance as an
384 asynchronous context manager; a `HeaderIterator` is returned when the context is
385 entered.
386 """
388 def __init__(self, session: Session, sender: Sender):
389 self.session = session
390 self.sender = sender
391 self._table = list[Header]()
392 self._aiter = ContextVar[HeaderIterator|None]("header-iter")
394 async def __aenter__(self) -> HeaderIterator:
395 if not (aiter := self._aiter.get(None)): 395 ↛ 398line 395 didn't jump to line 398 because the condition on line 395 was always true
396 aiter = HeaderIterator(self.__aiter())
397 self._aiter.set(aiter)
398 return aiter
400 async def __aexit__(self, *_: object) -> None:
401 if aiter := self._aiter.get(): 401 ↛ 403line 401 didn't jump to line 403 because the condition on line 401 was always true
402 await aiter.aclose()
403 self._aiter.set(None)
405 async def __aiter(self) -> AsyncGenerator[Header, None]:
406 # yield from cached headers first; allows multiple tasks to access the headers
407 # in an uncoordinated manner; note the broadcaster is locked at this point
408 for header in self._table:
409 yield header
410 seen = set(id(header) for header in self._table)
411 while self.session.phase <= Phase.HEADERS:
412 match (await self.session.broadcast.receive()):
413 case Header() as header:
414 header.freeze()
415 self._table.append(header)
416 seen.add(id(header))
417 try:
418 yield header
419 except GeneratorExit:
420 await self.collect()
421 raise
422 case EndOfHeaders():
423 return
424 # It's possible for collect() to have been called while yielded, in which case the
425 # loop will end. Yield any headers that were stored by collect() but not yet
426 # yielded.
427 for header in self._table:
428 if id(header) not in seen: 428 ↛ 429line 428 didn't jump to line 429 because the condition on line 428 was never true
429 yield header
431 async def collect(self) -> None:
432 """
433 Collect all headers without producing an iterator
435 Calling this method before the `Phase.BODY` phase allows later processing of headers
436 (after the HEADER phase) without the need for an empty loop.
437 """
438 # note the similarities between this and __aiter; the difference is no mutex or
439 # yields
440 while self.session.phase <= Phase.HEADERS:
441 match (await self.session.broadcast.receive()):
442 case Header() as header:
443 header.freeze()
444 self._table.append(header)
445 case _:
446 return
448 async def delete(self, header: Header) -> None:
449 """
450 Move onto the `Phase.POST` phase and Instruct the MTA to delete the given header
451 """
452 await self.collect()
453 await _until_editable(self.session)
454 index = _index_by_name(self._table, header)
455 await self.sender.send(ChangeHeader(index, header.name, b""))
456 self._table.remove(header)
458 async def update(self, header: Header, value: bytes) -> None:
459 """
460 Move onto the `Phase.POST` phase and Instruct the MTA to modify the value of a header
461 """
462 await self.collect()
463 await _until_editable(self.session)
464 index = _index_by_name(self._table, header)
465 await self.sender.send(ChangeHeader(index, header.name, value))
466 index = self._table.index(header)
467 self._table[index].value = value
469 async def insert(self, header: Header, position: Position) -> None:
470 """
471 Move onto the `Phase.POST` phase and instruct the MTA to insert a new header
473 The header is inserted at `START`, `END`, or a relative position with `Before` and
474 `After`; for example ``Before(Header("To", "test@example.com"))``.
475 """
476 await self.collect()
477 await _until_editable(self.session)
478 match position:
479 case Position(subject="start"):
480 index = 0
481 case Position(subject="end"):
482 index = len(self._table)
483 case Before():
484 index = self._table.index(position.subject)
485 case After(): # pragma: no-branch
486 index = self._table.index(position.subject) + 1
487 case _:
488 raise TypeError("Expect a Position")
489 if index >= len(self._table):
490 await self.sender.send(AddHeader(header.name, header.value))
491 self._table.append(header)
492 else:
493 await self.sender.send(InsertHeader(index + 1, header.name, header.value))
494 self._table.insert(index, header)
497class HeaderIterator(AsyncGenerator[Header, None]):
498 """
499 Iterator for headers obtained by using a `HeadersAccessor` as a context manager
500 """
502 def __init__(self, aiter: AsyncGenerator[Header, None]):
503 self._aiter = aiter
505 def __aiter__(self) -> Self:
506 return self
508 async def __anext__(self) -> Header: # noqa: D102
509 return await self._aiter.__anext__()
511 async def asend(self, value: None = None) -> Header: # noqa: D102
512 return await self._aiter.__anext__()
514 async def athrow( # noqa: D102
515 self,
516 e: type[BaseException]|BaseException,
517 m: object = None,
518 t: TracebackType|None = None, /,
519 ) -> Header:
520 if isinstance(e, type):
521 return await self._aiter.athrow(e, m, t)
522 assert m is None
523 return await self._aiter.athrow(e, m, t)
525 async def aclose(self) -> None: # noqa: D102
526 await self._aiter.aclose()
528 async def restrict(self, *names: str) -> AsyncIterator[Header]:
529 """
530 Return an asynchronous generator that filters headers by name
531 """
532 async for header in self._aiter:
533 if header.name in names:
534 yield header
537class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]):
538 """
539 A class that allows access and modification of the message body sent from an MTA
541 To access chunks of a body (which are only available iteratively), use an instance as an
542 asynchronous context manager; an asynchronous iterator is returned when the context is
543 entered.
544 """
546 def __init__(self, session: Session, sender: Sender):
547 self.session = session
548 self.sender = sender
549 self._entered = 0
550 self._skip = False
551 self._aiter = ContextVar[AsyncGenerator[memoryview, None] | None]("body-iter")
553 async def __aenter__(self) -> AsyncIterator[memoryview]:
554 if not (aiter := self._aiter.get(None)): 554 ↛ 557line 554 didn't jump to line 557 because the condition on line 554 was always true
555 aiter = self.__aiter()
556 self._aiter.set(aiter)
557 self._entered += 1
558 return aiter
560 async def __aexit__(self, *_: object) -> None:
561 if aiter := self._aiter.get(None): 561 ↛ 563line 561 didn't jump to line 563 because the condition on line 561 was always true
562 await aiter.aclose()
563 self._aiter.set(None)
564 self._entered -= 1
566 async def __aiter(self) -> AsyncGenerator[memoryview, None]:
567 while self.session.phase <= Phase.BODY:
568 match (await self.session.broadcast.receive()):
569 case Body() as body:
570 assert isinstance(body.content, memoryview)
571 yield body.content
572 case EndOfMessage() as eom:
573 assert isinstance(eom.content, memoryview)
574 yield eom.content
576 def should_skip(self) -> bool:
577 """
578 Return whether the message body should be skipped
580 The body should be skipped when there are no active contexts. All correctly
581 implemented filters should have started a context before the first `Body` message.
583 Once this method returns `True` it becomes "locked in" and will always return `True`
584 after.
585 """
586 if self._skip:
587 return True
588 self._skip = self._entered == 0
589 return self._skip
591 async def write(self, chunk: bytes) -> None:
592 """
593 Request that chunks of a new message body are sent to the MTA
595 This method should not be called from within the scope created by using it's
596 instance as an async context (`async with`); doing so may cause a warning to be
597 issued and the rest of the message body to be skipped.
598 """
599 if self._aiter.get(None):
600 warn(
601 "it looks as if BodyAccessor.write() was called on an instance from within "
602 "it's own async context",
603 stacklevel=2,
604 )
605 await _until_editable(self.session)
606 await self.sender.send(ReplaceBody(chunk))
609async def _until_editable(session: Session) -> None:
610 if session.phase == Phase.POST:
611 return
612 while session.phase < Phase.POST:
613 if session.phase == Phase.HEADERS: 613 ↛ 614line 613 didn't jump to line 614 because the condition on line 613 was never true
614 await session.headers.collect()
615 else:
616 await session.broadcast.receive()
619def _index_by_name(table: Sequence[Header], needle: Header) -> int:
620 index = 0
621 name = needle.name.lower()
622 for header in table: 622 ↛ 627line 622 didn't jump to line 627 because the loop on line 622 didn't complete
623 if header == needle:
624 return index + 1
625 if header.name.lower() == name: 625 ↛ 622line 625 didn't jump to line 622 because the condition on line 625 was always true
626 index += 1
627 raise ValueError(f"header not found: {needle}")