Skip to content

Commit 5c8b75e

Browse files
feat: add mTLS client certificate support (#156)
1 parent 889f6a9 commit 5c8b75e

2 files changed

Lines changed: 62 additions & 5 deletions

File tree

spicepy/_client.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,32 @@ def _user_agent(custom_user_agent=None):
236236
)
237237
return (str.encode("user-agent"), str.encode(config.SPICE_USER_AGENT))
238238

239-
def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None):
240-
self._flight_client = flight.connect(grpc, tls_root_certs=tls_root_certs)
239+
def __init__(
240+
self,
241+
grpc: str,
242+
api_key: str,
243+
tls_root_certs,
244+
user_agent=None,
245+
tls_client_certificate: str | Path | None = None,
246+
tls_client_key: str | Path | None = None,
247+
):
248+
connect_kwargs = {"tls_root_certs": tls_root_certs}
249+
if tls_client_certificate is not None and tls_client_key is not None:
250+
cert_path = (
251+
tls_client_certificate
252+
if isinstance(tls_client_certificate, Path)
253+
else Path(tls_client_certificate)
254+
)
255+
key_path = (
256+
tls_client_key
257+
if isinstance(tls_client_key, Path)
258+
else Path(tls_client_key)
259+
)
260+
with open(cert_path, "rb") as f:
261+
connect_kwargs["cert_chain"] = f.read()
262+
with open(key_path, "rb") as f:
263+
connect_kwargs["private_key"] = f.read()
264+
self._flight_client = flight.connect(grpc, **connect_kwargs)
241265
self._api_key = api_key
242266
self.headers = [_SpiceFlight._user_agent(user_agent)]
243267
self._flight_options = flight.FlightCallOptions(
@@ -312,17 +336,39 @@ def __init__(
312336
http_url: str = config.DEFAULT_HTTP_URL,
313337
tls_root_cert: str | Path | None = None,
314338
user_agent: str | None = None,
339+
tls_client_certificate: str | Path | None = None,
340+
tls_client_key: str | Path | None = None,
315341
): # pylint: disable=R0913
342+
# Validate that client cert and key are either both set or both unset
343+
has_cert = tls_client_certificate is not None
344+
has_key = tls_client_key is not None
345+
if has_cert != has_key:
346+
missing = "tls_client_key" if has_cert else "tls_client_certificate"
347+
raise ValueError(
348+
f"Both tls_client_certificate and tls_client_key must be "
349+
f"provided together for mTLS. {missing} is missing."
350+
)
351+
316352
tls_root_certs = _Cert(tls_root_cert).tls_root_certs
317353
self._flight = _SpiceFlight(
318-
flight_url, api_key or "", tls_root_certs, user_agent
354+
flight_url,
355+
api_key or "",
356+
tls_root_certs,
357+
user_agent,
358+
tls_client_certificate=tls_client_certificate,
359+
tls_client_key=tls_client_key,
319360
)
320361

321362
self.api_key = api_key
322363
self._flight_url = flight_url
323364
self._user_agent = user_agent
324365
self._adbc_client: _ADBCClient | None = None
325-
self.http = HttpRequests(http_url, self._headers(user_agent))
366+
self.http = HttpRequests(
367+
http_url,
368+
self._headers(user_agent),
369+
tls_client_certificate=tls_client_certificate,
370+
tls_client_key=tls_client_key,
371+
)
326372

327373
def _headers(self, user_agent=None) -> dict[str, str]:
328374
headers = {

spicepy/_http.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass
33
import datetime
4+
from pathlib import Path
45
from typing import Any, Literal
56

67
from requests import Response, Session
@@ -28,13 +29,23 @@ def to_dict(self) -> dict[str, Any]:
2829

2930

3031
class HttpRequests:
31-
def __init__(self, base_url: str, headers: dict[str, str]) -> None:
32+
def __init__(
33+
self,
34+
base_url: str,
35+
headers: dict[str, str],
36+
tls_client_certificate: str | Path | None = None,
37+
tls_client_key: str | Path | None = None,
38+
) -> None:
3239
self.session = self._create_session(headers)
3340

3441
# set the user-agent header
3542
if "user-agent" not in self.session.headers:
3643
self.session.headers["user-agent"] = SPICE_USER_AGENT
3744

45+
# Configure client certificate for mTLS on HTTP requests
46+
if tls_client_certificate is not None and tls_client_key is not None:
47+
self.session.cert = (str(tls_client_certificate), str(tls_client_key))
48+
3849
self.base_url = base_url
3950

4051
# pylint: disable=R0913

0 commit comments

Comments
 (0)