Coverage for kilter/service/session.py: 96.96%

317 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-07 13:26 +0000

1# Copyright 2022-2025 Dominik Sekotill <dom.sekotill@kodo.org.uk> 

2# 

3# This Source Code Form is subject to the terms of the Mozilla Public 

4# License, v. 2.0. If a copy of the MPL was not distributed with this 

5# file, You can obtain one at http://mozilla.org/MPL/2.0/. 

6 

7""" 

8Sessions are the kernel of a filter, providing it with an async API to access messages 

9""" 

10 

11from __future__ import annotations 

12 

13from collections.abc import AsyncGenerator 

14from collections.abc import AsyncIterator 

15from collections.abc import Sequence 

16from contextvars import ContextVar 

17from dataclasses import dataclass 

18from enum import Enum 

19from ipaddress import IPv4Address 

20from ipaddress import IPv6Address 

21from pathlib import Path 

22from types import TracebackType 

23from typing import AsyncContextManager 

24from typing import Literal 

25from typing import Protocol 

26from typing import TypeAlias 

27from warnings import warn 

28 

29from typing_extensions import Self 

30 

31from ..protocol.core import EditMessage 

32from ..protocol.core import EventMessage 

33from ..protocol.messages import * 

34from . import util 

35 

36FilterResponse: TypeAlias = Accept | Reject | Discard | ReplyCode 

37 

38 

39class Aborted(BaseException): 

40 """ 

41 An exception for aborting filters on receipt of an Abort message 

42 """ 

43 

44 

45class Filter(Protocol): 

46 """ 

47 Filters are callables that accept a `Session` and return a response 

48 """ 

49 

50 async def __call__(self, session: Session, /) -> FilterResponse: ... # noqa: D102 

51 

52 

53class Sender(Protocol): 

54 """ 

55 Senders asynchronously handle sending messages with their "send" method 

56 """ 

57 

58 async def send(self, message: EditMessage) -> None: ... # noqa: D102 

59 

60 

61class Phase(int, Enum): 

62 """ 

63 Session phases indicate what messages to expect and are impacted by received messages 

64 

65 Users should not generally need to use these values, however an understanding of the 

66 state-flow they represent is useful for understanding some error exception 

67 raised by `Session` methods. 

68 """ 

69 

70 INIT = 0 

71 """ 

72 This phase is the pre-connected phase of a session; this phase will be completed before 

73 users see the session object. 

74 """ 

75 

76 CONNECT = 1 

77 """ 

78 This phase is the starting phase of a session, during which a HELO/EHLO message may be 

79 awaited with `Session.helo()`. 

80 """ 

81 

82 MAIL = 2 

83 """ 

84 This phase is entered after HELO/EHLO, during which a MAIL message may be awaited with 

85 `Session.envelope_from()`. The `Session.extension()` method may also be used to get 

86 the raw MAIL command with any extension arguments, or any other extension commands 

87 that the MTA does not support (if the MTA supports passing these commands to 

88 a filter). 

89 """ 

90 

91 ENVELOPE = 3 

92 """ 

93 This phase is entered after MAIL, during which any RCPT commands may be awaited with 

94 `Session.envelope_recipients()`. The `Session.extension()` method may also be used to 

95 get the raw RCPT command with any extension arguments, or any other extension commands 

96 that the MTA does not support (if the MTA supports passing these commands to 

97 a filter). 

98 """ 

99 

100 HEADERS = 4 

101 """ 

102 This phase is entered after a DATA command, while message headers are processed. 

103 Headers may be iterated as they arrive, or be collected for later through the 

104 `Session.headers` object. 

105 """ 

106 

107 BODY = 5 

108 """ 

109 This phase is entered after a message's headers have been processed. The raw message 

110 body may be iterated over in chunks through the `Session.body` object. 

111 """ 

112 

113 POST = 6 

114 """ 

115 This phase is entered once a message's body has been completed (or skipped). During 

116 this phase the message editing methods of a `Session` object or the `Session.headers` 

117 and `Session.body` objects may be used. 

118 """ 

119 

120 

121@dataclass 

122class Position: 

123 """ 

124 A base class for `Before` and `After`, this class is not intended to be used directly 

125 """ 

126 

127 subject: Header|Literal["start"]|Literal["end"] 

128 

129 

130@dataclass 

131class Before(Position): 

132 """ 

133 Indicates a relative position preceding a subject `Header` in a header list 

134 

135 See `HeadersAccessor.insert`. 

136 """ 

137 

138 subject: Header 

139 

140 

141@dataclass 

142class After(Position): 

