Commit 932a1d45 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Merge branch '26-global-mysql-server' into 'release/0.4.x'

Global MySQL server

See merge request !19
parents d471e259 a19f77c0
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ repos:
    - types-orjson
    - types-requests
    - types-urllib3
    - typing-extensions ~=4.0; python_version < "3.10"
    - typing-extensions ~=4.0; python_version < "3.11"
      # https://github.com/python-trio/trio-typing/pull/72
    - trio-typing[mypy] @git+https://github.com/gschaffner/trio-typing.git@fix-takes_callable_and_args-TypeVar-binding
    - xdg ~=5.1
+12 −1
Original line number Diff line number Diff line
@@ -447,6 +447,16 @@ class Container:
		)
		docker(b"network", b"connect", *opts, str(network), contrid)

	def disconnect(self, network: Network) -> None:
		"""
		Disconnect the container from a Docker network

		Raises `KeyError` if the network was not connected to with `Container.connect()`.
		"""
		del self.networks[network]
		if self.cid is not None:
			docker(b"network", b"disconnect", str(network), self.cid)

	def show_logs(self) -> None:
		"""
		Print the container logs to stdout
@@ -476,8 +486,9 @@ class Container:
		Run "cmd" to completion inside the container and return the result
		"""
		self.is_running(raise_on_exit=True)
		interactive = input is not None or stdin is not None
		return run(
			self.get_exec_args(cmd),
			self.get_exec_args(cmd, interactive),
			stdin=stdin, stdout=stdout, stderr=stderr,
			capture_output=capture_output,
			check=check, timeout=timeout, input=input,

behave_utils/init.sql

0 → 100644
+3 −0
Original line number Diff line number Diff line
INSTALL PLUGIN auth_socket SONAME 'auth_socket.so';

ALTER USER 'root'@'localhost' IDENTIFIED WITH auth_socket;
+108 −38
Original line number Diff line number Diff line
@@ -10,13 +10,15 @@ Management and control for MySQL database fixtures

from __future__ import annotations

from contextlib import contextmanager
import atexit
from importlib import resources
from os import environ
from pathlib import Path
from time import sleep
from typing import TYPE_CHECKING
from typing import ClassVar
from typing import Iterator
from typing import Sequence
from typing import TypeVar

from behave import fixture

@@ -31,96 +33,164 @@ from .utils import wait

if TYPE_CHECKING:
	from behave.runner import FeatureContext
	from typing_extensions import Self


INIT_DIRECTORY = Path("/docker-entrypoint-initdb.d")


class Mysql(Container):
class MysqlContainer(Container):
	"""
	Container subclass for a database container
	"""

	if TYPE_CHECKING:
		T = TypeVar('T', bound='Mysql')
	_inst: ClassVar[Self|None] = None

	def __init__(
		self,
		version: str = "latest",
		init_files: Sequence[Path] = [],
		network: Network|None = None,
		name: str = "test-db",
		user: str = "test-db-user",
		password: str|None = None,
	):
		self.name = name
		self.user = user
		self.password = password or make_secret(20)
		volumes: list[Mount] = [(path, INIT_DIRECTORY / path.name) for path in init_files]
		volumes.append(Path("/var/lib/mysql"))
		env = dict(
			MYSQL_DATABASE=name,
			MYSQL_USER=user,
			MYSQL_PASSWORD=self.password,
		)
		Container.__init__(
			self,
			Image.pull(f"mysql/mysql-server:{version}"),
			volumes=volumes,
			env=env,
			network=network,
		)

	@classmethod
	def get_running(cls, version: str = "latest") -> MysqlContainer:
		"""
		Return a running instance of MysqlContainer

		Depending on what is currently running the container may have to be started, which
		is a long operation.
		"""
		if (inst := cls._inst or cls.get_labeled(version)):
			return inst
		with resources.path(__package__, "init.sql") as init:
			cls._inst = self = cls(version, [init])
			self.start()
			sleep(20)
			wait(lambda: self.run(['/healthcheck.sh']).returncode == 0)
			if environ.get("BEHAVE_UTILS_MYSQL_KEEP", "0") == "0":
				atexit.register(self.stop, rm=True)
			return self

	@classmethod
	def get_labeled(cls, version: str) -> Self|None:
		"""
		Return any existing running container matching the given version

		This method will clean up stopped or out-of-date containers. A container is
		considered out-of-date if it is labeled with the requested version but has
		a different SHA-ID to the image available on Docker Hub tagged with that version.
		"""
		# cls._inst = ...


class Mysql:
	"""
	A database instance for test fixtures' use

	If created with a non-`None` 'server' it MUST be a *running* Container instance or
	`ValueError` will be raised.

	If 'server' is `None` a running MysqlContainer instance will be retrieved or created
	using the value of 'version' as a MySQL image tag.
	"""

	def __init__(
		self, *,
		version: str = "latest",
		network: Network|None = None,
		server: Container|None = None,
	):
		if server and not server.is_running():
			raise ValueError(f"{server} is not running")
		self._server = server or MysqlContainer.get_running(version)
		self._network = network

		self.name = f"behave-{make_secret(5)}"
		self.user = f"behave-user-{make_secret(5)}"
		self.password = make_secret(20)

	def __enter__(self) -> Self:
		if self._network:
			self._server.connect(self._network)
		self._server.run(
			["mysql"],
			input=f"""
			CREATE DATABASE IF NOT EXISTS `{self.name}`;
			CREATE USER IF NOT EXISTS '{self.user}'@'%'
			  IDENTIFIED BY '{self.password}';
			GRANT ALL ON TABLE `{self.name}`.* TO '{self.user}'@'%';
			""".encode("utf-8"),
			check=True,
		)
		return self

	def __exit__(self, *exc_info: object) -> None:
		if self._network:
			self._server.disconnect(self._network)
		self._server.run(
			["mysql"],
			input=f"""
			DROP USER '{self.user}'@'%';
			DROP DATABASE `{self.name}`;
			""".encode("utf-8"),
		)

	def get_location(self) -> str:
		"""
		Return a "host:port" string for connecting to the database from other containers
		"""
		host = inspect(self).path("$.Config.Hostname", str)
		host = inspect(self._server).path("$.Config.Hostname", str)
		return f"{host}:3306"

	def run_commands(self, sql: str|Path) -> None:
		"""
		Run SQL commands as the superuser on the database, from either strings or files

		This is mostly intended for initialising database fixtures with data.
		"""
		if isinstance(sql, str):
			self.mysql(input=sql, check=True)
			return
		with sql.open("rb") as fh:
			self.mysql(stdin=fh, check=True)

	@property
	def mysql(self) -> Cli:
		"""
		Run "mysql" commands
		"""
		return Cli(self, "mysql")
		return Cli(self._server, "mysql", self.name)

	@property
	def mysqladmin(self) -> Cli:
		"""
		Run "mysqladmin" commands
		"""
		return Cli(self, "mysqladmin")
		return Cli(self._server, "mysqladmin", self.name)

	@property
	def mysqldump(self) -> Cli:
		"""
		Run "mysqldump" commands
		"""
		return Cli(self, "mysqldump")

	@contextmanager
	def started(self: T) -> Iterator[T]:
		"""
		Return a context manager that only enters once the database is initialised
		"""
		with self:
			self.start()
			sleep(20)
			wait(lambda: self.run(['/healthcheck.sh']).returncode == 0)
			yield self
		return Cli(self._server, "mysqldump", self.name)


@fixture
def snapshot_rollback(context: FeatureContext, /, database: Mysql|None = None) -> Iterator[None]:
def snapshot_rollback(context: FeatureContext, /, database: Mysql) -> Iterator[None]:
	"""
	Manage the state of a database as a revertible fixture

	At the end of the fixture's lifetime it's state at the beginning is restored.  This
	allows for faster fixture turn-around than restarting the database.
	"""
	assert database is not None, \
		"'database' is required for snapshot_rollback"
	snapshot = database.mysqldump("--all-databases", deserialiser=bytes)
	snapshot = database.mysqldump(deserialiser=bytes)
	yield
	database.mysql(input=snapshot)
+7 −0
Original line number Diff line number Diff line
@@ -331,6 +331,13 @@ class Executor(_ExecutorBase):
		"""
		assert not deserialiser or not query

		# Check interferes with query, simulate it not being accepted
		if "check" in kwargs:
			raise TypeError(
				f"{self.__class__.__name__}.__call__() got an unexpected keyword "
				"argument 'check'",
			)

		data = (
			b"" if input is None else
			input.encode() if isinstance(input, str) else