Commit 0ac76efc authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Reduce encoding/decoding of exec arguments

- Prefer constant bytes strings to constant unicode strings
- Allow both byte- & unicode- strings & path-like objects as arguments
- Pass path-like values and byte strings unchanged

Prefer passing raw byte strings to process spawner
parent f524e5a5
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -91,4 +91,5 @@ repos:
    additional_dependencies:
    - types-orjson
    - types-requests
    - trio-typing ~=0.6
    - git+https://code.kodo.org.uk/dom/type-stubs.git#type-stubs[jsonpath,parse]
+45 −34
Original line number Diff line number Diff line
@@ -15,6 +15,8 @@ import ipaddress
import json
import logging
from contextlib import contextmanager
from os import PathLike
from os import fspath
from pathlib import Path
from secrets import token_hex
from subprocess import DEVNULL
@@ -33,6 +35,7 @@ from typing import SupportsBytes
from typing import Tuple
from typing import TypeVar
from typing import Union
from typing import cast
from typing import overload

from .binaries import DownloadableDocker
@@ -43,23 +46,22 @@ from .proc import Arguments
from .proc import Deserialiser
from .proc import Environ
from .proc import MutableArguments
from .proc import PathLike
from .proc import coerce_args
from .proc import exec_io

HostMount = tuple[PathLike, PathLike]
NamedMount = tuple[str, PathLike]
AnonMount = PathLike
Mount = Union[HostMount, NamedMount, AnonMount]
MountPath = Union[PathLike[bytes], PathLike[str]]
HostMount = tuple[MountPath, MountPath]
NamedMount = tuple[str, MountPath]
Mount = Union[HostMount, NamedMount, MountPath]
Volumes = Iterable[Mount]


try:
	run(["docker", "version"], stdout=DEVNULL)
	run([b"docker", b"version"], stdout=DEVNULL)
except FileNotFoundError:
	DOCKER = DownloadableDocker().get_binary().as_posix()
	DOCKER: Argument = DownloadableDocker().get_binary()
else:
	DOCKER = "docker"
	DOCKER = b"docker"


def utf8_decode(buffer: bytes) -> str:
@@ -73,14 +75,14 @@ def docker(*args: Argument, **env: str) -> None:
	"""
	Run a Docker command, with output going to stdout
	"""
	run([DOCKER, *coerce_args(args)], env=env, check=True)
	run([DOCKER, *args], env=env, check=True)


def docker_output(*args: Argument, **env: str) -> str:
	"""
	Run a Docker command, capturing and returning its stdout
	"""
	proc = run([DOCKER, *coerce_args(args)], env=env, check=True, stdout=PIPE, text=True)
	proc = run([DOCKER, *args], env=env, check=True, stdout=PIPE, text=True)
	return proc.stdout.strip()


@@ -88,7 +90,7 @@ def docker_quiet(*args: Argument, **env: str) -> None:
	"""
	Run a Docker command, directing its stdout to /dev/null
	"""
	run([DOCKER, *coerce_args(args)], env=env, check=True, stdout=DEVNULL)
	run([DOCKER, *args], env=env, check=True, stdout=DEVNULL)