143 """ 

144 Indicates a relative position following a subject `Header` in a header list 

145 

146 See `HeadersAccessor.insert`. 

147 """ 

148 

149 subject: Header 

150 

151 

152START = Position("start") 

153""" 

154Indicates the start of a header list, before the first (current) header 

155""" 

156 

157END = Position("end") 

158""" 

159Indicates the end of a header list, after the last (current) header 

160""" 

161 

162 

163class Session: 

164 """ 

165 The kernel of a filter, providing an API for filters to access messages from an MTA 

166 """ 

167 

168 host: str 

169 """ 

170 A hostname from a reverse address lookup performed when a client connects 

171 

172 If no name is found this value defaults to the standard presentation format for 

173 `Session.address` surrounded by "[" and "]", e.g. "[192.0.2.100]" 

174 """ 

175 

176 address: IPv4Address|IPv6Address|Path|None 

177 """ 

178 The address of the connected client, or None if unknown 

179 """ 

180 

181 port: int 

182 """ 

183 The port of the connected client if applicable, or 0 otherwise 

184 """ 

185 

186 macros: dict[str, str] 

187 """ 

188 A mapping of string replacements sent by the MTA 

189 

190 See `smfi_getsymval <https://pythonhosted.org/pymilter/milter_api/smfi_getsymval.html>`_ 

191 from `libmilter` for more information. 

192 

193 Warning: 

194 The current implementation is very naïve and does not behave exactly like 

195 `libmilter`, nor is it very robust. It will definitely change in the future. 

196 """ 

197 

198 headers: HeadersAccessor 

199 """ 

200 A `HeadersAccessor` object for accessing and modifying the message header fields 

201 """ 

202 

203 body: BodyAccessor 

204 """ 

205 A `BodyAccessor` object for accessing and modifying the message body 

206 """ 

207 

208 def __init__( 

209 self, 

210 sender: Sender, 

211 broadcast: util.Broadcast[EventMessage]|None = None, 

212 ): 

213 self.host = "" 

214 self.address = None 

215 self.port = 0 

216 

217 self.sender = sender 

218 self.broadcast = broadcast or util.Broadcast[EventMessage]() 

219 

220 self.macros = dict[str, str]() 

221 self.headers = HeadersAccessor(self, sender) 

222 self.body = BodyAccessor(self, sender) 

223 

224 # Phase checking is a bit fuzzy as a filter may not request every message, 

225 # so some phases will be skipped; checks should not try to exactly match a phase. 

226 self.phase = Phase.INIT 

227 

228 self._helo: Helo|None = None 

229 

230 async def __aenter__(self) -> Self: 

231 await self.broadcast.__aenter__() 

232 return self 

233 

234 async def __aexit__(self, *_: object) -> None: 

235 await self.broadcast.__aexit__(None, None, None) 

236 # on session close, wake up any remaining deliver() awaitables 

237 await self.broadcast.shutdown_hook() 

238 

239 def _reset(self) -> None: 

240 self.headers = HeadersAccessor(self, self.sender) 

241 self.body = BodyAccessor(self, self.sender) 

242 

243 async def deliver(self, message: EventMessage) -> type[Continue]|type[Skip]: 

244 """ 

245 Deliver a message (or its contents) to a task waiting for it 

246 """ 

247 match message: 

248 case Connect(): 

249 self.host = message.hostname 

250 self.address = message.address 

251 self.port = message.port 

252 async with self.broadcast: 

253 self.phase = Phase.CONNECT 

254 return Continue 

255 case Macro(): 

256 self.macros.update(message.macros) 

257 return Continue # not strictly necessary, but type checker needs something 

258 case Abort(): 

259 async with self.broadcast: 

260 self.phase = Phase.CONNECT 

261 await self.broadcast.abort(Aborted) 

262 self._reset() 

263 return Continue 

264 case Helo(): 

265 phase = Phase.MAIL 

266 case EnvelopeFrom() | EnvelopeRecipient() | Unknown(): 

267 phase = Phase.ENVELOPE 

268 case Data() | Header(): 

269 phase = Phase.HEADERS 

270 case EndOfHeaders() | Body(): 

271 phase = Phase.BODY 

272 case EndOfMessage(): # pragma: no-branch 

273 phase = Phase.POST 

274 async with self.broadcast: 

275 self.phase = phase # phase attribute must be modified in locked context 

276 await self.broadcast.send(message) 

277 return Skip if self.phase == Phase.BODY and self.body.should_skip() else Continue 

278 

279 async def helo(self) -> str: 

