Loading kilter/service/runner.py +37 −5 Original line number Diff line number Diff line Loading @@ -57,6 +57,12 @@ class NegotiationError(Exception): """ class _CloseFilter: def __init__(self, filtr: Filter): self.filter = filtr class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: Loading @@ -83,7 +89,7 @@ class Runner: def __init__(self, *filters: Filter): if len(filters) == 0: # pragma: no-cover raise TypeError("Runner requires at least one filter to run") self.filters = filters self.filters = list(filters) self.use_skip = True async def __call__(self, client: anyio.abc.ByteStream) -> None: Loading Loading @@ -138,7 +144,17 @@ class Runner: if aborted: aborted = False await runner.start(False, self.use_skip) await sender.asend(await runner.message_events(message)) needs_response = proto.needs_response(message) match await runner.message_events(message, needs_response): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case resp if needs_response: assert resp is not None and not isinstance(resp, _CloseFilter) await sender.asend(resp) case resp: raise RuntimeError(f"unexpected response: {resp}") async def _negotiate(self, message: Negotiate) -> Negotiate: _logger.info("Negotiating with MTA") Loading Loading @@ -233,7 +249,11 @@ class _TaskRunner: for _, session in self.filters: await session.deliver(message) async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage|Skip: async def message_events( self, message: _VALID_EVENT_MESSAGE, needs_response: bool, ) -> ResponseMessage|Skip|_CloseFilter|None: skip = isinstance(message, Body) for channel in list(self.channels): await channel.send(message) Loading @@ -243,16 +263,28 @@ class _TaskRunner: case Continue(): skip = False case Accept(): await channel.aclose() del self.channels[channel] await self.close_channel(channel) case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: await self.close_channel(channel) if not needs_response: filtr = self.channels[channel] _logger.warning( f"Unexpected response from filter {self.channels[channel]}", ) return _CloseFilter(filtr) return resp if not needs_response: return None return ( Accept() if len(self.channels) == 0 else Skip() if skip else Continue() ) async def close_channel(self, channel: MessageChannel) -> None: await channel.aclose() del self.channels[channel] async def abort(self, abort: Abort) -> None: if self.channels: _logger.info("Aborting filters") Loading Loading
kilter/service/runner.py +37 −5 Original line number Diff line number Diff line Loading @@ -57,6 +57,12 @@ class NegotiationError(Exception): """ class _CloseFilter: def __init__(self, filtr: Filter): self.filter = filtr class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: Loading @@ -83,7 +89,7 @@ class Runner: def __init__(self, *filters: Filter): if len(filters) == 0: # pragma: no-cover raise TypeError("Runner requires at least one filter to run") self.filters = filters self.filters = list(filters) self.use_skip = True async def __call__(self, client: anyio.abc.ByteStream) -> None: Loading Loading @@ -138,7 +144,17 @@ class Runner: if aborted: aborted = False await runner.start(False, self.use_skip) await sender.asend(await runner.message_events(message)) needs_response = proto.needs_response(message) match await runner.message_events(message, needs_response): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case resp if needs_response: assert resp is not None and not isinstance(resp, _CloseFilter) await sender.asend(resp) case resp: raise RuntimeError(f"unexpected response: {resp}") async def _negotiate(self, message: Negotiate) -> Negotiate: _logger.info("Negotiating with MTA") Loading Loading @@ -233,7 +249,11 @@ class _TaskRunner: for _, session in self.filters: await session.deliver(message) async def message_events(self, message: _VALID_EVENT_MESSAGE) -> ResponseMessage|Skip: async def message_events( self, message: _VALID_EVENT_MESSAGE, needs_response: bool, ) -> ResponseMessage|Skip|_CloseFilter|None: skip = isinstance(message, Body) for channel in list(self.channels): await channel.send(message) Loading @@ -243,16 +263,28 @@ class _TaskRunner: case Continue(): skip = False case Accept(): await channel.aclose() del self.channels[channel] await self.close_channel(channel) case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: await self.close_channel(channel) if not needs_response: filtr = self.channels[channel] _logger.warning( f"Unexpected response from filter {self.channels[channel]}", ) return _CloseFilter(filtr) return resp if not needs_response: return None return ( Accept() if len(self.channels) == 0 else Skip() if skip else Continue() ) async def close_channel(self, channel: MessageChannel) -> None: await channel.aclose() del self.channels[channel] async def abort(self, abort: Abort) -> None: if self.channels: _logger.info("Aborting filters") Loading