Verified Commit a39e66e2 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Add a `Message.freeze()` method as an alternative to `Message.release()`

parent 26684f97
Loading
Loading
Loading
Loading
+38 −0
Original line number Diff line number Diff line
@@ -302,6 +302,17 @@ class Message(metaclass=ABCMeta):
		value that could be a memoryview.
		"""

	def freeze(self) -> None:
		"""
		Similar to `release()`, but memoryviews are replaced with byte object copies

		Users may call this to copy data from a buffer for later reading, for instance if
		they intend to store messages and process them at a later stage.

		Concrete classes MUST implement this if they store references to memoryviews, or any
		value that could be a memoryview.
		"""


class NoDataMessage(Message):
	"""
@@ -350,6 +361,9 @@ class BytesMessage(Message):
	def release(self) -> None:
		self.content.release()

	def freeze(self) -> None:
		self.content = memoryview(self.content.tobytes())


# MTA Setup Commands

@@ -527,6 +541,14 @@ class EnvelopeFrom(Message, ident=b"M"):
			if isinstance(arg, memoryview):
				arg.release()

	def freeze(self) -> None:
		if isinstance(self.sender, memoryview):
			self.sender = self.sender.tobytes()
		self.arguments[:] = (
			arg.tobytes() if isinstance(arg, memoryview) else arg
			for arg in self.arguments
		)


@dataclass
class EnvelopeRecipient(Message, ident=b"R"):
@@ -556,6 +578,14 @@ class EnvelopeRecipient(Message, ident=b"R"):
			if isinstance(arg, memoryview):
				arg.release()

	def freeze(self) -> None:
		if isinstance(self.recipient, memoryview):
			self.recipient = self.recipient.tobytes()
		self.arguments[:] = (
			arg.tobytes() if isinstance(arg, memoryview) else arg
			for arg in self.arguments
		)


class Data(NoDataMessage, ident=b"T"):
	"""
@@ -598,6 +628,10 @@ class Header(Message, ident=b"L"):
		if isinstance(self.value, memoryview):
			self.value.release()

	def freeze(self) -> None:
		if isinstance(self.value, memoryview):
			self.value = self.value.tobytes()


class EndOfHeaders(NoDataMessage, ident=b"N"):
	"""
@@ -758,6 +792,10 @@ class ChangeHeader(Message, ident=b"m"):
		if isinstance(self.value, memoryview):
			self.value.release()

	def freeze(self) -> None:
		if isinstance(self.value, memoryview):
			self.value = self.value.tobytes()


class InsertHeader(ChangeHeader, ident=b"i"):
	"""
+28 −0
Original line number Diff line number Diff line
@@ -265,6 +265,34 @@ class GenericTests(TestCaseMixin, Protocol[T]):
			with self.subTest(args=args, kwargs=kwargs):
				m.release()

	def test_buffer_freeze(self) -> None:
		"""
		Check that freezing a message that holds memoryviews works
		"""
		buf = SimpleBuffer(100)
		for *_, attr, example in self.get_test_values():
			buf[0:] = struct.pack("!lc", len(example) + 1, self.message_ident)
			buf[:] = example
			buf[:] = b"spam"  # needed to make the deletion do a resize
			m, s = messages.Message.unpack(buf)

			with self.subTest(attr=attr):
				m.freeze()
				del buf[:s-5]

				for name, val in attr.items():
					assert getattr(m, name) == val

	def test_buffer_freeze_noop(self) -> None:
		"""
		Check that releasing a message that holds no memoryviews works
		"""
		for args, kwargs, *_ in self.get_test_values():
			m = self.message_class(*args, **kwargs)

			with self.subTest(args=args, kwargs=kwargs):
				m.freeze()


class GenericNoDataTest:
	"""