Commit e78ae4dc authored by Dom Sekotill's avatar Dom Sekotill
Browse files

Support passing certificates to redirect

parent 65e617d1
Loading
Loading
Loading
Loading
+27 −4
Original line number Diff line number Diff line
@@ -20,7 +20,12 @@ from requests.packages.urllib3 import connection
from requests.packages.urllib3 import connectionpool


def redirect(session: requests.Session, prefix: str, address: ipaddress.IPv4Address) -> None:
def redirect(
	session: requests.Session,
	prefix: str,
	address: ipaddress.IPv4Address,
	certificate: str|None = None,
) -> None:
	"""
	Redirect all requests for "prefix" to a given address

@@ -32,8 +37,8 @@ def redirect(session: requests.Session, prefix: str, address: ipaddress.IPv4Addr
	where "schema" defaults to (and currently only supports) "http".
	"""
	if not prefix.startswith("http://") or prefix.startswith("https://"):
		prefix = f"http://{prefix}"
	session.mount(prefix, _DirectedAdapter(address))
		prefix = f"http://{prefix}" if certificate is None else f"https://{prefix}"
	session.mount(prefix, _DirectedAdapter(address, certificate))


class _DirectedAdapter(requests.adapters.HTTPAdapter):
@@ -47,9 +52,10 @@ class _DirectedAdapter(requests.adapters.HTTPAdapter):
	function.
	"""

	def __init__(self, destination: ipaddress.IPv4Address):
	def __init__(self, destination: ipaddress.IPv4Address, certificate: str|None):
		super().__init__()
		self.destination = destination
		self.certificate = certificate

	def get_connection(self, url: str, proxies: Mapping[str, str]|None = None) -> connectionpool.HTTPConnectionPool:
		parts = urlparse(url)
@@ -58,6 +64,23 @@ class _DirectedAdapter(requests.adapters.HTTPAdapter):
		else:
			return _HTTPConnectionPool(parts.hostname, parts.port, address=self.destination)

	def cert_verify(
		self,
		conn: connection.HTTPConnection,
		url: str,
		verify: bool|str,
		cert: str|tuple[str, str],
	) -> None:
		if verify is False:
			raise ValueError("Never disable TLS verification")
		if verify is not True:
			raise ValueError(
				"To supply verification certificates please use "
				"redirect(session, '{url.scheme}://{url.netloc}', '{self.destination}', Path('{verify}'))",
			)
		super().cert_verify(conn, url, True, cert)  # type: ignore
		conn.ca_cert_data = self.certificate  # type: ignore


class _HTTPConnectionPool(connectionpool.HTTPConnectionPool):