Verified Commit 6c4557b1 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Fix GH-1 by sending the correct index value for ChangeHeader

`ChangeHeader` index values are for differeniating between headers with
the same name. They are not the absolute index position of the header
within the full header list. They are also 1-based.

https://pythonhosted.org/pymilter/milter_api/smfi_chgheader.html
parent 42ff069e
Loading
Loading
Loading
Loading
+17 −4
Original line number Diff line number Diff line
# Copyright 2022-2023 Dominik Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2022-2024 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -12,6 +12,7 @@ from __future__ import annotations

from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from ipaddress import IPv4Address
@@ -404,9 +405,9 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		"""
		await self.collect()
		await _until_editable(self.session)
		index = self._table.index(header)
		index = _index_by_name(self._table, header)
		await self._editor.asend(ChangeHeader(index, header.name, b""))
		del self._table[index]
		self._table.remove(header)

	async def update(self, header: Header, value: bytes) -> None:
		"""
@@ -414,8 +415,9 @@ class HeadersAccessor(AsyncContextManager["HeaderIterator"]):
		"""
		await self.collect()
		await _until_editable(self.session)
		index = self._table.index(header)
		index = _index_by_name(self._table, header)
		await self._editor.asend(ChangeHeader(index, header.name, value))
		index = self._table.index(header)
		self._table[index].value = value

	async def insert(self, header: Header, position: Position) -> None:
@@ -546,3 +548,14 @@ async def _until_editable(session: Session) -> None:
	session.body.skip = True
	while session.phase < Phase.POST:
		await session.broadcast.receive()


def _index_by_name(table: Sequence[Header], needle: Header) -> int:
	index = 0
	name = needle.name.lower()
	for header in table:
		if header == needle:
			return index + 1
		if header.name.lower() == name:
			index += 1
	raise ValueError(f"header not found: {needle}")
+2 −2
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ChangeHeader(1, "Spam", b""))
		sender._asend.assert_awaited_with(ChangeHeader(2, "Spam", b""))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Eggs", b"and spam"),
@@ -220,7 +220,7 @@ class HeaderAccessorTests(AsyncTestCase):
			await session.deliver(EndOfHeaders())
			await session.deliver(EndOfMessage(b""))

		sender._asend.assert_awaited_with(ChangeHeader(1, "Spam", b"no spam!"))
		sender._asend.assert_awaited_with(ChangeHeader(2, "Spam", b"no spam!"))
		assert result == [
			Header("Spam", b"spam spam spam"),
			Header("Spam", b"no spam!"),