Verified Commit e1ed6c40 authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Handle shallow clones in copyright hook

Fixes #4
parent b0f090d2
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -169,6 +169,12 @@ ignore = [

	# Unfortunately a lot of single quotes strings used in this project already
	"Q000",

	# DISABLE "Call date.today()"
	# I don't know why someone thought that a timezone aware date object was
	# ever a thing that would be needed, considering it does not represent
	# a specific point in time.
	"DTZ011",
]

[lint.per-file-ignores]
+56 −23
Original line number Diff line number Diff line
#!/usr/bin/env python3
#  Copyright 2021, 2022  Dominik Sekotill <dom.sekotill@kodo.org.uk>
#  Copyright 2021, 2022, 2025  Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
@@ -21,19 +21,18 @@ import argparse
import os
import re
import sys
import time
from collections.abc import Iterable
from collections.abc import Iterator
from datetime import date
from functools import lru_cache
from itertools import product
from pathlib import Path
from subprocess import DEVNULL
from subprocess import PIPE
from subprocess import run
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List


def check_file(path: Path, year: str, min_size: int) -> bool:
def check_file(path: Path, year: int, min_size: int) -> bool:
	"""
	Check for an up-to-date copyright notice in the first few lines of the given file
	"""
@@ -44,7 +43,7 @@ def check_file(path: Path, year: str, min_size: int) -> bool:
	return bool(re.search(f'\\b(?:copyright)\\b.*\\b{year}\\b', lines, re.I))


def filter_excluded(paths: Iterable[Path]) -> List[Path]:
def filter_excluded(paths: Iterable[Path]) -> list[Path]:
	"""
	Return a list of files from the input that are not excluded by git
	"""
@@ -54,31 +53,33 @@ def filter_excluded(paths: Iterable[Path]) -> List[Path]:
	return list(split_paths(proc.stdout))


def get_file_years(paths: List[Path]) -> Dict[Path, str]:
def get_file_years(paths: list[Path], ignored_commits: set[str]) -> dict[Path, int]:
	"""
	Return a mapping of paths to the year they where last changed (if they are tracked)
	"""
	output = dict()
	output = dict[Path, int]()

	cmd = [
		'git', 'log', '--topo-order',
		'--format=format:%ad', '--date=format:%Y',
		'--format=format:%ad %H', '--date=format:%Y',
		'--name-only', '-z', '--',
	]
	cmd.extend(p.as_posix() for p in paths)
	proc = run(cmd, stdout=PIPE, check=True)

	regex = re.compile(br'(?P<year>[0-9]{4,})(?:\n|\r|\r\n)(?P<files>.*?)(\x00\x00|$)')
	for match in regex.finditer(proc.stdout):
		year = match.group('year').decode()
		for path in split_paths(match.group('files')):
			if path not in output:
				output[path] = year
	for result in proc.stdout.split(b'\0\0'):
		commits, files = result.splitlines()
		combos = product(split_commits(commits), split_paths(files))
		for (commit_year, commit_hash), path in combos:
			if commit_hash in ignored_commits:
				continue
			if output.setdefault(path, commit_year) < commit_year:
				output[path] = commit_year

	return output


def get_changed(paths: List[Path]) -> Iterator[Path]:
def get_changed(paths: list[Path]) -> Iterator[Path]:
	"""
	Return an iterator of changed paths
	"""
@@ -88,11 +89,31 @@ def get_changed(paths: List[Path]) -> Iterator[Path]:
	return split_paths(proc.stdout)


def get_grafted() -> set[str]:
	"""
	Return a set of grafted commits (from shallow clones)
	"""
	shallow_file = git_dir() / "shallow"
	if not shallow_file.exists():
		return set()
	with shallow_file.open() as file:
		return set(file.read().splitlines())


def split_paths(paths: bytes) -> Iterator[Path]:
	"""
	Return an iterator of Paths from a null-separated byte string list
	"""
	return (Path(p.decode()) for p in paths.split(b'\x00') if p != b'')
	return (Path(p.decode()) for p in paths.split(b"\0") if p != b'')


def split_commits(commits: bytes) -> Iterator[tuple[int, str]]:
	"""
	Return a iterator of commit (year, hash) tuples from a null-separated byte string list
	"""
	for commit in commits.split(b"\0"):
		year, ref = commit.split(b" ", maxsplit=1)
		yield int(year), ref.decode()


@lru_cache(1)
@@ -105,6 +126,16 @@ def has_commits() -> bool:
	return proc.returncode == 0


@lru_cache(1)
def git_dir() -> Path:
	"""
	Return the git-dir for the current worktree
	"""
	cmd = ["git", "rev-parse", "--git-dir"]
	proc = run(cmd, stdout=PIPE, check=True)
	return Path(proc.stdout.strip().decode("utf-8"))


def cli_parser() -> argparse.ArgumentParser:
	"""
	Return an argparse parser
@@ -130,15 +161,17 @@ def main() -> None:
	opts = cli_parser().parse_args()
	paths = filter_excluded(opts.files)

	year = time.strftime('%Y')
	years = {path: year for path in paths if path.is_file()}
	year = date.today().year
	if has_commits():
		years.update(get_file_years(paths))
		grafts = get_grafted()
		years = get_file_years(paths, grafts)
		years.update((path, year) for path in get_changed(paths))
	else:
		years = {path: year for path in paths if path.is_file()}

	missing = []
	for path in paths:
		if path.is_file() and not check_file(path, years[path], opts.min_size):
		if path.is_file() and path in years and not check_file(path, years[path], opts.min_size):
			missing.append(path)

	if missing: