Coverage for kilter/service/runner.py: 95.62%

187 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-07 13:26 +0000

1# Copyright 2022-2023, 2025 Dominik Sekotill <dom.sekotill@kodo.org.uk> 

2# 

3# This Source Code Form is subject to the terms of the Mozilla Public 

4# License, v. 2.0. If a copy of the MPL was not distributed with this 

5# file, You can obtain one at http://mozilla.org/MPL/2.0/. 

6 

7""" 

8Coordinate receiving and sending raw messages with a filter and Session object 

9 

10The primary class in this module (`Runner`) is intended to be used with an 

11`anyio.abc.Listener`, which can be obtained, for instance, from 

12`anyio.create_tcp_listener()`. 

13""" 

14 

15from __future__ import annotations 

16 

17import enum 

18import logging 

19from collections import defaultdict 

20from collections.abc import Iterable 

21from typing import Final 

22from typing import TypeAlias 

23from warnings import warn 

24 

25import anyio.abc 

26from async_generator import aclosing 

27 

28from kilter.protocol.buffer import SimpleBuffer 

29from kilter.protocol.core import EventMessage 

30from kilter.protocol.core import FilterMessage 

31from kilter.protocol.core import FilterProtocol 

32from kilter.protocol.core import ResponseMessage 

33from kilter.protocol.messages import * 

34 

35from .options import get_flags 

36from .options import get_macros 

37from .session import Aborted 

38from .session import Filter 

39from .session import FilterResponse 

40from .session import Session 

41from .util import Broadcast 

42from .util import qualname 

43 

44__all__ = [ 

45 "Runner", 

46 "NegotiationError", 

47] 

48 

49FinalResponse: TypeAlias = FilterResponse | TemporaryFailure 

50 

51kiB: Final = 2**10 

52MiB: Final = 2**20 

53 

54_logger = logging.getLogger(__package__) 

55 

56 

57class NegotiationError(Exception): 

58 """ 

59 An error raised when MTAs are not compatible with the filter 

60 """ 

61 

62 

63class State(enum.Enum): 

64 

65 CONNECTED = enum.auto() 

66 SESSION = enum.auto() 

67 SESSION_ABORTED = enum.auto() 

68 MESSAGE = enum.auto() 

69 MESSAGE_ABORTED = enum.auto() 

70 

71 

72class _Broadcast(Broadcast[EventMessage]): 

73 

74 def __init__(self) -> None: 

75 super().__init__() 

76 self.task_status = list[anyio.abc.TaskStatus[None]]() 

77 

78 async def shutdown_hook(self) -> None: 

79 await self.pre_receive_hook() 

80 

81 async def pre_receive_hook(self) -> None: 

82 while self.task_status: 

83 self.task_status.pop().started() 

84 

85 

86class Sender: 

87 """ 

88 Concrete implementation of `kilter.service.session.Sender` 

89 """ 

90 

