Skip to content

Commit 55253d5

Browse files
feat: add mTLS client certificate support
1 parent 889f6a9 commit 55253d5

2 files changed

Lines changed: 51 additions & 5 deletions

File tree

spicepy/_client.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,22 @@ def _user_agent(custom_user_agent=None):
236236
)
237237
return (str.encode("user-agent"), str.encode(config.SPICE_USER_AGENT))
238238

239-
def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None):
240-
self._flight_client = flight.connect(grpc, tls_root_certs=tls_root_certs)
239+
def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None,
240+
tls_client_certificate: str | Path | None = None,
241+
tls_client_key: str | Path | None = None):
242+
connect_kwargs = {"tls_root_certs": tls_root_certs}
243+
if tls_client_certificate is not None and tls_client_key is not None:
244+
cert_path = (
245+
tls_client_certificate
246+
if isinstance(tls_client_certificate, Path)
247+
else Path(tls_client_certificate)
248+
)
249+
key_path = tls_client_key if isinstance(tls_client_key, Path) else Path(tls_client_key)
250+
with open(cert_path, "rb") as f:
251+
connect_kwargs["cert_chain"] = f.read()
252+
with open(key_path, "rb") as f:
253+
connect_kwargs["private_key"] = f.read()
254+
self._flight_client = flight.connect(grpc, **connect_kwargs)
241255
self._api_key = api_key
242256
self.headers = [_SpiceFlight._user_agent(user_agent)]
243257
self._flight_options = flight.FlightCallOptions(
@@ -312,17 +326,38 @@ def __init__(
312326
http_url: str = config.DEFAULT_HTTP_URL,
313327
tls_root_cert: str | Path | None = None,
314328
user_agent: str | None = None,
329+
tls_client_certificate: str | Path | None = None,
330+
tls_client_key: str | Path | None = None,
315331
): # pylint: disable=R0913
332+
# Validate that client cert and key are either both set or both unset
333+
has_cert = tls_client_certificate is not None
334+
has_key = tls_client_key is not None
335+
if has_cert != has_key:
336+
missing = (
337+
"tls_client_key" if has_cert else "tls_client_certificate"
338+
)
339+
raise ValueError(
340+
f"Both tls_client_certificate and tls_client_key must be "
341+
f"provided together for mTLS. {missing} is missing."
342+
)
343+
316344
tls_root_certs = _Cert(tls_root_cert).tls_root_certs
317345
self._flight = _SpiceFlight(
318-
flight_url, api_key or "", tls_root_certs, user_agent
346+
flight_url, api_key or "", tls_root_certs, user_agent,
347+
tls_client_certificate=tls_client_certificate,
348+
tls_client_key=tls_client_key,
319349
)
320350

321351
self.api_key = api_key
322352
self._flight_url = flight_url
323353
self._user_agent = user_agent
324354
self._adbc_client: _ADBCClient | None = None
325-
self.http = HttpRequests(http_url, self._headers(user_agent))
355+
self.http = HttpRequests(
356+
http_url,
357+
self._headers(user_agent),
358+
tls_client_certificate=tls_client_certificate,
359+
tls_client_key=tls_client_key,
360+
)
326361

327362
def _headers(self, user_agent=None) -> dict[str, str]:
328363
headers = {

spicepy/_http.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass
33
import datetime
4+
from pathlib import Path
45
from typing import Any, Literal
56

67
from requests import Response, Session
@@ -28,13 +29,23 @@ def to_dict(self) -> dict[str, Any]:
2829

2930

3031
class HttpRequests:
31-
def __init__(self, base_url: str, headers: dict[str, str]) -> None:
32+
def __init__(
33+
self,
34+
base_url: str,
35+
headers: dict[str, str],
36+
tls_client_certificate: str | Path | None = None,
37+
tls_client_key: str | Path | None = None,
38+
) -> None:
3239
self.session = self._create_session(headers)
3340

3441
# set the user-agent header
3542
if "user-agent" not in self.session.headers:
3643
self.session.headers["user-agent"] = SPICE_USER_AGENT
3744

45+
# Configure client certificate for mTLS on HTTP requests
46+
if tls_client_certificate is not None and tls_client_key is not None:
47+
self.session.cert = (str(tls_client_certificate), str(tls_client_key))
48+
3849
self.base_url = base_url
3950

4051
# pylint: disable=R0913

0 commit comments

Comments
 (0)