280 """ 

281 Wait for a HELO/EHLO message and return the client's claimed hostname 

282 """ 

283 if self.phase > Phase.CONNECT: 

284 raise RuntimeError( 

285 "Session.helo() must be awaited before any other async features of a " 

286 "Session", 

287 ) 

288 if self._helo: 

289 return self._helo.hostname 

290 while self.phase <= Phase.CONNECT: 

291 message = await self.broadcast.receive() 

292 if isinstance(message, Helo): 

293 self._helo = message 

294 return message.hostname 

295 raise RuntimeError("HELO/EHLO event not received") 

296 

297 async def envelope_from(self) -> str: 

298 """ 

299 Wait for a MAIL command message and return the sender identity 

300 

301 Note that if extensions arguments are wanted, users should use `Session.extension()` 

302 instead with a name of ``"MAIL"``. 

303 """ 

304 if self.phase > Phase.MAIL: 

305 raise RuntimeError( 

306 "Session.envelope_from() may only be awaited before the ENVELOPE phase", 

307 ) 

308 while self.phase <= Phase.MAIL: 

309 message = await self.broadcast.receive() 

310 if isinstance(message, EnvelopeFrom): 

311 return bytes(message.sender).decode() 

312 raise RuntimeError("MAIL event not received") 

313 

314 async def envelope_recipients(self) -> AsyncIterator[str]: 

315 """ 

316 Wait for RCPT command messages and iteratively yield the recipients' identities 

317 

318 Note that if extensions arguments are wanted, users should use `Session.extension()` 

319 instead with a name of ``"RCPT"``. 

320 """ 

321 if self.phase > Phase.ENVELOPE: 

322 raise RuntimeError( 

323 "Session.envelope_from() may only be awaited before the HEADERS phase", 

324 ) 

325 while self.phase <= Phase.ENVELOPE: 

326 message = await self.broadcast.receive() 

327 if isinstance(message, EnvelopeRecipient): 

328 yield bytes(message.recipient).decode() 

329 

330 async def extension(self, name: str) -> memoryview: 

331 """ 

332 Wait for the named command extension and return the raw command for processing 

333 """ 

334 if self.phase > Phase.ENVELOPE: 

335 raise RuntimeError( 

336 "Session.extension() may only be awaited before the HEADERS phase", 

337 ) 

338 bname = name.encode("utf-8") 

339 while self.phase <= Phase.ENVELOPE: 

340 message = await self.broadcast.receive() 

341 match message: 

342 case Unknown(): 

343 if message.content[:len(bname)] == bname: 

344 assert isinstance(message.content, memoryview) 

345 return message.content 

346 # fake buffers for MAIL and RCPT commands 

347 case EnvelopeFrom() if name == "MAIL": 

348 vals = [b"MAIL FROM", message.sender, *message.arguments] 

349 return memoryview(b" ".join(vals)) 

350 case EnvelopeRecipient() if name == "RCPT": 

351 vals = [b"RCPT TO", message.recipient, *message.arguments] 

352 return memoryview(b" ".join(vals)) 

353 raise RuntimeError(f"{name} event not received") 

354 

355 async def change_sender(self, sender: str, args: str = "") -> None: 

356 """ 

357 Move onto the `Phase.POST` phase and instruct the MTA to change the sender address 

358 """ 

359 await _until_editable(self) 

360 await self.sender.send(ChangeSender(sender, args or None)) 

361 

362 async def add_recipient(self, recipient: str, args: str = "") -> None: 

363 """ 

364 Move onto the `Phase.POST` phase and instruct the MTA to add a new recipient address 

365 """ 

366 await _until_editable(self) 

367 await self.sender.send( 

368 AddRecipientPar(recipient, args) if args else AddRecipient(recipient), 

369 ) 

370 

371 async def remove_recipient(self, recipient: str) -> None: 

372 """ 

373 Move onto the `Phase.POST` phase and instruct the MTA to remove a recipient address 

374 """ 

375 await _until_editable(self) 

376 await self.sender.send(RemoveRecipient(recipient)) 

377 

378 

379class HeadersAccessor(AsyncContextManager["HeaderIterator"]): 

380 """ 

381 A class that allows access and modification of the message headers sent from an MTA 

382 

383 To access headers (which are only available iteratively), use an instance as an 

384 asynchronous context manager; a `HeaderIterator` is returned when the context is 

385 entered. 

386 """ 

387 

388 def __init__(self, session: Session, sender: Sender): 

389 self.session = session 

390 self.sender = sender 

391 self._table = list[Header]() 

