Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions spicepy/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,32 @@ def _user_agent(custom_user_agent=None):
)
return (str.encode("user-agent"), str.encode(config.SPICE_USER_AGENT))

def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None):
self._flight_client = flight.connect(grpc, tls_root_certs=tls_root_certs)
def __init__(
self,
grpc: str,
api_key: str,
tls_root_certs,
user_agent=None,
tls_client_certificate: str | Path | None = None,
tls_client_key: str | Path | None = None,
):
connect_kwargs = {"tls_root_certs": tls_root_certs}
if tls_client_certificate is not None and tls_client_key is not None:
cert_path = (
tls_client_certificate
if isinstance(tls_client_certificate, Path)
else Path(tls_client_certificate)
)
key_path = (
tls_client_key
if isinstance(tls_client_key, Path)
else Path(tls_client_key)
)
with open(cert_path, "rb") as f:
connect_kwargs["cert_chain"] = f.read()
with open(key_path, "rb") as f:
connect_kwargs["private_key"] = f.read()
self._flight_client = flight.connect(grpc, **connect_kwargs)
self._api_key = api_key
self.headers = [_SpiceFlight._user_agent(user_agent)]
self._flight_options = flight.FlightCallOptions(
Expand Down Expand Up @@ -312,17 +336,39 @@ def __init__(
http_url: str = config.DEFAULT_HTTP_URL,
tls_root_cert: str | Path | None = None,
user_agent: str | None = None,
tls_client_certificate: str | Path | None = None,
tls_client_key: str | Path | None = None,
): # pylint: disable=R0913
# Validate that client cert and key are either both set or both unset
has_cert = tls_client_certificate is not None
has_key = tls_client_key is not None
if has_cert != has_key:
missing = "tls_client_key" if has_cert else "tls_client_certificate"
raise ValueError(
f"Both tls_client_certificate and tls_client_key must be "
f"provided together for mTLS. {missing} is missing."
)

tls_root_certs = _Cert(tls_root_cert).tls_root_certs
self._flight = _SpiceFlight(
flight_url, api_key or "", tls_root_certs, user_agent
flight_url,
api_key or "",
tls_root_certs,
user_agent,
tls_client_certificate=tls_client_certificate,
tls_client_key=tls_client_key,
)

self.api_key = api_key
self._flight_url = flight_url
self._user_agent = user_agent
self._adbc_client: _ADBCClient | None = None
self.http = HttpRequests(http_url, self._headers(user_agent))
self.http = HttpRequests(
http_url,
self._headers(user_agent),
tls_client_certificate=tls_client_certificate,
tls_client_key=tls_client_key,
)

def _headers(self, user_agent=None) -> dict[str, str]:
headers = {
Expand Down
13 changes: 12 additions & 1 deletion spicepy/_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from dataclasses import dataclass
import datetime
from pathlib import Path
from typing import Any, Literal

from requests import Response, Session
Expand Down Expand Up @@ -28,13 +29,23 @@ def to_dict(self) -> dict[str, Any]:


class HttpRequests:
def __init__(self, base_url: str, headers: dict[str, str]) -> None:
def __init__(
self,
base_url: str,
headers: dict[str, str],
tls_client_certificate: str | Path | None = None,
tls_client_key: str | Path | None = None,
) -> None:
self.session = self._create_session(headers)

# set the user-agent header
if "user-agent" not in self.session.headers:
self.session.headers["user-agent"] = SPICE_USER_AGENT

# Configure client certificate for mTLS on HTTP requests
if tls_client_certificate is not None and tls_client_key is not None:
self.session.cert = (str(tls_client_certificate), str(tls_client_key))

self.base_url = base_url

# pylint: disable=R0913
Expand Down
Loading