class IPv4Address(ipaddress.IPv4Address):
@@ -152,7 +154,7 @@ class Image(Item):
		arguments from external lookups without complex argument composing.
		"""
		cmd: Arguments = [
			'build', context, f"--target={target}",
			b"build", context, f"--target={target}",
			*(f"--build-arg={arg}={val}" for arg, val in build_args.items() if val is not None),
		]
		docker(*cmd, DOCKER_BUILDKIT='1')
@@ -164,7 +166,7 @@ class Image(Item):
		"""
		Pull an image from a registry
		"""
		docker('pull', repository)
		docker(b"pull", repository)
		iid = Item(repository).inspect().path('$.Id', str)
		return cls(iid)

@@ -254,25 +256,34 @@ class Container(Item):
			return self.cid

		opts: MutableArguments = [
			"--network=none",
			*(
				(f"--volume={vol[0]}:{vol[1]}" if isinstance(vol, tuple) else f"--volume={vol}")
				for vol in self.volumes
			),
			b"--network=none",
			*(f"--env={name}={val}" for name, val in self.env.items()),
		]

		for vol in self.volumes:
			if isinstance(vol, tuple):
				src = fspath(vol[0])
				dst = fspath(vol[1])
				if isinstance(src, bytes):
					src = src.decode()
				if isinstance(dst, bytes):
					dst = dst.decode()
				arg: Argument = f"{src}:{dst}"
			else:
				arg = vol
			opts.extend((b"--volume", arg))

		if self.entrypoint:
			opts.append(f"--entrypoint={self.entrypoint}")
			opts.extend((b"--entrypoint", self.entrypoint))

		if self.privileged:
			opts.append("--privileged")
			opts.append(b"--privileged")

		self.cid = docker_output('container', 'create', *opts, self.image.iid, *self.cmd)
		self.cid = docker_output(b"container", b"create", *opts, self.image.iid, *self.cmd)
		assert self.cid

		# Disconnect the "none" network specified as the starting network
		docker_quiet("network", "disconnect", "none", self.cid)
		docker_quiet(b"network", b"disconnect", b"none", self.cid)

		return self.cid

@@ -282,7 +293,7 @@ class Container(Item):
		"""
		if self.is_running():
			return
		docker_quiet('container', 'start', self.get_id())
		docker_quiet(b"container", b"start", self.get_id())

	def stop(self, rm: bool = False) -> None:
		"""
@@ -292,10 +303,10 @@ class Container(Item):
			return
		try:
			if self.is_running():
				docker_quiet('container', 'stop', self.cid)
				docker_quiet(b"container", b"stop", self.cid)
		finally:
			if rm:
				docker_quiet('container', 'rm', self.cid)
				docker_quiet(b"container", b"rm", self.cid)
				self.cid = None

	def connect(
@@ -323,8 +334,8 @@ class Container(Item):
		if network in self.networks:
			if self.networks[network] == aliases:
				return
			docker('network', 'disconnect', str(network), cid)
		docker('network', 'connect', *opts, str(network), cid)
			docker(b"network", b"disconnect", str(network), cid)
		docker(b"network", b"connect", *opts, str(network), cid)
		self.networks[network] = aliases

	def show_logs(self) -> None:
@@ -332,13 +343,13 @@ class Container(Item):
		Print the container logs to stdout
		"""
		if self.cid:
			docker('logs', self.cid)
			docker(b"logs", self.cid)

	def get_exec_args(self, cmd: Arguments, interactive: bool = False) -> MutableArguments:
		"""
		Return a full argument list for running "cmd" inside the container
		"""
		return [DOCKER, "exec", *(("-i",) if interactive else ""), self.get_id(), *coerce_args(cmd)]
		return [DOCKER, b"exec", *((b"-i",) if interactive else []), self.get_id(), *cmd]

	def run(
		self,
@@ -436,12 +447,12 @@ class Network(Item):
		gateway = next(subnet.hosts())
		try:
			self._nid = docker_output(
				"network", "create", self._name,
				b"network", b"create", self._name,
				f"--subnet={subnet}", f"--gateway={gateway}",
			)
		except CalledProcessError:
			data = exec_io(
				[DOCKER, "network", "inspect", self._name],
				[DOCKER, b"network", b"inspect", self._name],
				deserialiser=JSONArray.from_string,
			)
			if len(data) == 0:
@@ -454,7 +465,7 @@ class Network(Item):
		Remove the network
		"""
		if self._nid:
			docker_quiet("network", "rm", self._nid)
			docker_quiet(b"network", b"rm", self._nid)

	@classmethod
	def get_free_subnet(cls) -> ipaddress.IPv4Network:
@@ -462,11 +473,11 @@ class Network(Item):
		Return a free private subnet
		"""
		networks = exec_io(
			[DOCKER, "network", "ls", "--format={{.ID}}"],
			[DOCKER, b"network", b"ls", b"--format={{.ID}}"],
			deserialiser=utf8_decode,
		).splitlines()
		subnets = exec_io(
			[DOCKER, "network", "inspect"] + networks,
			[DOCKER, b"network", b"inspect"] + cast(list[Argument], networks),
			deserialiser=JSONArray.from_string,
		).path(
			"$[*].IPAM.Config[*].Subnet", list[str],
+8 −5
Original line number Diff line number Diff line
@@ -11,8 +11,10 @@ Manage processes asynchronously with stdio capture
from __future__ import annotations

import io
import os
import sys
from os import PathLike
from os import fspath
from os import write as fdwrite
from subprocess import DEVNULL
from subprocess import PIPE
from typing import IO
@@ -31,8 +33,7 @@ import trio.abc
T = TypeVar('T')
Deserialiser = Callable[[memoryview], T]

PathLike = os.PathLike[str]
Argument = Union[PathLike, str]
Argument = Union[str, bytes, PathLike[str], PathLike[bytes]]
PathArg = Argument  # deprecated
Arguments = Sequence[Argument]
MutableArguments = MutableSequence[Argument]
@@ -43,7 +44,9 @@ def coerce_args(args: Arguments) -> Iterator[str]:
	"""
	Ensure path-like arguments are converted to strings
	"""
	return (os.fspath(a) for a in args)
	for arg in args:
		arg = fspath(arg)
		yield arg if isinstance(arg, str) else arg.decode()


@overload
@@ -144,7 +147,7 @@ async def _passthru(in_stream: trio.abc.ReceiveStream, out_stream: IO[str]|IO[by
			remaining = len(data)
			while remaining:
				await trio.lowlevel.wait_writable(out_stream)
				written = os.write(out_stream, data)
				written = fdwrite(out_stream, data)
				data = data[written:]
				remaining -= written