392 self._aiter = ContextVar[HeaderIterator|None]("header-iter") 

393 

394 async def __aenter__(self) -> HeaderIterator: 

395 if not (aiter := self._aiter.get(None)): 395 ↛ 398line 395 didn't jump to line 398 because the condition on line 395 was always true

396 aiter = HeaderIterator(self.__aiter()) 

397 self._aiter.set(aiter) 

398 return aiter 

399 

400 async def __aexit__(self, *_: object) -> None: 

401 if aiter := self._aiter.get(): 401 ↛ 403line 401 didn't jump to line 403 because the condition on line 401 was always true

402 await aiter.aclose() 

403 self._aiter.set(None) 

404 

405 async def __aiter(self) -> AsyncGenerator[Header, None]: 

406 # yield from cached headers first; allows multiple tasks to access the headers 

407 # in an uncoordinated manner; note the broadcaster is locked at this point 

408 for header in self._table: 

409 yield header 

410 seen = set(id(header) for header in self._table) 

411 while self.session.phase <= Phase.HEADERS: 

412 match (await self.session.broadcast.receive()): 

413 case Header() as header: 

414 header.freeze() 

415 self._table.append(header) 

416 seen.add(id(header)) 

417 try: 

418 yield header 

419 except GeneratorExit: 

420 await self.collect() 

421 raise 

422 case EndOfHeaders(): 

423 return 

424 # It's possible for collect() to have been called while yielded, in which case the 

425 # loop will end. Yield any headers that were stored by collect() but not yet 

426 # yielded. 

427 for header in self._table: 

428 if id(header) not in seen: 428 ↛ 429line 428 didn't jump to line 429 because the condition on line 428 was never true

429 yield header 

430 

431 async def collect(self) -> None: 

432 """ 

433 Collect all headers without producing an iterator 

434 

435 Calling this method before the `Phase.BODY` phase allows later processing of headers 

436 (after the HEADER phase) without the need for an empty loop. 

437 """ 

438 # note the similarities between this and __aiter; the difference is no mutex or 

439 # yields 

440 while self.session.phase <= Phase.HEADERS: 

441 match (await self.session.broadcast.receive()): 

442 case Header() as header: 

443 header.freeze() 

444 self._table.append(header) 

445 case _: 

446 return 

447 

448 async def delete(self, header: Header) -> None: 

449 """ 

450 Move onto the `Phase.POST` phase and Instruct the MTA to delete the given header 

451 """ 

452 await self.collect() 

453 await _until_editable(self.session) 

454 index = _index_by_name(self._table, header) 

455 await self.sender.send(ChangeHeader(index, header.name, b"")) 

456 self._table.remove(header) 

457 

458 async def update(self, header: Header, value: bytes) -> None: 

459 """ 

460 Move onto the `Phase.POST` phase and Instruct the MTA to modify the value of a header 

461 """ 

462 await self.collect() 

463 await _until_editable(self.session) 

464 index = _index_by_name(self._table, header) 

465 await self.sender.send(ChangeHeader(index, header.name, value)) 

466 index = self._table.index(header) 

467 self._table[index].value = value 

468 

469 async def insert(self, header: Header, position: Position) -> None: 

470 """ 

471 Move onto the `Phase.POST` phase and instruct the MTA to insert a new header 

472 

473 The header is inserted at `START`, `END`, or a relative position with `Before` and 

474 `After`; for example ``Before(Header("To", "test@example.com"))``. 

475 """ 

476 await self.collect() 

477 await _until_editable(self.session) 

478 match position: 

479 case Position(subject="start"): 

480 index = 0 

481 case Position(subject="end"): 

482 index = len(self._table) 

483 case Before(): 

484 index = self._table.index(position.subject) 

485 case After(): # pragma: no-branch 

486 index = self._table.index(position.subject) + 1 

487 case _: 

488 raise TypeError("Expect a Position") 

489 if index >= len(self._table): 

490 await self.sender.send(AddHeader(header.name, header.value)) 

491 self._table.append(header) 

492 else: 

493 await self.sender.send(InsertHeader(index + 1, header.name, header.value)) 

494 self._table.insert(index, header) 

495 

496 

497class HeaderIterator(AsyncGenerator[Header, None]): 

498 """ 

499 Iterator for headers obtained by using a `HeadersAccessor` as a context manager 

500 """ 

501 

502 def __init__(self, aiter: AsyncGenerator[Header, None]): 

503 self._aiter = aiter 

504 

505 def __aiter__(self) -> Self: 

506 return self 

507 

508 async def __anext__(self) -> Header: # noqa: D102 