91 def __init__(self, client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> None: 

92 self.client = client 

93 self.proto = proto 

94 

95 async def send(self, message: FilterMessage) -> None: 

96 """ 

97 Encode and send a message to the client stream 

98 """ 

99 buffer = SimpleBuffer(1*kiB) 

100 self.proto.write_to(buffer, message) 

101 await self.client.send(buffer[:]) 

102 if __debug__: 

103 _logger.debug(f"sent: {message}") 

104 

105 

106class Runner: 

107 """ 

108 A filter runner that coordinates passing data between a stream and multiple filters 

109 

110 Instances can be used as handlers that can be passed to `anyio.abc.Listener.serve()` or 

111 used with any `anyio.abc.ByteStream`. 

112 """ 

113 

114 def __init__(self, *filters: Filter): 

115 if len(filters) == 0: # pragma: no-cover 

116 raise TypeError("Runner requires at least one filter to run") 

117 self.filters = set(filters) 

118 if len(filters) != len(self.filters): 

119 warn("Repeated filters will only be run once", stacklevel=2) 

120 self.use_skip = True 

121 

122 async def __call__(self, client: anyio.abc.ByteStream) -> None: 

123 """ 

124 Return an awaitable that starts and coordinates filters 

125 """ 

126 buff = SimpleBuffer(1*MiB) 

127 proto = FilterProtocol(abort_on_unknown=True) 

128 sender = Sender(client, proto) 

129 session = Session(sender, _Broadcast()) 

130 runner = SessionRunner(session) 

131 state = State.CONNECTED 

132 

133 async with ( 

134 aclosing(client), 

135 anyio.create_task_group() as tasks, 

136 ): 

137 while 1: 

138 try: 

139 buff[:] = await client.receive(buff.available) 

140 except ( 

141 anyio.EndOfStream, 

142 anyio.ClosedResourceError, 

143 anyio.BrokenResourceError, 

144 ): 

145 return 

146 for message in proto.read_from(buff): 

147 if __debug__: 

148 _logger.debug(f"received: {message}") 

149 

150 # If previous message was Abort, restart filters for any non-Abort/Close 

151 # message 

152 if state in (State.SESSION_ABORTED, State.MESSAGE_ABORTED): 

153 if not isinstance(message, Abort|Close): 

154 await runner.start(self.filters, tasks) 

155 state = ( 

156 State.CONNECTED if state == State.SESSION_ABORTED else 

157 State.SESSION 

158 ) 

159 

160 match message: 

161 case Negotiate(): 

162 await sender.send(await self._negotiate(message)) 

163 continue 

164 case Connect(): 

165 _logger.info(f"Client connected from {message.hostname}") 

166 await session.deliver(message) 

167 await runner.start(self.filters, tasks) 

168 if proto.needs_response(message): 168 ↛ 170line 168 didn't jump to line 170 because the condition on line 168 was always true

169 await sender.send(await runner.check_response() or Continue()) 

170 continue 

171 case Helo(): 

172 state = State.SESSION 

173 case EnvelopeFrom(): 

174 state = State.MESSAGE 

175 case Abort() if state in (State.SESSION, State.MESSAGE): 

176 state = ( 

177 State.SESSION_ABORTED if state == State.SESSION else 

178 State.MESSAGE_ABORTED 

179 ) 

180 case Abort(): 

181 _logger.warning("Unexpected Abort received") 

182 state = State.CONNECTED 

183 case Close(): 

184 tasks.cancel_scope.cancel() 

185 return 

186 

187 skip_or_cont = await session.deliver(message) 

188 if not proto.needs_response(message): 

189 continue 

190 if (resp := await runner.check_response()): 

191 await sender.send(resp) 

192 elif self.use_skip: 

193 await sender.send(skip_or_cont()) 

194 else: 

195 await sender.send(Continue()) 

196 

197 async def _negotiate(self, message: Negotiate) -> Negotiate: 

198 _logger.info("Negotiating with MTA") 

199 

200 optmask = ProtocolFlags.NONE 

201 options = \ 

202 ProtocolFlags.SKIP | \ 

203 ProtocolFlags.NO_HELO | \ 

204 ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | \ 

205 ProtocolFlags.NO_DATA | ProtocolFlags.NO_BODY | \ 

206 ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_END_OF_HEADERS | \ 

207 ProtocolFlags.NR_CONNECT | ProtocolFlags.NR_HELO | \ 

208 ProtocolFlags.NR_SENDER | ProtocolFlags.NR_RECIPIENT | \ 

209 ProtocolFlags.NR_DATA | ProtocolFlags.NR_BODY | \ 

210 ProtocolFlags.NR_HEADER | ProtocolFlags.NR_END_OF_HEADERS 

211 actions = ActionFlags.NONE 

212 macros = defaultdict(set) 

213 

214 options &= message.protocol_flags # Remove unoffered initial flags, they are not required 

215 

216 for filtr in self.filters: 

217 flags = get_flags(filtr) 

218 optmask |= flags.unset_options 

219 options |= flags.set_options 

220 actions |= flags.set_actions 

221 

222 for stage, names in get_macros(filtr).items(): 222 ↛ 223line 222 didn't jump to line 223 because the loop on line 222 never started

223 macros[stage].update(names) 

224 

225 options &= ~optmask 

226 

227 if (missing_actions := actions & ~message.action_flags): 

228 raise NegotiationError(f"MTA does not accept {missing_actions}") 

229 

230 if (missing_options := options & ~message.protocol_flags): 

231 raise NegotiationError(f"MTA does not offer {missing_options}") 

232 

233 self.use_skip = ProtocolFlags.SKIP in options 

234 

235 return Negotiate(6, actions, options, dict(macros)) 

236 

237 

238class SessionRunner: 

239 

240 def __init__(self, session: Session): 

241 self.session = session 

242 self.filters = dict[Filter, FinalResponse|None]() 

243 

244 async def start(self, filters: Iterable[Filter], task_group: anyio.abc.TaskGroup) -> None: 

245 """ 

246 Run all the given filters in a task group 

247 

248 The session MUST have been primed by the delivery of a Connect message beforehand or 

249 filters will be unable to access the connection details. 

250 """ 

251 _logger.debug("Starting filters") 

252 for flter in filters: 

253 await task_group.start(self.run_filter, flter) 

254 

255 async def run_filter( 

256 self, 

257 flter: Filter, 

258 task_status: anyio.abc.TaskStatus[None], 

259 ) -> None: 

260 """ 

261 Run a filter as a subtask in a task group 

262 

263 A `Future` for returning the filter's response is added to the 

264 `SessionRunner.filter` dict. 

265 """ 

266 if flter in self.filters: 266 ↛ 267line 266 didn't jump to line 267 because the condition on line 266 was never true

267 raise RuntimeError 

268 self.filters[flter] = None 

269 

270 async with self.session: 

271 assert isinstance(self.session.broadcast, _Broadcast) 

272 status_notifiers = self.session.broadcast.task_status 

273 status_notifiers.append(task_status) 

274 

275 try: 

276 resp: FinalResponse = await flter(self.session) 

277 except Aborted: 

278 _logger.debug(f"Aborted filter {qualname(flter)}") 

279 del self.filters[flter] 

280 return 

281 except Exception: 

282 _logger.exception(f"Error in filter {qualname(flter)}") 

283 resp = TemporaryFailure() 

284 if not isinstance(resp, FinalResponse): 

285 warn(f"expected a valid response from {qualname(flter)}, got {resp}") # type: ignore # Don't fully trust users… 

286 resp = TemporaryFailure() 

287 self.filters[flter] = resp 

288 if task_status in status_notifiers: 

289 status_notifiers.remove(task_status) 

290 task_status.started() 

291 

292 async def check_response(self) -> ResponseMessage|None: 

293 assert self.filters, "no filters when checking for a response" 

294 response: ResponseMessage|None = None 

295 complete = list[Filter]() 

296 for flter, result in self.filters.items(): 

297 # If a filter has not finished or no response is expected, continue without 

298 # removing from filter container; remove failed filters and filters that have 

299 # accepted; return a response for rejections; 

300 match result: 

301 case None: 

302 continue 

303 case Accept(): 

304 _logger.info("Accept from %s, waiting for remaining", qualname(flter)) 

305 case TemporaryFailure() as response: 

306 _logger.warning("Filter failed: %s", flter) 

307 case Reject()|Discard()|ReplyCode() as response: 307 ↛ 311line 307 didn't jump to line 311 because the pattern on line 307 always matched

308 _logger.info("Returning response %s from %s", type(response).__name__, qualname(flter)) 

309 complete[:] = self.filters 

310 break 

311 case msg: 

312 raise AssertionError(f"unexpected filter result: {msg}") 

313 complete.append(flter) 

314 for flter in complete: 

315 del self.filters[flter] 

316 return response if response else None if self.filters else Accept()