Coverage for kilter/service/runner.py: 95.62%
187 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-07 13:26 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-07 13:26 +0000
1# Copyright 2022-2023, 2025 Dominik Sekotill <dom.sekotill@kodo.org.uk>
2#
3# This Source Code Form is subject to the terms of the Mozilla Public
4# License, v. 2.0. If a copy of the MPL was not distributed with this
5# file, You can obtain one at http://mozilla.org/MPL/2.0/.
7"""
8Coordinate receiving and sending raw messages with a filter and Session object
10The primary class in this module (`Runner`) is intended to be used with an
11`anyio.abc.Listener`, which can be obtained, for instance, from
12`anyio.create_tcp_listener()`.
13"""
15from __future__ import annotations
17import enum
18import logging
19from collections import defaultdict
20from collections.abc import Iterable
21from typing import Final
22from typing import TypeAlias
23from warnings import warn
25import anyio.abc
26from async_generator import aclosing
28from kilter.protocol.buffer import SimpleBuffer
29from kilter.protocol.core import EventMessage
30from kilter.protocol.core import FilterMessage
31from kilter.protocol.core import FilterProtocol
32from kilter.protocol.core import ResponseMessage
33from kilter.protocol.messages import *
35from .options import get_flags
36from .options import get_macros
37from .session import Aborted
38from .session import Filter
39from .session import FilterResponse
40from .session import Session
41from .util import Broadcast
42from .util import qualname
44__all__ = [
45 "Runner",
46 "NegotiationError",
47]
49FinalResponse: TypeAlias = FilterResponse | TemporaryFailure
51kiB: Final = 2**10
52MiB: Final = 2**20
54_logger = logging.getLogger(__package__)
57class NegotiationError(Exception):
58 """
59 An error raised when MTAs are not compatible with the filter
60 """
63class State(enum.Enum):
65 CONNECTED = enum.auto()
66 SESSION = enum.auto()
67 SESSION_ABORTED = enum.auto()
68 MESSAGE = enum.auto()
69 MESSAGE_ABORTED = enum.auto()
72class _Broadcast(Broadcast[EventMessage]):
74 def __init__(self) -> None:
75 super().__init__()
76 self.task_status = list[anyio.abc.TaskStatus[None]]()
78 async def shutdown_hook(self) -> None:
79 await self.pre_receive_hook()
81 async def pre_receive_hook(self) -> None:
82 while self.task_status:
83 self.task_status.pop().started()
86class Sender:
87 """
88 Concrete implementation of `kilter.service.session.Sender`
89 """
91 def __init__(self, client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> None:
92 self.client = client
93 self.proto = proto
95 async def send(self, message: FilterMessage) -> None:
96 """
97 Encode and send a message to the client stream
98 """
99 buffer = SimpleBuffer(1*kiB)
100 self.proto.write_to(buffer, message)
101 await self.client.send(buffer[:])
102 if __debug__:
103 _logger.debug(f"sent: {message}")
106class Runner:
107 """
108 A filter runner that coordinates passing data between a stream and multiple filters
110 Instances can be used as handlers that can be passed to `anyio.abc.Listener.serve()` or
111 used with any `anyio.abc.ByteStream`.
112 """
114 def __init__(self, *filters: Filter):
115 if len(filters) == 0: # pragma: no-cover
116 raise TypeError("Runner requires at least one filter to run")
117 self.filters = set(filters)
118 if len(filters) != len(self.filters):
119 warn("Repeated filters will only be run once", stacklevel=2)
120 self.use_skip = True
122 async def __call__(self, client: anyio.abc.ByteStream) -> None:
123 """
124 Return an awaitable that starts and coordinates filters
125 """
126 buff = SimpleBuffer(1*MiB)
127 proto = FilterProtocol(abort_on_unknown=True)
128 sender = Sender(client, proto)
129 session = Session(sender, _Broadcast())
130 runner = SessionRunner(session)
131 state = State.CONNECTED
133 async with (
134 aclosing(client),
135 anyio.create_task_group() as tasks,
136 ):
137 while 1:
138 try:
139 buff[:] = await client.receive(buff.available)
140 except (
141 anyio.EndOfStream,
142 anyio.ClosedResourceError,
143 anyio.BrokenResourceError,
144 ):
145 return
146 for message in proto.read_from(buff):
147 if __debug__:
148 _logger.debug(f"received: {message}")
150 # If previous message was Abort, restart filters for any non-Abort/Close
151 # message
152 if state in (State.SESSION_ABORTED, State.MESSAGE_ABORTED):
153 if not isinstance(message, Abort|Close):
154 await runner.start(self.filters, tasks)
155 state = (
156 State.CONNECTED if state == State.SESSION_ABORTED else
157 State.SESSION
158 )
160 match message:
161 case Negotiate():
162 await sender.send(await self._negotiate(message))
163 continue
164 case Connect():
165 _logger.info(f"Client connected from {message.hostname}")
166 await session.deliver(message)
167 await runner.start(self.filters, tasks)
168 if proto.needs_response(message): 168 ↛ 170line 168 didn't jump to line 170 because the condition on line 168 was always true
169 await sender.send(await runner.check_response() or Continue())
170 continue
171 case Helo():
172 state = State.SESSION
173 case EnvelopeFrom():
174 state = State.MESSAGE
175 case Abort() if state in (State.SESSION, State.MESSAGE):
176 state = (
177 State.SESSION_ABORTED if state == State.SESSION else
178 State.MESSAGE_ABORTED
179 )
180 case Abort():
181 _logger.warning("Unexpected Abort received")
182 state = State.CONNECTED
183 case Close():
184 tasks.cancel_scope.cancel()
185 return
187 skip_or_cont = await session.deliver(message)
188 if not proto.needs_response(message):
189 continue
190 if (resp := await runner.check_response()):
191 await sender.send(resp)
192 elif self.use_skip:
193 await sender.send(skip_or_cont())
194 else:
195 await sender.send(Continue())
197 async def _negotiate(self, message: Negotiate) -> Negotiate:
198 _logger.info("Negotiating with MTA")
200 optmask = ProtocolFlags.NONE
201 options = \
202 ProtocolFlags.SKIP | \
203 ProtocolFlags.NO_HELO | \
204 ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | \
205 ProtocolFlags.NO_DATA | ProtocolFlags.NO_BODY | \
206 ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_END_OF_HEADERS | \
207 ProtocolFlags.NR_CONNECT | ProtocolFlags.NR_HELO | \
208 ProtocolFlags.NR_SENDER | ProtocolFlags.NR_RECIPIENT | \
209 ProtocolFlags.NR_DATA | ProtocolFlags.NR_BODY | \
210 ProtocolFlags.NR_HEADER | ProtocolFlags.NR_END_OF_HEADERS
211 actions = ActionFlags.NONE
212 macros = defaultdict(set)
214 options &= message.protocol_flags # Remove unoffered initial flags, they are not required
216 for filtr in self.filters:
217 flags = get_flags(filtr)
218 optmask |= flags.unset_options
219 options |= flags.set_options
220 actions |= flags.set_actions
222 for stage, names in get_macros(filtr).items(): 222 ↛ 223line 222 didn't jump to line 223 because the loop on line 222 never started
223 macros[stage].update(names)
225 options &= ~optmask
227 if (missing_actions := actions & ~message.action_flags):
228 raise NegotiationError(f"MTA does not accept {missing_actions}")
230 if (missing_options := options & ~message.protocol_flags):
231 raise NegotiationError(f"MTA does not offer {missing_options}")
233 self.use_skip = ProtocolFlags.SKIP in options
235 return Negotiate(6, actions, options, dict(macros))
238class SessionRunner:
240 def __init__(self, session: Session):
241 self.session = session
242 self.filters = dict[Filter, FinalResponse|None]()
244 async def start(self, filters: Iterable[Filter], task_group: anyio.abc.TaskGroup) -> None:
245 """
246 Run all the given filters in a task group
248 The session MUST have been primed by the delivery of a Connect message beforehand or
249 filters will be unable to access the connection details.
250 """
251 _logger.debug("Starting filters")
252 for flter in filters:
253 await task_group.start(self.run_filter, flter)
255 async def run_filter(
256 self,
257 flter: Filter,
258 task_status: anyio.abc.TaskStatus[None],
259 ) -> None:
260 """
261 Run a filter as a subtask in a task group
263 A `Future` for returning the filter's response is added to the
264 `SessionRunner.filter` dict.
265 """
266 if flter in self.filters: 266 ↛ 267line 266 didn't jump to line 267 because the condition on line 266 was never true
267 raise RuntimeError
268 self.filters[flter] = None
270 async with self.session:
271 assert isinstance(self.session.broadcast, _Broadcast)
272 status_notifiers = self.session.broadcast.task_status
273 status_notifiers.append(task_status)
275 try:
276 resp: FinalResponse = await flter(self.session)
277 except Aborted:
278 _logger.debug(f"Aborted filter {qualname(flter)}")
279 del self.filters[flter]
280 return
281 except Exception:
282 _logger.exception(f"Error in filter {qualname(flter)}")
283 resp = TemporaryFailure()
284 if not isinstance(resp, FinalResponse):
285 warn(f"expected a valid response from {qualname(flter)}, got {resp}") # type: ignore # Don't fully trust users…
286 resp = TemporaryFailure()
287 self.filters[flter] = resp
288 if task_status in status_notifiers:
289 status_notifiers.remove(task_status)
290 task_status.started()
292 async def check_response(self) -> ResponseMessage|None:
293 assert self.filters, "no filters when checking for a response"
294 response: ResponseMessage|None = None
295 complete = list[Filter]()
296 for flter, result in self.filters.items():
297 # If a filter has not finished or no response is expected, continue without
298 # removing from filter container; remove failed filters and filters that have
299 # accepted; return a response for rejections;
300 match result:
301 case None:
302 continue
303 case Accept():
304 _logger.info("Accept from %s, waiting for remaining", qualname(flter))
305 case TemporaryFailure() as response:
306 _logger.warning("Filter failed: %s", flter)
307 case Reject()|Discard()|ReplyCode() as response: 307 ↛ 311line 307 didn't jump to line 311 because the pattern on line 307 always matched
308 _logger.info("Returning response %s from %s", type(response).__name__, qualname(flter))
309 complete[:] = self.filters
310 break
311 case msg:
312 raise AssertionError(f"unexpected filter result: {msg}")
313 complete.append(flter)
314 for flter in complete:
315 del self.filters[flter]
316 return response if response else None if self.filters else Accept()