diff --git a/.gitignore b/.gitignore index 5118d796dadbd0c65de0bfbd5d0acf8f9acb9831..b1b3d1d4378ea5e9f282e85431a1f5b4dd62f0d9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ # Python *.py[co] dist/ + +# Testing +/results/ diff --git a/.gitlab-ci.pre-commit-run.bash b/.gitlab-ci.pre-commit-run.bash new file mode 100644 index 0000000000000000000000000000000000000000..704e716956ec58c44775b2b11d696e71560a6650 --- /dev/null +++ b/.gitlab-ci.pre-commit-run.bash @@ -0,0 +1,58 @@ +# Find a suitable commit for determining changed files +# +# +# Copyright 2022 Dom Sekotill +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +pre_commit_run() ( + set -eu + declare -a PRE_COMMIT_ARGS + + find_lca() { + local repo=$CI_REPOSITORY_URL + local current_branch=$1 other_branch=$2 + + # See https://stackoverflow.com/questions/63878612/git-fatal-error-in-object-unshallow-sha-1 + # and https://stackoverflow.com/questions/4698759/converting-git-repository-to-shallow/53245223#53245223 + # for background on what `git repack -d` is doing here. + git repack -qd + + git fetch -q $repo --shallow-exclude=$other_branch $current_branch + git fetch -q $repo --deepen=1 $current_branch + + FROM_REF=$(git rev-parse -q --revs-only --verify shallow) || unset FROM_REF + } + + fetch_ref() { + git fetch -q $CI_REPOSITORY_URL --depth=1 $1 + FROM_REF=$1 + } + + if [[ -v CI_COMMIT_BEFORE_SHA ]] && [[ ! $CI_COMMIT_BEFORE_SHA =~ ^0{40}$ ]]; then + fetch_ref $CI_COMMIT_BEFORE_SHA + elif [[ -v CI_MERGE_REQUEST_TARGET_BRANCH_NAME ]]; then + find_lca $CI_MERGE_REQUEST_SOURCE_BRANCH_NAME $CI_MERGE_REQUEST_TARGET_BRANCH_NAME + elif [[ $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH ]]; then + find_lca $CI_COMMIT_BRANCH $CI_DEFAULT_BRANCH + fi + + if [[ -v FROM_REF ]]; then + PRE_COMMIT_ARGS=( --from-ref=$FROM_REF --to-ref=$CI_COMMIT_SHA ) + else + PRE_COMMIT_ARGS=( --all-files ) + fi + + pre-commit run "$@" "${PRE_COMMIT_ARGS[@]}" +) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d8ef51e0ce28e40e3181d3412473d4237d0986a0..b89551972f6aeb6505ee60cf9f273a86143e7886 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,86 +1,130 @@ +# Optional project CI variables to set: +# +# SAFETY_API_KEY: +# Set to your API key for accessing up-to-date package security information + stages: -- test - build +- test - publish -image: python:3.9 -variables: - PIP_CACHE_DIR: $CI_PROJECT_DIR/cache/pkg - PRE_COMMIT_HOME: $CI_PROJECT_DIR/cache/pre-commit +workflow: + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_OPEN_MERGE_REQUESTS + when: never + - when: always + -.cached: +.python: + image: python:3.9 + variables: + PIP_CACHE_DIR: $CI_PROJECT_DIR/cache/pkg + PIP_NO_COMPILE: "true" + PIP_NO_CLEAN: "true" cache: key: $CI_JOB_IMAGE paths: [cache] + before_script: + - pip install "pip>=21.3" + + +Build Package: + stage: build + extends: [.python] + script: + - pip install build + - python -m build + artifacts: + paths: [dist] + + +Pin: + # Pin dependencies in requirements.txt for reproducing pipeline results + stage: test + extends: [.python] + needs: [] + script: + - pip install --prefer-binary -e . + - pip freeze --exclude-editable | tee requirements.txt + artifacts: + paths: [requirements.txt] + + +Dependency Check: + stage: test + image: pyupio/safety:latest + needs: [Pin] + allow_failure: true + script: + - safety check -r requirements.txt Code Checks: stage: test - extends: [.cached] image: docker.kodo.org.uk/ci-images/pre-commit:2.15.0-1 + needs: [] variables: - HOOK_STAGE: commit - FROM_REF: $CI_DEFAULT_BRANCH + PRE_COMMIT_HOME: $CI_PROJECT_DIR/cache/pre-commit + cache: + key: $CI_JOB_IMAGE + paths: [cache] rules: - - if: $CI_PIPELINE_SOURCE == "push" && $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - variables: - FROM_REF: $CI_COMMIT_BEFORE_SHA - if: $CI_PIPELINE_SOURCE == "push" - if: $CI_PIPELINE_SOURCE == "merge_request_event" - variables: - HOOK_STAGE: merge-commit script: - - git fetch $CI_REPOSITORY_URL $FROM_REF:FROM_REF -f - - pre-commit run - --hook-stage=$HOOK_STAGE - --from-ref=FROM_REF - --to-ref=HEAD + - source .gitlab-ci.pre-commit-run.bash + - pre_commit_run --hook-stage=commit + - pre_commit_run --hook-stage=push -Commit Graph Check: - extends: ["Code Checks"] - variables: - HOOK_STAGE: push - rules: - - if: $CI_PIPELINE_SOURCE == "push" && $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - variables: - FROM_REF: $CI_COMMIT_BEFORE_SHA - - if: $CI_PIPELINE_SOURCE == "merge_request_event" +Unit Tests: + stage: test + extends: [.python] + needs: [Pin] + script: + - pip install -r requirements.txt -e . coverage[toml] nose + - coverage run -m nose tests + --verbose + --with-xunit --xunit-file=results/xunit.xml + after_script: + - coverage report + - coverage json + - coverage xml + - coverage html + coverage: '/^TOTAL .* (\d{1,3}\.\d{2})%$/' + artifacts: + paths: [results] + reports: + cobertura: results/coverage.xml + junit: results/xunit.xml Check Tag: stage: test - extends: [.cached] + extends: [.python] + needs: ["Build Package"] rules: - if: $CI_COMMIT_TAG =~ /^v[0-9]/ script: - - pip install tomli packaging + - pip install packaging pkginfo - | python <<-END - import tomli + from glob import glob from packaging.version import Version + from pkginfo import Wheel - with open("pyproject.toml", "rb") as f: - proj = tomli.load(f) - - assert Version("$CI_COMMIT_TAG") == Version(proj["tool"]["poetry"]["version"]) + wheel_path = glob("dist/*.whl")[0] + wheel = Wheel(wheel_path) + assert Version("$CI_COMMIT_TAG") == Version(wheel.version) END -Build Package: - stage: build - extends: [.cached] - script: - - pip install build - - python -m build - artifacts: - paths: [dist] - - Upload Package: stage: publish - extends: [.cached] + extends: [.python] + needs: ["Build Package"] rules: - if: $CI_COMMIT_TAG =~ /^v[0-9]/ script: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 557c8dd3434839159d2d72fb3ef5feab05cab320..c293559c7bffb8cb5e8e2d188f150fa6033fe714 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,11 +35,13 @@ repos: - id: gitlint - repo: https://code.kodo.org.uk/dom/pre-commit-hooks - rev: v0.6 + rev: v0.6.1 hooks: - id: check-executable-modes - id: check-for-squash - id: copyright-notice + args: [--min-size=100] + stages: [commit, manual] - id: protect-first-parent - repo: https://github.com/pre-commit/pygrep-hooks @@ -73,10 +75,10 @@ repos: types_or: [python, pyi] stages: [commit, manual] -- repo: https://github.com/domsekotill/flakehell - rev: 5a7ecdc +- repo: https://github.com/flakeheaven/flakeheaven + rev: 0.11.0 hooks: - - id: flakehell + - id: flakeheaven additional_dependencies: - flake8-bugbear - flake8-docstrings @@ -85,10 +87,15 @@ repos: - flake8-sfs - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v0.942 hooks: - id: mypy + args: [--follow-imports=silent] additional_dependencies: + - packaging >=21 - types-orjson - types-requests + - types-urllib3 + - trio-typing[mypy] ~=0.6 + - xdg ~=5.1 - git+https://code.kodo.org.uk/dom/type-stubs.git#type-stubs[jsonpath,parse] diff --git a/behave_utils/binaries.py b/behave_utils/binaries.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc2dfdbf3a3ebf0e14b10644e2d395d501031ad --- /dev/null +++ b/behave_utils/binaries.py @@ -0,0 +1,300 @@ +# Copyright 2021,2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Start and manage a test Kubernetes cluster with Kubernetes-in-Docker (kind) +""" + +from __future__ import annotations + +import platform +import re +from abc import ABC +from abc import abstractmethod +from io import BytesIO +from pathlib import Path +from shutil import copyfileobj +from tarfile import TarFile +from typing import IO +from typing import Iterator + +import requests +import xdg +from packaging.version import Version + +from behave_utils.json import JSONObject +from behave_utils.url import URL + +CACHE_DIR: Path = xdg.xdg_cache_home() / "behave-testing" + + +class DownloadableExecutable(ABC): + """ + Base class for downloading static binaries to local paths + + Subclasses should implement `get_latest` and `get_stream` methods. They may use the + "kernel", "arch" and "goarch" attributes to select the correct source for the current + platform. + + Subclasses must also provide the "name" attribute, either as a class or instance + attribute. It is used to generate a cache path. + + Users of the subclasses SHOULD ONLY call the `get_binary` method to get a path pointing + to a locally cached copy of the downloaded binary. + """ + + # Map of `uname -m` output to architecture values accepted by Go + # Many Go binaries include the architecture value accepted by `go` in their names, so + # the "goarch" class attribute is added for convenience, generated from this map. + # This map may not be fully complete. Only non-equal values need to be added. + GOARCH_MAP = { + "i386": "386", + "i686": "386", + "x86": "386", + + "x86_64": "amd64", + + "armv6l": "arm", + "armv7l": "arm", + + "aarch64": "arm64", + } + + kernel = platform.system().lower() + arch = platform.machine() + goarch = GOARCH_MAP.get(arch, arch) + + name: str + + def __init__(self, version: str = "latest"): + self.version = version + + @abstractmethod + def get_latest(self, session: requests.Session) -> str: + """ + Return the latest release string for a supported binary + + Implementations must discover and return the latest release or tag string + + `session` is provided for performing HTTP requests. Although its use is not + required, it has and automatic code check hook so there is no need to manually check + the return code and handle errors. + """ + raise NotImplementedError + + @abstractmethod + def get_stream(self, session: requests.Session, version: str) -> IO[bytes]: + """ + Return a stream that emits the requested version of a supported binary + + Implementations must perform a request for the binary and return a file-like reader + + The return object must be a readable FileIO like instance, returning bytes. If the + source is uncompressed the "raw" attribute of a `requests.Response` object opened + with `stream=True` will suffice. See examples below. + + `version` specifies the wanted version of the binary, which MAY be different from + the "version" instance attribute. Other attributes such as "kernel" and "arch" (or + "goarch" if appropriate) MUST be honoured when selecting a source. + + `session` is provided for performing HTTP requests. Although its use is not + required, it has and automatic code check hook so there is no need to manually check + the return code and handle errors. + + + Examples: + + 1) Get an uncompressed binary: + + >>> def get_stream(session: requests.Session, version: str) -> IO[bytes]: + ... url = "https://example.com/binary" + ... return session.get(url, stream=True).raw + + + 2) Get a binary from a GZip compressed tar archive, storing the tar file in memory: + + Note: Avoid this for very large downloads. Unfortunately the Python tarfile + implementation cannot handle non-seekable streams. + + >>> from tarfile import TarFile + + >>> def get_stream(session: requests.Session, version: str) -> IO[bytes]: + ... url = "https://example.com/binary.tar.gz" + ... buf = BytesIO(session.get(url).content) + ... tar = TarFile.gzopen("buffer", fileobj=buf) + ... return tar.extractfile(self.name) + + + 3) Get a binary from a GZip compressed tar archive, storing the tar file in the file + system: + + >>> from tarfile import TarFile + >>> from tempfile import TemporaryFile + >>> from shutil import copyfileobj + + >>> def get_stream(session: requests.Session, version: str) -> IO[bytes]: + ... url = "https://example.com/binary.tar.gz" + ... resp = session.get(url, stream=True) + ... temp = TemporaryFile() + ... copyfileobj(resp.raw, temp) + ... tar = TarFile.gzopen("buffer", fileobj=temp) + ... return tar.extractfile(self.name) + """ + raise NotImplementedError + + def get_binary(self) -> Path: + """ + Return a Path to a locally cached executable, downloading it if necessary + """ + CACHE_DIR.mkdir(0o775, True, True) + version = self.version + + with requests.Session() as session: + assert isinstance(session.hooks["response"], list) + session.hooks["response"].append(lambda r, *a, **k: r.raise_for_status()) + + if version == "latest": + version = self.get_latest(session) + + binary = CACHE_DIR / f"{self.name}-{version}-{self.kernel}-{self.arch}" + if binary.exists(): + return binary + + stream = self.get_stream(session, version) + + try: + with binary.open("wb") as f: + copyfileobj(stream, f) + except BaseException: + binary.unlink() + raise + binary.chmod(0o755) + + return binary + + +class DownloadableDocker(DownloadableExecutable): + """ + Download class for the Docker client binary + """ + + URL = "https://download.docker.com/{kernel}/static/stable/{arch}/docker-{version}.tgz" + LATEST_URL = "https://download.docker.com/{kernel}/static/stable/{arch}/" + VERSION_RE = re.compile(rb'href="docker-(?P[0-9.]+).tgz"') + + name = "docker" + + def get_latest(self, session: requests.Session) -> str: + """ + Return latest Docker release + """ + url = self.LATEST_URL.format(kernel=self.kernel, arch=self.arch) + doc = session.get(url).content + latest = max(self._extract_versions(doc)) + return str(latest) + + def get_stream(self, session: requests.Session, version: str) -> IO[bytes]: + """ + Return a stream that emits theDocker CLI binary + """ + url = self.URL.format(version=version, kernel=self.kernel, arch=self.arch) + buf = BytesIO(session.get(url).content) + tar = TarFile.gzopen("buffer", fileobj=buf) + stream = tar.extractfile("docker/docker") + if stream is None: + raise FileNotFoundError(f"'docker/docker' in {url}") + return stream + + @classmethod + def _extract_versions(cls, doc: bytes) -> Iterator[Version]: + for match in cls.VERSION_RE.finditer(doc): + yield Version(match.group("release").decode()) + + +class DownloadableKubeTools(DownloadableExecutable): + """ + Download class for the kubernetes binaries "kubectl", "kubelet" and "kubeadm" + """ + + URL = "https://dl.k8s.io/release/{version}/bin/{kernel}/{arch}/{name}" + LATEST_URL = "https://dl.k8s.io/release/stable.txt" + + def __init__(self, name: str, version: str = "latest"): + DownloadableExecutable.__init__(self, version) + self.name = name + self._latest = "" + + def get_latest(self, session: requests.Session) -> str: + """ + Return that latest release of Kubernetes + """ + if not self._latest: + self._latest = session.get(self.LATEST_URL).content.decode().strip() + return self._latest + + def get_stream(self, session: requests.Session, version: str) -> IO[bytes]: + """ + Return a stream that emits the requested Kubernetes binary + """ + url = self.URL.format(version=version, kernel=self.kernel, arch=self.goarch, name=self.name) + stream: IO[bytes] = session.get(url, stream=True).raw + return stream + + +class DownloadableCrictl(DownloadableExecutable): + """ + Download class for the "crictl" binary + """ + + URL = "https://github.com/kubernetes-sigs/cri-tools/releases/download/{version}/crictl-{version}-{kernel}-{arch}.tar.gz" + LATEST_URL = "https://api.github.com/repos/kubernetes-sigs/cri-tools/releases/latest" + + name = "cri" + + def get_latest(self, session: requests.Session) -> str: + """ + Return the latest "crictl" release + """ + json = JSONObject.from_string(session.get(self.LATEST_URL).content) + return json.path("$.name", str).replace("cri-tools ", "") + + def get_stream(self, session: requests.Session, version: str) -> IO[bytes]: + """ + Return a stream that emits the requested "crictl" binary + """ + url = self.URL.format(version=version, kernel=self.kernel, arch=self.goarch) + buf = BytesIO(session.get(url).content) + tar = TarFile.gzopen("buffer", fileobj=buf) + stream = tar.extractfile("crictl") + if stream is None: + raise FileNotFoundError(f"'crictl' in {url}") + return stream + + +class DownloadableKind(DownloadableExecutable): + """ + Download class for the "kind" (Kubernetes-in-Docker) binary + """ + + URL = "https://kind.sigs.k8s.io/dl/{version}/kind-{kernel}-{arch}" + LATEST_URL = "https://api.github.com/repos/kubernetes-sigs/kind/releases/latest" + + name = "kind" + + def get_latest(self, session: requests.Session) -> str: + """ + Return the latest Kind binary + """ + json = JSONObject.from_string(session.get(self.LATEST_URL).content) + return json.path("$.name", str) + + def get_stream(self, session: requests.Session, version: str) -> IO[bytes]: + """ + Return a stream that emits the requested Kind binary + """ + url = self.URL.format(version=version, kernel=self.kernel, arch=self.goarch) + stream: IO[bytes] = session.get(url, stream=True).raw + return stream diff --git a/behave_utils/docker.py b/behave_utils/docker.py index 50d3169c88152c8cafb25566ea851b30c7150bbb..7f407683a636cca0652d5c0c53ba88e02f50b5cc 100644 --- a/behave_utils/docker.py +++ b/behave_utils/docker.py @@ -1,4 +1,4 @@ -# Copyright 2021 Dominik Sekotill +# Copyright 2021-2022 Dominik Sekotill # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,10 +10,13 @@ Commands for managing Docker for fixtures from __future__ import annotations +import codecs import ipaddress import json import logging from contextlib import contextmanager +from os import PathLike +from os import fspath from pathlib import Path from secrets import token_hex from subprocess import DEVNULL @@ -27,44 +30,57 @@ from typing import IO from typing import Any from typing import Iterable from typing import Iterator -from typing import Literal -from typing import SupportsBytes +from typing import MutableMapping from typing import Tuple from typing import TypeVar from typing import Union -from typing import overload +from typing import cast +from .binaries import DownloadableDocker +from .json import JSONArray from .json import JSONObject from .proc import Argument from .proc import Arguments from .proc import Deserialiser from .proc import Environ +from .proc import Executor from .proc import MutableArguments -from .proc import PathLike -from .proc import coerce_args from .proc import exec_io -HostMount = tuple[PathLike, PathLike] -NamedMount = tuple[str, PathLike] -AnonMount = PathLike -Mount = Union[HostMount, NamedMount, AnonMount] +MountPath = Union[PathLike[bytes], PathLike[str]] +HostMount = tuple[MountPath, MountPath] +NamedMount = tuple[str, MountPath] +Mount = Union[HostMount, NamedMount, MountPath] Volumes = Iterable[Mount] -DOCKER = "docker" + +try: + run([b"docker", b"version"], stdout=DEVNULL) +except FileNotFoundError: + DOCKER: Argument = DownloadableDocker().get_binary() +else: + DOCKER = b"docker" + + +def utf8_decode(buffer: bytes) -> str: + """ + Return a decoded string from a bytes-like sequence of bytes + """ + return codecs.getdecoder("utf-8")(buffer)[0] def docker(*args: Argument, **env: str) -> None: """ Run a Docker command, with output going to stdout """ - run([DOCKER, *coerce_args(args)], env=env, check=True) + run([DOCKER, *args], env=env, check=True) def docker_output(*args: Argument, **env: str) -> str: """ Run a Docker command, capturing and returning its stdout """ - proc = run([DOCKER, *coerce_args(args)], env=env, check=True, stdout=PIPE, text=True) + proc = run([DOCKER, *args], env=env, check=True, stdout=PIPE, text=True) return proc.stdout.strip() @@ -72,7 +88,7 @@ def docker_quiet(*args: Argument, **env: str) -> None: """ Run a Docker command, directing its stdout to /dev/null """ - run([DOCKER, *coerce_args(args)], env=env, check=True, stdout=DEVNULL) + run([DOCKER, *args], env=env, check=True, stdout=DEVNULL) class IPv4Address(ipaddress.IPv4Address): @@ -136,7 +152,7 @@ class Image(Item): arguments from external lookups without complex argument composing. """ cmd: Arguments = [ - 'build', context, f"--target={target}", + b"build", context, f"--target={target}", *(f"--build-arg={arg}={val}" for arg, val in build_args.items() if val is not None), ] docker(*cmd, DOCKER_BUILDKIT='1') @@ -148,7 +164,7 @@ class Image(Item): """ Pull an image from a registry """ - docker('pull', repository) + docker(b"pull", repository) iid = Item(repository).inspect().path('$.Id', str) return cls(iid) @@ -167,6 +183,8 @@ class Container(Item): exiting the context. """ + T = TypeVar('T', bound='Container') + DEFAULT_ALIASES = tuple[str]() def __init__( @@ -177,6 +195,7 @@ class Container(Item): env: Environ = {}, network: Network|None = None, entrypoint: HostMount|Argument|None = None, + privileged: bool = False, ): if isinstance(entrypoint, tuple): volumes = [*volumes, entrypoint] @@ -187,13 +206,14 @@ class Container(Item): self.volumes = volumes self.env = env self.entrypoint = entrypoint + self.privileged = privileged self.networks = dict[Network, Tuple[str, ...]]() self.cid: str|None = None if network: self.connect(network, *self.DEFAULT_ALIASES) - def __enter__(self) -> Container: + def __enter__(self: T) -> T: return self def __exit__(self, etype: type[BaseException], exc: BaseException, tb: TracebackType) -> None: @@ -210,7 +230,7 @@ class Container(Item): self.start() yield self - def is_running(self) -> bool: + def is_running(self, raise_on_exit: bool = False) -> bool: """ Return whether the container is running """ @@ -223,6 +243,10 @@ class Container(Item): logging.getLogger(__name__).warning( f"container {self.cid} exited ({code})", ) + if raise_on_exit: + cmd = details.path("$.Config.Entrypoint", list[str]) + cmd.extend(details.path("$.Config.Cmd", list[str])) + raise CalledProcessError(code, cmd) return ( self.cid is not None and details.path('$.State.Running', bool) @@ -236,22 +260,34 @@ class Container(Item): return self.cid opts: MutableArguments = [ - "--network=none", - *( - (f"--volume={vol[0]}:{vol[1]}" if isinstance(vol, tuple) else f"--volume={vol}") - for vol in self.volumes - ), + b"--network=none", *(f"--env={name}={val}" for name, val in self.env.items()), ] + for vol in self.volumes: + if isinstance(vol, tuple): + src = fspath(vol[0]) + dst = fspath(vol[1]) + if isinstance(src, bytes): + src = src.decode() + if isinstance(dst, bytes): + dst = dst.decode() + arg: Argument = f"{src}:{dst}" + else: + arg = vol + opts.extend((b"--volume", arg)) + if self.entrypoint: - opts.append(f"--entrypoint={self.entrypoint}") + opts.extend((b"--entrypoint", self.entrypoint)) + + if self.privileged: + opts.append(b"--privileged") - self.cid = docker_output('container', 'create', *opts, self.image.iid, *self.cmd) + self.cid = docker_output(b"container", b"create", *opts, self.image.iid, *self.cmd) assert self.cid # Disconnect the "none" network specified as the starting network - docker_quiet("network", "disconnect", "none", self.cid) + docker_quiet(b"network", b"disconnect", b"none", self.cid) return self.cid @@ -261,7 +297,7 @@ class Container(Item): """ if self.is_running(): return - docker_quiet('container', 'start', self.get_id()) + docker_quiet(b"container", b"start", self.get_id()) def stop(self, rm: bool = False) -> None: """ @@ -271,13 +307,18 @@ class Container(Item): return try: if self.is_running(): - docker_quiet('container', 'stop', self.cid) + docker_quiet(b"container", b"stop", self.cid) finally: if rm: - docker_quiet('container', 'rm', self.cid) + docker_quiet(b"container", b"rm", self.cid) self.cid = None - def connect(self, network: Network, *aliases: str) -> None: + def connect( + self, + network: Network, + *aliases: str, + address: ipaddress.IPv4Address|ipaddress.IPv6Address|None = None, + ) -> None: """ Connect the container to a Docker network @@ -285,15 +326,20 @@ class Container(Item): network. """ cid = self.get_id() + opts = [f'--alias={a}' for a in aliases] + + if address is None: + address = network.reserve_address() + opts.append( + f"--ip={address}" if isinstance(address, ipaddress.IPv4Address) else + f"--ip6={address}", + ) + if network in self.networks: if self.networks[network] == aliases: return - docker('network', 'disconnect', str(network), cid) - docker( - 'network', 'connect', - *(f'--alias={a}' for a in aliases), - str(network), cid, - ) + docker(b"network", b"disconnect", str(network), cid) + docker(b"network", b"connect", *opts, str(network), cid) self.networks[network] = aliases def show_logs(self) -> None: @@ -301,13 +347,13 @@ class Container(Item): Print the container logs to stdout """ if self.cid: - docker('logs', self.cid) + docker(b"logs", self.cid) def get_exec_args(self, cmd: Arguments, interactive: bool = False) -> MutableArguments: """ Return a full argument list for running "cmd" inside the container """ - return [DOCKER, "exec", *(("-i",) if interactive else ""), self.get_id(), *coerce_args(cmd)] + return [DOCKER, b"exec", *((b"-i",) if interactive else []), self.get_id(), *cmd] def run( self, @@ -353,8 +399,12 @@ class Network(Item): A Docker network """ + DOCKER_SUBNET = ipaddress.IPv4Network("172.16.0.0/12") + def __init__(self, name: str|None = None) -> None: self._name = name or f"br{token_hex(6)}" + self._nid: str|None = None + self._assigned = set[ipaddress.IPv4Address]() def __str__(self) -> str: return self._name @@ -389,22 +439,103 @@ class Network(Item): """ Return an identifier for the Docker Network """ - return self._name + if self._nid is None: + self.create() + assert self._nid is not None + return self._nid def create(self) -> None: """ Create the network """ - docker_quiet("network", "create", self._name) + subnet = self.get_free_subnet() + gateway = next(subnet.hosts()) + try: + self._nid = docker_output( + b"network", b"create", self._name, + f"--subnet={subnet}", f"--gateway={gateway}", + ) + except CalledProcessError: + data = exec_io( + [DOCKER, b"network", b"inspect", self._name], + deserialiser=JSONArray.from_string, + ) + if len(data) == 0: + raise + self._nid = data.path("$[0].Id", str) + self._assigned.update( + data.path( + "$[0].IPAM.Config[*].Gateway", + list[str], lambda ls: (IPv4Address(s) for s in ls), + ), + ) + else: + self._assigned.add(gateway) + assert self._nid is not None + assert len(self._assigned) > 0, \ + "Expected gateways address(es) to be added to assigned addresses set" def destroy(self) -> None: """ Remove the network """ - docker_quiet("network", "rm", self._name) + if self._nid: + docker_quiet(b"network", b"rm", self._nid) + + @classmethod + def get_free_subnet(cls) -> ipaddress.IPv4Network: + """ + Return a free private subnet + """ + networks = exec_io( + [DOCKER, b"network", b"ls", b"--format={{.ID}}"], + deserialiser=utf8_decode, + ).splitlines() + subnets = exec_io( + [DOCKER, b"network", b"inspect"] + cast(list[Argument], networks), + deserialiser=JSONArray.from_string, + ).path( + "$[*].IPAM.Config[*].Subnet", list[str], + lambda subnets: {ipaddress.ip_network(net) for net in subnets}, + ) + for subnet in cls.DOCKER_SUBNET.subnets(8): + if not any(net.overlaps(subnet) for net in subnets): + return subnet + raise LookupError(f"No free subnets found in subnet {cls.DOCKER_SUBNET}") + + def reserve_address(self) -> ipaddress.IPv4Address: + """ + Return a free address in the network + + Note that the address is not reserved; any changes made to the network such as + adding a container may invalidate the assurance that the address is free. + """ + # TODO: support IPv6 + data = self.inspect() + # Considering only the first listed subnet + net = data.path("$.IPAM.Config[0].Subnet", str, ipaddress.IPv4Network) + + # Recycle some old code for an assertion about assigned addresses + if __debug__: + reserved: set[ipaddress.IPv4Address] = data.path( + "$.Containers.*.IPv4Address", list[str], + lambda addrs: {IPv4Address.with_suffix(a) for a in addrs}, + ) + reserved.add(data.path("$.IPAM.Config[0].Gateway", str, IPv4Address)) + missing = reserved - self._assigned + assert len(missing) == 0, f"Missing addresses from assigned set: {missing}" + + # Optimise for CPython 3.x without early binding + assigned = self._assigned + for addr in net.hosts(): + if addr not in assigned: + assigned.add(addr) + return addr + raise LookupError(f"No free addresses found in subnet {net}") -class Cli: + +class Cli(Executor): """ Manage calling executables in a container @@ -412,91 +543,19 @@ class Cli: is called. """ - T = TypeVar("T") - def __init__(self, container: Container, *cmd: Argument): + Executor.__init__(self, *cmd) self.container = container - self.cmd = cmd - @overload - def __call__( - self, - *args: Argument, - input: str|bytes|SupportsBytes|None = ..., - deserialiser: Deserialiser[T], - query: Literal[False] = False, - **kwargs: Any, - ) -> T: ... - - @overload - def __call__( - self, - *args: Argument, - input: str|bytes|SupportsBytes|None = ..., - deserialiser: None = None, - query: Literal[True], - **kwargs: Any, - ) -> int: ... - - @overload - def __call__( + def get_arguments( self, - *args: Argument, - input: str|bytes|SupportsBytes|None = ..., - deserialiser: None = None, - query: Literal[False] = False, - **kwargs: Any, - ) -> None: ... - - def __call__( - self, - *args: Argument, - input: str|bytes|SupportsBytes|None = None, - deserialiser: Deserialiser[Any]|None = None, - query: bool = False, - **kwargs: Any, - ) -> Any: + cmd: Arguments, + kwargs: MutableMapping[str, Any], + has_input: bool, + is_query: bool, + deserialiser: Deserialiser[Any]|None, + ) -> Arguments: """ - Run the container executable with the given arguments - - Input: - Any bytes passed as "input" will be fed into the process' stdin pipe. - - Output: - If "deserialiser" is provided it will be called with a memoryview of a buffer - containing any bytes from the process' stdout; whatever is returned by - "deserialiser" will be returned. - - If "query" is true the return code of the process will be returned. - - Otherwise nothing is returned. - - Note that "deserialiser" and "query" are mutually exclusive; if debugging is - enabled an AssertionError will be raised if both are non-None/non-False, otherwise - "query" is ignored. - - Errors: - If "query" is not true any non-zero return code will cause CalledProcessError to - be raised. + Prefix the command arguments with a command necessary for executing in a container """ - # deserialiser = kwargs.pop('deserialiser', None) - assert not deserialiser or not query - - data = ( - b"" if input is None else - input.encode() if isinstance(input, str) else - bytes(input) - ) - cmd = self.container.get_exec_args([*self.cmd, *args], interactive=bool(data)) - - if deserialiser: - return exec_io(cmd, data, deserialiser=deserialiser, **kwargs) - - rcode = exec_io(cmd, data, **kwargs) - if query: - return rcode - if not isinstance(rcode, int): - raise TypeError(f"got rcode {rcode!r}") - if 0 != rcode: - raise CalledProcessError(rcode, ' '.join(coerce_args(cmd))) - return None + return self.container.get_exec_args(cmd, interactive=has_input) diff --git a/behave_utils/http.py b/behave_utils/http.py index 9c602201894148d97c21e620ad453ef95057b009..afaf06615a62ba4ad685b2ec60f4effd485709d3 100644 --- a/behave_utils/http.py +++ b/behave_utils/http.py @@ -1,4 +1,4 @@ -# Copyright 2021 Dominik Sekotill +# Copyright 2021,2022 Dominik Sekotill # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,11 +16,16 @@ from typing import Mapping from urllib.parse import urlparse import requests.adapters -from requests.packages.urllib3 import connection -from requests.packages.urllib3 import connectionpool +from urllib3 import connection +from urllib3 import connectionpool -def redirect(session: requests.Session, prefix: str, address: ipaddress.IPv4Address) -> None: +def redirect( + session: requests.Session, + prefix: str, + address: ipaddress.IPv4Address, + certificate: str|None = None, +) -> None: """ Redirect all requests for "prefix" to a given address @@ -31,11 +36,9 @@ def redirect(session: requests.Session, prefix: str, address: ipaddress.IPv4Addr "prefix" is formated as either "{hostname}[:{port}]" or "{schema}://{hostname}[:{port}]" where "schema" defaults to (and currently only supports) "http". """ - if prefix.startswith("https://"): - raise ValueError("https:// prefixes not currently supported") - if not prefix.startswith("http://"): - prefix = f"http://{prefix}" - session.mount(prefix, _DirectedAdapter(address)) + if not prefix.startswith("http://") or prefix.startswith("https://"): + prefix = f"http://{prefix}" if certificate is None else f"https://{prefix}" + session.mount(prefix, _DirectedAdapter(address, certificate)) class _DirectedAdapter(requests.adapters.HTTPAdapter): @@ -49,13 +52,34 @@ class _DirectedAdapter(requests.adapters.HTTPAdapter): function. """ - def __init__(self, destination: ipaddress.IPv4Address): + def __init__(self, destination: ipaddress.IPv4Address, certificate: str|None): super().__init__() self.destination = destination + self.certificate = certificate - def get_connection(self, url: str, proxies: Mapping[str, str]|None = None) -> _HTTPConnectionPool: + def get_connection(self, url: str, proxies: Mapping[str, str]|None = None) -> connectionpool.HTTPConnectionPool: parts = urlparse(url) - return _HTTPConnectionPool(parts.hostname, parts.port, address=self.destination) + if parts.scheme == "https": + return _HTTPSConnectionPool(parts.hostname, parts.port, address=self.destination) + else: + return _HTTPConnectionPool(parts.hostname, parts.port, address=self.destination) + + def cert_verify( + self, + conn: connection.HTTPConnection, + url: str, + verify: bool|str, + cert: str|tuple[str, str], + ) -> None: + if verify is False: + raise ValueError("Never disable TLS verification") + if verify is not True: + raise ValueError( + "To supply verification certificates please use " + "redirect(session, '{url.scheme}://{url.netloc}', '{self.destination}', Path('{verify}'))", + ) + super().cert_verify(conn, url, True, cert) # type: ignore + conn.ca_cert_data = self.certificate # type: ignore class _HTTPConnectionPool(connectionpool.HTTPConnectionPool): @@ -66,5 +90,11 @@ class _HTTPConnectionPool(connectionpool.HTTPConnectionPool): host = "" def __init__(self, /, address: ipaddress.IPv4Address, **kwargs: Any): - connection.HTTPConnection.__init__(self, **kwargs) + super().__init__(**kwargs) self._dns_host = str(address) + + +class _HTTPSConnectionPool(connectionpool.HTTPSConnectionPool): + + class ConnectionCls(connection.HTTPSConnection, _HTTPConnectionPool.ConnectionCls): + ... diff --git a/behave_utils/json.py b/behave_utils/json.py index 38101bdec2f73d3ea576625d120c2700546885b6..4e8a325e7200d682f6f8fefea43409b063365087 100644 --- a/behave_utils/json.py +++ b/behave_utils/json.py @@ -10,6 +10,7 @@ JSON classes for container types (objects and arrays) from __future__ import annotations +from types import GenericAlias from typing import Any from typing import Callable from typing import TypeVar @@ -35,13 +36,15 @@ class JSONPathMixin: @overload def path(self, path: str, kind: type[T], convert: Callable[[T], C]) -> C: ... - def path(self, path: str, kind: type[T], convert: Callable[[T], C]|None = None) -> T|C: + def path(self, path: str, kind: type[T]|GenericAlias, convert: Callable[[T], C]|None = None) -> T|C: result = JSONPath(path).parse(self) if "*" not in path: try: result = result[0] except IndexError: raise KeyError(path) from None + if isinstance(kind, GenericAlias): + kind = kind.__origin__ if not isinstance(result, kind): raise TypeError(f"{path} is wrong type; expected {kind}; got {type(result)}") if convert is None: diff --git a/behave_utils/proc.py b/behave_utils/proc.py index a39b9507abd467386febab082d538ce1bc0b0499..93923db67a82a750a13a6729c30f453cf20f6d42 100644 --- a/behave_utils/proc.py +++ b/behave_utils/proc.py @@ -1,4 +1,4 @@ -# Copyright 2021 Dominik Sekotill +# Copyright 2021-2022 Dominik Sekotill # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this @@ -11,28 +11,38 @@ Manage processes asynchronously with stdio capture from __future__ import annotations import io -import os import sys +from copy import copy +from functools import partial +from os import PathLike +from os import fspath +from os import write as fdwrite from subprocess import DEVNULL from subprocess import PIPE +from subprocess import CalledProcessError from typing import IO from typing import Any +from typing import BinaryIO from typing import Callable from typing import Iterator +from typing import Literal from typing import Mapping +from typing import MutableMapping from typing import MutableSequence from typing import Sequence +from typing import SupportsBytes +from typing import TextIO from typing import TypeVar from typing import Union from typing import overload +from warnings import warn import trio.abc T = TypeVar('T') Deserialiser = Callable[[memoryview], T] -PathLike = os.PathLike[str] -Argument = Union[PathLike, str] +Argument = Union[str, bytes, PathLike[str], PathLike[bytes]] PathArg = Argument # deprecated Arguments = Sequence[Argument] MutableArguments = MutableSequence[Argument] @@ -43,37 +53,39 @@ def coerce_args(args: Arguments) -> Iterator[str]: """ Ensure path-like arguments are converted to strings """ - return (os.fspath(a) for a in args) + for arg in args: + arg = fspath(arg) + yield arg if isinstance(arg, str) else arg.decode() @overload def exec_io( - cmd: Arguments, - data: bytes = b'', - deserialiser: Deserialiser[T] = ..., + cmd: Arguments, *, + input: bytes = b'', + deserialiser: Deserialiser[T], **kwargs: Any, ) -> T: ... @overload def exec_io( - cmd: Arguments, - data: bytes = b'', + cmd: Arguments, *, + input: bytes = b'', deserialiser: None = None, **kwargs: Any, ) -> int: ... def exec_io( - cmd: Arguments, - data: bytes = b'', + cmd: Arguments, *, + input: bytes = b'', deserialiser: Deserialiser[Any]|None = None, **kwargs: Any, ) -> Any: """ Execute a command, handling output asynchronously - If data is provided it will be fed to the process' stdin. + If input is provided it will be fed to the process' stdin. If a deserialiser is provided it will be used to parse stdout data from the process. Stderr and stdout (if no deserialiser is provided) will be written to `sys.stderr` and @@ -84,11 +96,16 @@ def exec_io( """ if deserialiser and 'stdout' in kwargs: raise TypeError("Cannot provide 'deserialiser' with 'stdout' argument") - if data and 'stdin' in kwargs: - raise TypeError("Cannot provide 'data' with 'stdin' argument") + if "data" in kwargs: + if input: + raise TypeError("both 'input' and the deprecated 'data' keywords provided") + warn(DeprecationWarning("the 'data' keyword argument is deprecated, use 'input'")) + input = kwargs.pop("data") + if input and 'stdin' in kwargs: + raise TypeError("Cannot provide 'input' with 'stdin' argument") stdout: IO[str]|IO[bytes]|int = io.BytesIO() if deserialiser else kwargs.pop('stdout', sys.stdout) stderr: IO[str]|IO[bytes]|int = kwargs.pop('stderr', sys.stderr) - proc = trio.run(_exec_io, cmd, data, stdout, stderr, kwargs) + proc = trio.run(_exec_io, cmd, input, stdout, stderr, kwargs) if deserialiser: assert isinstance(stdout, io.BytesIO) return deserialiser(stdout.getbuffer()) @@ -102,14 +119,16 @@ async def _exec_io( stderr: IO[str]|IO[bytes]|int, kwargs: dict[str, Any], ) -> trio.Process: - proc = await trio.open_process( - [*coerce_args(cmd)], - stdin=PIPE if data else DEVNULL, - stdout=PIPE, - stderr=PIPE, - **kwargs, - ) - async with proc, trio.open_nursery() as nursery: + async with trio.open_nursery() as nursery: + proc: trio.Process = await nursery.start( + partial( + trio.run_process, [*coerce_args(cmd)], + stdin=PIPE if data else DEVNULL, + stdout=PIPE, stderr=PIPE, + check=False, + **kwargs, + ), + ) assert proc.stdout is not None and proc.stderr is not None nursery.start_soon(_passthru, proc.stderr, stderr) nursery.start_soon(_passthru, proc.stdout, stdout) @@ -126,16 +145,14 @@ async def _passthru(in_stream: trio.abc.ReceiveStream, out_stream: IO[str]|IO[by out_stream = out_stream.fileno() except (OSError, AttributeError): # cannot get file descriptor, probably a memory buffer - if isinstance(out_stream, io.BytesIO): + if isinstance(out_stream, (BinaryIO, io.BytesIO)): async def write(data: bytes) -> None: - assert isinstance(out_stream, io.BytesIO) + assert isinstance(out_stream, (BinaryIO, io.BytesIO)) out_stream.write(data) - elif isinstance(out_stream, io.StringIO): + else: async def write(data: bytes) -> None: - assert isinstance(out_stream, io.StringIO) + assert isinstance(out_stream, (TextIO, io.StringIO)) out_stream.write(data.decode()) - else: - raise TypeError(f"Unknown IO type: {type(out_stream)}") else: # is/has a file descriptor, out_stream is now that file descriptor async def write(data: bytes) -> None: @@ -144,7 +161,7 @@ async def _passthru(in_stream: trio.abc.ReceiveStream, out_stream: IO[str]|IO[by remaining = len(data) while remaining: await trio.lowlevel.wait_writable(out_stream) - written = os.write(out_stream, data) + written = fdwrite(out_stream, data) data = data[written:] remaining -= written @@ -153,3 +170,131 @@ async def _passthru(in_stream: trio.abc.ReceiveStream, out_stream: IO[str]|IO[by if not data: return await write(data) + + +class Executor(list[Argument]): + """ + Manage calling executables with composable argument lists + + Subclasses may add or amend the argument list just prior to execution by implementing + `get_arguments`. + + Any arguments passed to the constructor will prefix the arguments passed when the object + is called. + """ + + T = TypeVar("T") + E = TypeVar("E", bound="Executor") + + def __init__(self, *cmd: Argument): + self[:] = cmd + + @overload + def __call__( + self, + *args: Argument, + input: str|bytes|SupportsBytes|None = ..., + deserialiser: Deserialiser[T], + query: Literal[False] = False, + **kwargs: Any, + ) -> T: ... + + @overload + def __call__( + self, + *args: Argument, + input: str|bytes|SupportsBytes|None = ..., + deserialiser: None = None, + query: Literal[True], + **kwargs: Any, + ) -> int: ... + + @overload + def __call__( + self, + *args: Argument, + input: str|bytes|SupportsBytes|None = ..., + deserialiser: None = None, + query: Literal[False] = False, + **kwargs: Any, + ) -> None: ... + + def __call__( + self, + *args: Argument, + input: str|bytes|SupportsBytes|None = None, + deserialiser: Deserialiser[Any]|None = None, + query: bool = False, + **kwargs: Any, + ) -> Any: + """ + Execute the configure command with the given arguments + + Input: + Any bytes passed as "input" will be fed into the process' stdin pipe. + + Output: + If "deserialiser" is provided it will be called with a memoryview of a buffer + containing any bytes from the process' stdout; whatever is returned by + "deserialiser" will be returned. + + If "query" is true the return code of the process will be returned. + + Otherwise nothing is returned. + + Note that "deserialiser" and "query" are mutually exclusive; if debugging is + enabled an AssertionError will be raised if both are non-None/non-False, otherwise + "query" is ignored. + + Errors: + If "query" is not true any non-zero return code will cause CalledProcessError to + be raised. + """ + assert not deserialiser or not query + + data = ( + b"" if input is None else + input.encode() if isinstance(input, str) else + bytes(input) + ) + cmd = self.get_arguments( + [*self, *args], kwargs, + has_input=bool(data), + is_query=query, + deserialiser=deserialiser, + ) + + if deserialiser: + return exec_io(cmd, input=data, deserialiser=deserialiser, **kwargs) + + rcode = exec_io(cmd, input=data, **kwargs) + if query: + return rcode + if 0 != rcode: + raise CalledProcessError(rcode, ' '.join(coerce_args(cmd))) + return None + + def get_arguments( + self, + cmd: Arguments, + kwargs: MutableMapping[str, Any], + has_input: bool, + is_query: bool, + deserialiser: Deserialiser[Any]|None, + ) -> Arguments: + """ + Override to amend command arguments and kwargs for exec_io() prior to execution + """ + return cmd + + def subcommand(self: E, *args: Argument) -> E: + """ + Return a new Executor instance of the same class with additional arguments appended + + The returned instance is created as a shallow copy; if attribute values need to be + copied, subclasses must implement __copy__(). + (see https://docs.python.org/3/library/copy.html) + """ + new = copy(self) + new.extend(args) + return new diff --git a/pyproject.toml b/pyproject.toml index e98398dce68626882a13ca9c78066b54b605a63a..27aea3b543197ca3034d33c5fbe5c931eb54a4bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "behave-utils" -version = "0.2" +version = "0.3" description = "Utilities for writing Behave step implementations" license = "Apache-2.0" readme = "README.md" @@ -27,10 +27,12 @@ include = [ python = "~=3.9" behave = "~=1.2" jsonpath-python = "~=1.0" -orjson = "~=3.6.1" +orjson = "~=3.6" parse = "~=1.19" requests = "~=2.26" -trio = "~=0.19" +trio = "~=0.20.0" +xdg = "~=5.1" +packaging = ">=21" [tool.isort] @@ -41,6 +43,52 @@ strict = true warn_unused_configs = true warn_unreachable = true mypy_path = ["stubs"] +plugins = ["trio_typing.plugin"] + +[[tool.mypy.overrides]] +module = "coverage.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "tests.coverage" +disallow_subclassing_any = false + [tool.flakehell] base = ".flakehell.toml" + + +[tool.coverage.run] +data_file = "results/coverage.db" +branch = true +source = [ + "behave_utils", +] +plugins = [ + "tests.coverage", +] + +[tool.coverage.report] +precision = 2 +skip_empty = true +exclude_lines = [ + "pragma: no cover", + "if .*\\b__name__\\b", + "if .*\\bTYPE_CHECKING\\b", + "class .*(.*\\bProtocol\\b.*):", + "@overload", +] +partial_branches = [ + "pragma: no branch", + "if .*\\b__debug__\\b", +] + +[tool.coverage.json] +output = "results/coverage.json" + +[tool.coverage.xml] +output = "results/coverage.xml" + +[tool.coverage.html] +directory = "results/coverage.html.d" +show_contexts = true diff --git a/stubs/behave-stubs/fixture.pyi b/stubs/behave-stubs/fixture.pyi index 4af0fab1073053d19ca42c53ba699b253b513515..c1f2c41fc6f8d26ea03092d000426993b70212b7 100644 --- a/stubs/behave-stubs/fixture.pyi +++ b/stubs/behave-stubs/fixture.pyi @@ -1,4 +1,3 @@ -import sys from typing import Any from typing import Iterator from typing import Protocol @@ -13,6 +12,9 @@ C_con = TypeVar("C_con", bound=Context, contravariant=True) R = TypeVar("R") R_co = TypeVar("R_co", covariant=True) +P = TypeVar("P", bound=None) +P_co = TypeVar("P_co", covariant=True) # unused + # There's a lot of @overload-ed functions here as fixtures come in two varieties: # 1) A @contextlib.contextmanager-like generator that yields an arbitrary object once. @@ -20,69 +22,34 @@ R_co = TypeVar("R_co", covariant=True) # # "use_fixture" allows both types of fixture callables to be used in the same way -if sys.version_info >= (3, 10) and False: - # This depends on complete support of ParamSpec in mypy so is disabled for now. - - from typing import ParamSpec - - P = ParamSpec("P") - - - class FixtureCoroutine(Protocol[C_con, P, R_co]): - def __call__(self, _: C_con, /, *__a: P.args, **__k: P.kwargs) -> Iterator[R_co]: ... +# Without ParamSpec no checking is done to ensure the arguments passed to use_fixture +# match the fixture's arguments; fixtures must be able to handle arguments not being +# supplied (except the context); and fixtures must accept ANY arbitrary keyword +# arguments. - class FixtureFunction(Protocol[C_con, P, R_co]): - def __call__(self, _: C_con, /, *__a: P.args, **__k: P.kwargs) -> R_co: ... +class FixtureCoroutine(Protocol[C_con, P_co, R_co]): + def __call__(self, _: C_con, /, *__a: Any, **__k: Any) -> Iterator[R_co]: ... - @overload - def use_fixture( - fixture_func: FixtureCoroutine[C_con, P, R], - context: C_con, - *a: P.args, - **k: P.kwargs, - ) -> R: ... - - @overload - def use_fixture( - fixture_func: FixtureFunction[C_con, P, R], - context: C_con, - *a: P.args, - **k: P.kwargs, - ) -> R: ... - -else: - # Without ParamSpec no checking is done to ensure the arguments passed to use_fixture - # match the fixture's arguments; fixtures must be able to handle arguments not being - # supplied (except the context); and fixtures must accept ANY arbitrary keyword - # arguments. +class FixtureFunction(Protocol[C_con, P_co, R_co]): + def __call__(self, _: C_con, /, *__a: Any, **__k: Any) -> R_co: ... - P = TypeVar("P", bound=None) - P_co = TypeVar("P_co", covariant=True) # unused +@overload +def use_fixture( + fixture_func: FixtureCoroutine[C_con, P_co, R_co], + context: C_con, + *a: Any, + **k: Any, +) -> R_co: ... - class FixtureCoroutine(Protocol[C_con, P_co, R_co]): - def __call__(self, _: C_con, /, *__a: Any, **__k: Any) -> Iterator[R_co]: ... - - class FixtureFunction(Protocol[C_con, P_co, R_co]): - def __call__(self, _: C_con, /, *__a: Any, **__k: Any) -> R_co: ... - - - @overload - def use_fixture( - fixture_func: FixtureCoroutine[C_con, P_co, R_co], - context: C_con, - *a: Any, - **k: Any, - ) -> R_co: ... - - @overload - def use_fixture( - fixture_func: FixtureFunction[C_con, P_co, R_co], - context: C_con, - *a: Any, - **k: Any, - ) -> R_co: ... +@overload +def use_fixture( + fixture_func: FixtureFunction[C_con, P_co, R_co], + context: C_con, + *a: Any, + **k: Any, +) -> R_co: ... # "fixture" is a decorator used to mark both types of fixture callables. It can also return diff --git a/stubs/behave-stubs/model.pyi b/stubs/behave-stubs/model.pyi index 9372d29b7090d351af6604d6fe1706f70fed73d2..bcbc9ccebec680154adfde5602bcd7d5129d5161 100644 --- a/stubs/behave-stubs/model.pyi +++ b/stubs/behave-stubs/model.pyi @@ -5,6 +5,7 @@ from typing import Iterator from typing import Literal from typing import Protocol from typing import Sequence +from typing import SupportsIndex from .model_core import BasicStatement from .model_core import Replayable @@ -187,7 +188,7 @@ class Table(Replayable): def __init__(self, headings: Sequence[str], line: int = ..., rows: Sequence[Row] = ...): ... def __eq__(self, other: Any) -> bool: ... def __iter__(self) -> Iterator[Row]: ... - def __getitem__(self, index: int) -> Row: ... + def __getitem__(self, index: SupportsIndex) -> Row: ... def add_row(self, row: Sequence[str], line: int) -> None: ... def add_column(self, column_name: str, values: Iterable[str], default_value: str = ...) -> int: ... @@ -215,13 +216,13 @@ class Row: line: int = ..., comments: Sequence[str] = ..., ): ... - def __getitem__(self, index: int) -> str: ... + def __getitem__(self, index: SupportsIndex) -> str: ... def __eq__(self, other: Any) -> bool: ... def __len__(self) -> int: ... def __iter__(self) -> Iterator[str]: ... def items(self) -> Iterator[tuple[str, str]]: ... - def get(self, key: int, default: str = ...) -> str: ... + def get(self, key: SupportsIndex, default: str = ...) -> str: ... def as_dict(self) -> dict[str, str]: ... @@ -246,5 +247,5 @@ class Text(str): def __init__(self, value: str, content_type: Literal["text/plain"] = ..., line: int = ...): ... def line_range(self) -> tuple[int, int]: ... - def replace(self, old: str, new: str, count: int = ...) -> Text: ... + def replace(self, old: str, new: str, count: SupportsIndex = ...) -> Text: ... def assert_equals(self, expected: str) -> bool: ... diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/coverage.py b/tests/coverage.py new file mode 100644 index 0000000000000000000000000000000000000000..c33528ca5e26d4a314684eda4fe711a4a99e8b54 --- /dev/null +++ b/tests/coverage.py @@ -0,0 +1,47 @@ +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Plugin module for test coverage +""" + +from __future__ import annotations + +from types import FrameType +from typing import Any +from unittest import TestCase + +from coverage.plugin import CoveragePlugin +from coverage.plugin_support import Plugins + + +class DynamicContextPlugin(CoveragePlugin): + """ + A dynamic context plugin for coverage.py + + https://coverage.readthedocs.io/en/latest/contexts.html#dynamic-contexts + + This plugin annotates code lines with the names of tests under which the line was + reached. + """ + + def dynamic_context(self, frame: FrameType) -> str|None: # noqa: D102 + if not frame.f_code.co_name.startswith("test"): + return None + try: + inst = frame.f_locals["self"] + except KeyError: + return None + if isinstance(inst, TestCase): + return str(inst) + return None + + +def coverage_init(reg: Plugins, options: dict[str, Any]) -> None: + """ + Initialise this plugin module + """ + reg.add_dynamic_context(DynamicContextPlugin()) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57bed5dddf34f3ec9509c6ff43570900385ebbec --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Unit tests for behave_utils +""" + +import unittest +import warnings + + +class TestCase(unittest.TestCase): + """ + Base class for all project test cases + + Extends the base class provided by `unittest` + """ + + def setUp(self) -> None: # noqa: D102 + warnings.simplefilter("error", category=DeprecationWarning) + + def tearDown(self) -> None: # noqa: D102 + warnings.resetwarnings() diff --git a/tests/unit/proc/__init__.py b/tests/unit/proc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/unit/proc/fixture_output.py b/tests/unit/proc/fixture_output.py new file mode 100755 index 0000000000000000000000000000000000000000..9f038e5561e8637aaf2e30a8b7f59e7800d07d34 --- /dev/null +++ b/tests/unit/proc/fixture_output.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +A fixture script for producing example outputs +""" + +import argparse +import json +import shutil +import sys +from typing import Callable +from typing import NoReturn + +ErrorC = Callable[[str], NoReturn] + + +def main() -> None: + """ + Produce various example outputs; CLI entrypoint + """ + argp = argparse.ArgumentParser() + subs = argp.add_subparsers(required=True) + + jsonp = subs.add_parser("json", description=json_cmd.__doc__) + jsonp.set_defaults(func=json_cmd, parser=jsonp) + + echop = subs.add_parser("echo", description=echo_cmd.__doc__) + echop.set_defaults(func=echo_cmd, parser=echop) + + rcodep = subs.add_parser("rcode", description=rcode_cmd.__doc__) + rcodep.add_argument("--code", type=int, default=1) + rcodep.set_defaults(func=rcode_cmd, parser=rcodep) + + args = argp.parse_args() + args.func(args, args.parser.error) + + +def json_cmd(args: argparse.Namespace, error: ErrorC) -> None: + """ + Output a sample JSON string + """ + json.dump( + {"example-output": True}, + sys.stdout, + ) + + +def echo_cmd(args: argparse.Namespace, error: ErrorC) -> None: + """ + Echo everything from stdin to stdout + """ + shutil.copyfileobj(sys.stdin, sys.stdout) + + +def rcode_cmd(args: argparse.Namespace, error: ErrorC) -> None: + """ + Return a non-zero return code + """ + if 0 >= args.code or args.code >= 128: + raise error(f"bad value for --code: {args.code}") + raise SystemExit(args.code) + + +main() diff --git a/tests/unit/proc/test_proc.py b/tests/unit/proc/test_proc.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b486839623376d79deab6753d55c6db529edcf --- /dev/null +++ b/tests/unit/proc/test_proc.py @@ -0,0 +1,224 @@ +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Unit tests for behave_utils.proc +""" + +import io +import os +import subprocess +import sys +import warnings +from sys import executable as python +from typing import Any + +from behave_utils import json +from behave_utils import proc + +from .. import TestCase + +FIXTURE_CMD = [python, "-BESsm", "tests.unit.proc.fixture_output"] +JSON_OUTPUT = [*FIXTURE_CMD, "json"] +ECHO_OUTPUT = [*FIXTURE_CMD, "echo"] + +TEST_BYTES = b"""lorem ipsum dolorum""" + + +class ExecIOTests(TestCase): + """ + Tests for the behave_utils.proc.exec_io function + """ + + def test_deserialiser(self) -> None: + """ + Check that calling with a deserialiser correctly deserialises output + """ + with self.subTest(deserialiser=bytes): + output: Any = proc.exec_io(JSON_OUTPUT, deserialiser=bytes) + self.assertIsInstance(output, bytes) + + with self.subTest(deserialiser=json.JSONObject): + output = proc.exec_io(JSON_OUTPUT, deserialiser=json.JSONObject.from_string) + self.assertIsInstance(output, json.JSONObject) + + def test_deserialiser_with_stdout(self) -> None: + """ + Check that calling with both deserialiser and stdout raises TypeError + """ + with self.assertRaises(TypeError): + proc.exec_io(ECHO_OUTPUT, stdout=sys.stdout, deserialiser=bytes) + + def test_input(self) -> None: + """ + Check that calling with the "input" argument passes bytes to stdin + """ + output = proc.exec_io(ECHO_OUTPUT, input=TEST_BYTES, deserialiser=bytes) + + self.assertEqual(output, TEST_BYTES) + + def test_data(self) -> None: + """ + Check that calling with the deprecated "data" argument passes bytes to stdin + """ + msg_re = r".*'data'.* use 'input'" + with warnings.catch_warnings(record=True) as messages: + warnings.filterwarnings("ignore") + warnings.filterwarnings("always", category=DeprecationWarning, message=msg_re) + + output = proc.exec_io(ECHO_OUTPUT, data=TEST_BYTES, deserialiser=bytes) + + self.assertEqual(output, TEST_BYTES) + assert len(messages) == 1 and issubclass(messages[0].category, DeprecationWarning) + + def test_input_with_data(self) -> None: + """ + Check that calling with both "input" and "data" arguments raises TypeError + """ + with self.assertRaises(TypeError): + proc.exec_io(ECHO_OUTPUT, input=TEST_BYTES, data=TEST_BYTES) + + def test_input_with_stdin(self) -> None: + """ + Check that calling with both "input" and "stdin" arguments raises TypeError + """ + with self.assertRaises(TypeError): + proc.exec_io(ECHO_OUTPUT, input=TEST_BYTES, stdin=sys.stdin) + + def test_stdout(self) -> None: + """ + Check that calling with the "stdout" argument receives bytes from stdout + """ + with self.subTest(stdout="BytesIO"): + bbuff = io.BytesIO() + + code = proc.exec_io(ECHO_OUTPUT, input=TEST_BYTES, stdout=bbuff) + + bbuff.seek(0) + self.assertEqual(code, 0) + self.assertEqual(bbuff.read(), TEST_BYTES) + + with self.subTest(stdout="StringIO"): + sbuff = io.StringIO() + + code = proc.exec_io(ECHO_OUTPUT, input=TEST_BYTES, stdout=sbuff) + + sbuff.seek(0) + self.assertEqual(code, 0) + self.assertEqual(sbuff.read(), TEST_BYTES.decode()) + + with self.subTest(stdout="pipe"): + read_fd, write_fd = os.pipe() + + code = proc.exec_io(ECHO_OUTPUT, input=TEST_BYTES, stdout=write_fd) + + os.close(write_fd) + with io.open(read_fd, mode="rb") as pipe: + self.assertEqual(pipe.read(), TEST_BYTES) + self.assertEqual(code, 0) + + def test_return_code(self) -> None: + """ + Check that non-zero return codes are returned when no deserialiser is provided + """ + code = proc.exec_io([*FIXTURE_CMD, "rcode", "--code=5"]) + + self.assertEqual(code, 5) + + +class ExecutorTests(TestCase): + """ + Tests for the behave_utils.proc.Executor class + """ + + def test_deserialiser(self) -> None: + """ + Check that calling with a deserialiser correctly deserialises output + """ + exe = proc.Executor(*FIXTURE_CMD) + + with self.subTest(deserialiser=bytes): + output: Any = exe("json", deserialiser=bytes) + self.assertIsInstance(output, bytes) + + with self.subTest(deserialiser=json.JSONObject): + output = exe("json", deserialiser=json.JSONObject.from_string) + self.assertIsInstance(output, json.JSONObject) + + def test_input(self) -> None: + """ + Check that calling with the "input" argument passes bytes to stdin + """ + exe = proc.Executor(*FIXTURE_CMD) + + output = exe("echo", input=TEST_BYTES, deserialiser=bytes) + + self.assertEqual(output, TEST_BYTES) + + def test_stdout(self) -> None: + """ + Check that calling with the "stdout" argument receives bytes from stdout + """ + exe = proc.Executor(*FIXTURE_CMD) + + with self.subTest(stdout="BytesIO"): + bbuff = io.BytesIO() + + exe("echo", input=TEST_BYTES, stdout=bbuff) + + bbuff.seek(0) + self.assertEqual(bbuff.read(), TEST_BYTES) + + with self.subTest(stdout="StringIO"): + sbuff = io.StringIO() + + exe("echo", input=TEST_BYTES, stdout=sbuff) + + sbuff.seek(0) + self.assertEqual(sbuff.read(), TEST_BYTES.decode()) + + with self.subTest(stdout="pipe"): + read_fd, write_fd = os.pipe() + + exe("echo", input=TEST_BYTES, stdout=write_fd) + + os.close(write_fd) + with io.open(read_fd, mode="rb") as pipe: + self.assertEqual(pipe.read(), TEST_BYTES) + + def test_return_code(self) -> None: + """ + Check that the "query" argument behaves as expected + """ + exe = proc.Executor(*FIXTURE_CMD) + + with self.subTest(query=True): + exe("rcode", "--code=3", query=True) + + with self.subTest(query=False), self.assertRaises(subprocess.CalledProcessError): + exe("rcode", "--code=3", query=False) + + with self.subTest(query=None), self.assertRaises(subprocess.CalledProcessError): + exe("rcode", "--code=3") + + def test_subcommand(self) -> None: + """ + Check that the subcommand method returns a new instance with appended arguments + """ + class NewExecutor(proc.Executor): + ... + + with self.subTest(cls=proc.Executor): + exe = proc.Executor("foo", "bar").subcommand("baz") + + self.assertIsInstance(exe, proc.Executor) + self.assertListEqual(exe, ["foo", "bar", "baz"]) + + with self.subTest(cls=NewExecutor): + exe = NewExecutor("foo", "bar").subcommand("baz") + + self.assertIsInstance(exe, NewExecutor) + self.assertListEqual(exe, ["foo", "bar", "baz"]) diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py new file mode 100644 index 0000000000000000000000000000000000000000..10d2f5ed9205ba692c2aabfcaa7e6b1d2151dfd6 --- /dev/null +++ b/tests/unit/test_json.py @@ -0,0 +1,174 @@ +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Unit tests for behave_utils.json +""" + +from json import dumps as json_dumps +from typing import Any + +from behave_utils.json import JSONArray +from behave_utils.json import JSONObject + +from . import TestCase + +SAMPLE_OBJECT = { + "integer_object": { + "first": 1, + "second": 2, + }, +} + +SAMPLE_ARRAY = [ + SAMPLE_OBJECT, + SAMPLE_OBJECT, +] + +SAMPLE_OBJECT_JSON = json_dumps(SAMPLE_OBJECT).encode() +SAMPLE_ARRAY_JSON = json_dumps(SAMPLE_ARRAY).encode() + + +class JSONObjectTests(TestCase): + """ + Tests for behave_utils.json.JSONObject + """ + + def test_from_dict(self) -> None: + """ + Check that a JSONObject can be created from a dictionary of values + """ + instance = JSONObject(SAMPLE_OBJECT) + + self.assertIsInstance(instance, JSONObject) + self.assertDictEqual(instance, SAMPLE_OBJECT) + + def test_from_JSONObject(self) -> None: + """ + Check that a JSONObject can be created from another JSONObject + """ + source = JSONObject(SAMPLE_OBJECT) + + instance = JSONObject(source) + + self.assertIsInstance(instance, JSONObject) + self.assertDictEqual(instance, SAMPLE_OBJECT) + + def test_from_from_string(self) -> None: + """ + Check that a JSONObject can be created from a string passed to the from_string method + """ + instance = JSONObject.from_string(SAMPLE_OBJECT_JSON) + + self.assertIsInstance(instance, JSONObject) + self.assertDictEqual(instance, SAMPLE_OBJECT) + + def test_from_from_string_bad_type(self) -> None: + """ + Check that a JSON string of the wrong type raises TypeError + """ + with self.assertRaises(TypeError): + JSONObject.from_string(b"""[1,2,3]""") + + def test_path(self) -> None: + """ + Check that retrieving values with the path method works with various types + """ + sample = JSONObject(SAMPLE_OBJECT) + + with self.subTest("single return"): + output: Any = sample.path("$.integer_object.first", int) + + self.assertIsInstance(output, int) + + with self.subTest("multiple returns"): + output = sample.path("$.integer_object[*]", list[int]) + + self.assertIsInstance(output, list) + self.assertSetEqual(set(output), {1, 2}) + + with self.subTest("object return"): + output = sample.path("$.integer_object", dict[str, int]) + + self.assertIsInstance(output, dict) + self.assertDictEqual(output, SAMPLE_OBJECT["integer_object"]) + + def test_path_convert(self) -> None: + """ + Check that passing "convert" to the path method works + """ + sample = JSONObject(SAMPLE_OBJECT) + + with self.subTest("int -> str"): + output: Any = sample.path("$.integer_object.first", int, str) + + self.assertIsInstance(output, str) + + with self.subTest("list -> set"): + output = sample.path("$.integer_object[*]", list[int], set[int]) + + self.assertIsInstance(output, set) + self.assertSetEqual(output, {1, 2}) + + def test_path_bad_type(self) -> None: + """ + Check that TypeError is raised when the type passed to path() does not match + """ + sample = JSONObject(SAMPLE_OBJECT) + + with self.assertRaises(TypeError): + sample.path("$.integer_object.first", str) + + def test_path_missing(self) -> None: + """ + Check that KeyError is raised when no result is found + """ + sample = JSONObject(SAMPLE_OBJECT) + + with self.assertRaises(KeyError): + sample.path("$.integer_object.last", int) + + +class JSONArrayTests(TestCase): + """ + Tests for behave_utils.json.JSONArray + """ + + def test_from_dict(self) -> None: + """ + Check that a JSONArray can be created from a dictionary of values + """ + instance = JSONArray(SAMPLE_ARRAY) + + self.assertIsInstance(instance, JSONArray) + self.assertListEqual(instance, SAMPLE_ARRAY) + + def test_from_JSONObject(self) -> None: + """ + Check that a JSONArray can be created from another JSONArray + """ + source = JSONArray(SAMPLE_ARRAY) + + instance = JSONArray(source) + + self.assertIsInstance(instance, JSONArray) + self.assertListEqual(instance, SAMPLE_ARRAY) + + def test_from_from_string(self) -> None: + """ + Check that a JSONArray can be created from a string passed to the from_string method + """ + instance = JSONArray.from_string(SAMPLE_ARRAY_JSON) + + self.assertIsInstance(instance, JSONArray) + self.assertListEqual(instance, SAMPLE_ARRAY) + + def test_from_from_string_bad_type(self) -> None: + """ + Check that a JSON string of the wrong type raises TypeError + """ + with self.assertRaises(TypeError): + JSONArray.from_string(b"""{"first": 1}""") diff --git a/tests/unit/test_secret.py b/tests/unit/test_secret.py new file mode 100644 index 0000000000000000000000000000000000000000..eb59c133b778842c40fa5c08f674c04cc21a0609 --- /dev/null +++ b/tests/unit/test_secret.py @@ -0,0 +1,30 @@ +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Unit tests for behave_utils.secret +""" + +from behave_utils import secret + +from . import TestCase + + +class MakeSecretTests(TestCase): + """ + Tests for behave_utils.secret.make_secret + """ + + def test(self) -> None: + """ + Check that calling make_secret returns a correctly sized string + """ + for num in (20, 40): + with self.subTest(f"{num} chars"): + output = secret.make_secret(num) + + self.assertIsInstance(output, str) + assert len(output) == num diff --git a/tests/unit/test_url.py b/tests/unit/test_url.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b400a8d03a7c1a5c0a76de32799182e9b46a18 --- /dev/null +++ b/tests/unit/test_url.py @@ -0,0 +1,47 @@ +# Copyright 2022 Dominik Sekotill +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Unit tests for behave_utils.url +""" + +from behave_utils.url import URL + +from . import TestCase + + +class URLTests(TestCase): + """ + Tests for behave_utils.url.URL + """ + + def test_join(self) -> None: + """ + Check that the division operator does a URL join + """ + with self.subTest("check type"): + assert isinstance(URL("https://example.com/foo") / "bar", URL) + + with self.subTest("no slash"): + assert URL("https://example.com/foo") / "bar" == "https://example.com/bar" + + with self.subTest("slash"): + assert URL("https://example.com/foo/") / "bar" == "https://example.com/foo/bar" + + with self.subTest("both slash"): + assert URL("https://example.com/foo/") / "bar/" == "https://example.com/foo/bar/" + + with self.subTest("absolute"): + assert URL("https://example.com/foo/bar/") / "/bar" == "https://example.com/bar" + + def test_append(self) -> None: + """ + Check that the addition operator appends + """ + url = URL("https://example.com/foo") + "bar" + + assert url == "https://example.com/foobar" + assert isinstance(url, URL)