Verified Commit 2e28c259 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Improve default filter options for negotiation

parent 5f09425f
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -31,7 +31,12 @@ SIZES = Literal[ProtocolFlags.NONE, ProtocolFlags.MDS_256K, ProtocolFlags.MDS_1M

FLAGS_ATTRIBUTE = "filter_flags"

NR_FLAGS = \
DEFAULT_UNSET = \
	ProtocolFlags.NO_CONNECT | ProtocolFlags.NO_HELO | \
	ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | \
	ProtocolFlags.NO_DATA | ProtocolFlags.NO_BODY | \
	ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_END_OF_HEADERS | \
	ProtocolFlags.NO_UNKNOWN | \
	ProtocolFlags.NR_CONNECT | ProtocolFlags.NR_HELO | \
	ProtocolFlags.NR_SENDER | ProtocolFlags.NR_RECIPIENT | \
	ProtocolFlags.NR_DATA | ProtocolFlags.NR_BODY | \
@@ -70,7 +75,7 @@ def get_flags(filtr: Filter) -> FlagsTuple:
	"""
	Return the flags attached to a filter
	"""
	default = FlagsTuple(unset_options=NR_FLAGS, set_actions=ActionFlags.ALL)
	default = FlagsTuple(unset_options=DEFAULT_UNSET, set_actions=ActionFlags.ALL)
	return _get_flags(filtr, default)


+34 −7
Original line number Diff line number Diff line
@@ -133,7 +133,17 @@ class Runner:
							await self._prepare_filters(message, sender, runner)
							if macro:
								await runner.set_macros(macro)
							await sender.asend(await runner.start(True, self.use_skip))
							needs_response = proto.needs_response(message)
							match await runner.start(needs_response, True, self.use_skip):
								case None:
									assert not needs_response
								case _CloseFilter() as notif:
									self.filters.remove(notif.filter)
								case c_resp if needs_response:
									assert c_resp is not None and not isinstance(c_resp, _CloseFilter)
									await sender.asend(c_resp)
								case c_resp:
									raise RuntimeError(f"unexpected response: {c_resp}")
						case Abort():
							aborted = True
							await runner.abort(message)
@@ -143,7 +153,7 @@ class Runner:
						case _:
							if aborted:
								aborted = False
								await runner.start(False, self.use_skip)
								await runner.start(True, False, self.use_skip)
							needs_response = proto.needs_response(message)
							match await runner.message_events(message, needs_response):
								case None:
@@ -161,7 +171,12 @@ class Runner:

		optmask = ProtocolFlags.NONE
		options = \
			ProtocolFlags.SKIP | ProtocolFlags.NR_HELO | \
			ProtocolFlags.SKIP | \
			ProtocolFlags.NO_HELO | \
			ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | \
			ProtocolFlags.NO_DATA | ProtocolFlags.NO_BODY | \
			ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_END_OF_HEADERS | \
			ProtocolFlags.NR_CONNECT | ProtocolFlags.NR_HELO | \
			ProtocolFlags.NR_SENDER | ProtocolFlags.NR_RECIPIENT | \
			ProtocolFlags.NR_DATA | ProtocolFlags.NR_BODY | \
			ProtocolFlags.NR_HEADER | ProtocolFlags.NR_END_OF_HEADERS
@@ -215,7 +230,12 @@ class _TaskRunner:
	def add_filter(self, flter: Filter, session: Session, /) -> None:
		self.filters.append((flter, session))

	async def start(self, first_connect: bool, use_skip: bool) -> ResponseMessage:
	async def start(
		self,
		needs_response: bool,
		first_connect: bool,
		use_skip: bool,
	) -> ResponseMessage|_CloseFilter|None:
		if self.channels:
			raise RuntimeError(f"{self} is already running tasks")
		final: ResponseMessage = Accept()
@@ -236,9 +256,16 @@ class _TaskRunner:
							f"{qualname(flter)} -> {resp}",
						)
						continue
					if not needs_response:
						_logger.warning(
							f"Unexpected response from filter {flter}",
						)
						return _CloseFilter(flter)
					return resp
				case _ as arg:  # pragma: no-cover
					raise TypeError(f"task_status.started called with bad type: {arg!r}")
		if not needs_response:
			return None
		return final if len(self.channels) == 0 else Continue()

	async def set_macros(self, message: Macro) -> None:
@@ -267,11 +294,11 @@ class _TaskRunner:
				case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp:
					await self.close_channel(channel)
					if not needs_response:
						filtr = self.channels[channel]
						flter = self.channels[channel]
						_logger.warning(
							f"Unexpected response from filter {self.channels[channel]}",
							f"Unexpected response from filter {flter}",
						)
						return _CloseFilter(filtr)
						return _CloseFilter(flter)
					return resp
		if not needs_response:
			return None