diff --git a/spicepy/_client.py b/spicepy/_client.py index c7ce822..6c3597e 100644 --- a/spicepy/_client.py +++ b/spicepy/_client.py @@ -236,8 +236,32 @@ def _user_agent(custom_user_agent=None): ) return (str.encode("user-agent"), str.encode(config.SPICE_USER_AGENT)) - def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None): - self._flight_client = flight.connect(grpc, tls_root_certs=tls_root_certs) + def __init__( + self, + grpc: str, + api_key: str, + tls_root_certs, + user_agent=None, + tls_client_certificate: str | Path | None = None, + tls_client_key: str | Path | None = None, + ): + connect_kwargs = {"tls_root_certs": tls_root_certs} + if tls_client_certificate is not None and tls_client_key is not None: + cert_path = ( + tls_client_certificate + if isinstance(tls_client_certificate, Path) + else Path(tls_client_certificate) + ) + key_path = ( + tls_client_key + if isinstance(tls_client_key, Path) + else Path(tls_client_key) + ) + with open(cert_path, "rb") as f: + connect_kwargs["cert_chain"] = f.read() + with open(key_path, "rb") as f: + connect_kwargs["private_key"] = f.read() + self._flight_client = flight.connect(grpc, **connect_kwargs) self._api_key = api_key self.headers = [_SpiceFlight._user_agent(user_agent)] self._flight_options = flight.FlightCallOptions( @@ -312,17 +336,39 @@ def __init__( http_url: str = config.DEFAULT_HTTP_URL, tls_root_cert: str | Path | None = None, user_agent: str | None = None, + tls_client_certificate: str | Path | None = None, + tls_client_key: str | Path | None = None, ): # pylint: disable=R0913 + # Validate that client cert and key are either both set or both unset + has_cert = tls_client_certificate is not None + has_key = tls_client_key is not None + if has_cert != has_key: + missing = "tls_client_key" if has_cert else "tls_client_certificate" + raise ValueError( + f"Both tls_client_certificate and tls_client_key must be " + f"provided together for mTLS. {missing} is missing." + ) + tls_root_certs = _Cert(tls_root_cert).tls_root_certs self._flight = _SpiceFlight( - flight_url, api_key or "", tls_root_certs, user_agent + flight_url, + api_key or "", + tls_root_certs, + user_agent, + tls_client_certificate=tls_client_certificate, + tls_client_key=tls_client_key, ) self.api_key = api_key self._flight_url = flight_url self._user_agent = user_agent self._adbc_client: _ADBCClient | None = None - self.http = HttpRequests(http_url, self._headers(user_agent)) + self.http = HttpRequests( + http_url, + self._headers(user_agent), + tls_client_certificate=tls_client_certificate, + tls_client_key=tls_client_key, + ) def _headers(self, user_agent=None) -> dict[str, str]: headers = { diff --git a/spicepy/_http.py b/spicepy/_http.py index 09fe414..bc3f5b1 100644 --- a/spicepy/_http.py +++ b/spicepy/_http.py @@ -1,6 +1,7 @@ from collections.abc import Callable from dataclasses import dataclass import datetime +from pathlib import Path from typing import Any, Literal from requests import Response, Session @@ -28,13 +29,23 @@ def to_dict(self) -> dict[str, Any]: class HttpRequests: - def __init__(self, base_url: str, headers: dict[str, str]) -> None: + def __init__( + self, + base_url: str, + headers: dict[str, str], + tls_client_certificate: str | Path | None = None, + tls_client_key: str | Path | None = None, + ) -> None: self.session = self._create_session(headers) # set the user-agent header if "user-agent" not in self.session.headers: self.session.headers["user-agent"] = SPICE_USER_AGENT + # Configure client certificate for mTLS on HTTP requests + if tls_client_certificate is not None and tls_client_key is not None: + self.session.cert = (str(tls_client_certificate), str(tls_client_key)) + self.base_url = base_url # pylint: disable=R0913