509 return await self._aiter.__anext__() 

510 

511 async def asend(self, value: None = None) -> Header: # noqa: D102 

512 return await self._aiter.__anext__() 

513 

514 async def athrow( # noqa: D102 

515 self, 

516 e: type[BaseException]|BaseException, 

517 m: object = None, 

518 t: TracebackType|None = None, /, 

519 ) -> Header: 

520 if isinstance(e, type): 

521 return await self._aiter.athrow(e, m, t) 

522 assert m is None 

523 return await self._aiter.athrow(e, m, t) 

524 

525 async def aclose(self) -> None: # noqa: D102 

526 await self._aiter.aclose() 

527 

528 async def restrict(self, *names: str) -> AsyncIterator[Header]: 

529 """ 

530 Return an asynchronous generator that filters headers by name 

531 """ 

532 async for header in self._aiter: 

533 if header.name in names: 

534 yield header 

535 

536 

537class BodyAccessor(AsyncContextManager[AsyncIterator[memoryview]]): 

538 """ 

539 A class that allows access and modification of the message body sent from an MTA 

540 

541 To access chunks of a body (which are only available iteratively), use an instance as an 

542 asynchronous context manager; an asynchronous iterator is returned when the context is 

543 entered. 

544 """ 

545 

546 def __init__(self, session: Session, sender: Sender): 

547 self.session = session 

548 self.sender = sender 

549 self._entered = 0 

550 self._skip = False 

551 self._aiter = ContextVar[AsyncGenerator[memoryview, None] | None]("body-iter") 

552 

553 async def __aenter__(self) -> AsyncIterator[memoryview]: 

554 if not (aiter := self._aiter.get(None)): 554 ↛ 557line 554 didn't jump to line 557 because the condition on line 554 was always true

555 aiter = self.__aiter() 

556 self._aiter.set(aiter) 

557 self._entered += 1 

558 return aiter 

559 

560 async def __aexit__(self, *_: object) -> None: 

561 if aiter := self._aiter.get(None): 561 ↛ 563line 561 didn't jump to line 563 because the condition on line 561 was always true

562 await aiter.aclose() 

563 self._aiter.set(None) 

564 self._entered -= 1 

565 

566 async def __aiter(self) -> AsyncGenerator[memoryview, None]: 

567 while self.session.phase <= Phase.BODY: 

568 match (await self.session.broadcast.receive()): 

569 case Body() as body: 

570 assert isinstance(body.content, memoryview) 

571 yield body.content 

572 case EndOfMessage() as eom: 

573 assert isinstance(eom.content, memoryview) 

574 yield eom.content 

575 

576 def should_skip(self) -> bool: 

577 """ 

578 Return whether the message body should be skipped 

579 

580 The body should be skipped when there are no active contexts. All correctly 

581 implemented filters should have started a context before the first `Body` message. 

582 

583 Once this method returns `True` it becomes "locked in" and will always return `True` 

584 after. 

585 """ 

586 if self._skip: 

587 return True 

588 self._skip = self._entered == 0 

589 return self._skip 

590 

591 async def write(self, chunk: bytes) -> None: 

592 """ 

593 Request that chunks of a new message body are sent to the MTA 

594 

595 This method should not be called from within the scope created by using it's 

596 instance as an async context (`async with`); doing so may cause a warning to be 

597 issued and the rest of the message body to be skipped. 

598 """ 

599 if self._aiter.get(None): 

600 warn( 

601 "it looks as if BodyAccessor.write() was called on an instance from within " 

602 "it's own async context", 

603 stacklevel=2, 

604 ) 

605 await _until_editable(self.session) 

606 await self.sender.send(ReplaceBody(chunk)) 

607 

608 

609async def _until_editable(session: Session) -> None: 

610 if session.phase == Phase.POST: 

611 return 

612 while session.phase < Phase.POST: 

613 if session.phase == Phase.HEADERS: 613 ↛ 614line 613 didn't jump to line 614 because the condition on line 613 was never true

614 await session.headers.collect() 

615 else: 

616 await session.broadcast.receive() 

617 

618 

619def _index_by_name(table: Sequence[Header], needle: Header) -> int: 

620 index = 0 

621 name = needle.name.lower() 

622 for header in table: 622 ↛ 627line 622 didn't jump to line 627 because the loop on line 622 didn't complete

623 if header == needle: 

624 return index + 1 

625 if header.name.lower() == name: 625 ↛ 622line 625 didn't jump to line 622 because the condition on line 625 was always true

626 index += 1 

627 raise ValueError(f"header not found: {needle}")