@@ -236,8 +236,22 @@ def _user_agent(custom_user_agent=None):
236236 )
237237 return (str .encode ("user-agent" ), str .encode (config .SPICE_USER_AGENT ))
238238
239- def __init__ (self , grpc : str , api_key : str , tls_root_certs , user_agent = None ):
240- self ._flight_client = flight .connect (grpc , tls_root_certs = tls_root_certs )
239+ def __init__ (self , grpc : str , api_key : str , tls_root_certs , user_agent = None ,
240+ tls_client_certificate : str | Path | None = None ,
241+ tls_client_key : str | Path | None = None ):
242+ connect_kwargs = {"tls_root_certs" : tls_root_certs }
243+ if tls_client_certificate is not None and tls_client_key is not None :
244+ cert_path = (
245+ tls_client_certificate
246+ if isinstance (tls_client_certificate , Path )
247+ else Path (tls_client_certificate )
248+ )
249+ key_path = tls_client_key if isinstance (tls_client_key , Path ) else Path (tls_client_key )
250+ with open (cert_path , "rb" ) as f :
251+ connect_kwargs ["cert_chain" ] = f .read ()
252+ with open (key_path , "rb" ) as f :
253+ connect_kwargs ["private_key" ] = f .read ()
254+ self ._flight_client = flight .connect (grpc , ** connect_kwargs )
241255 self ._api_key = api_key
242256 self .headers = [_SpiceFlight ._user_agent (user_agent )]
243257 self ._flight_options = flight .FlightCallOptions (
@@ -312,17 +326,38 @@ def __init__(
312326 http_url : str = config .DEFAULT_HTTP_URL ,
313327 tls_root_cert : str | Path | None = None ,
314328 user_agent : str | None = None ,
329+ tls_client_certificate : str | Path | None = None ,
330+ tls_client_key : str | Path | None = None ,
315331 ): # pylint: disable=R0913
332+ # Validate that client cert and key are either both set or both unset
333+ has_cert = tls_client_certificate is not None
334+ has_key = tls_client_key is not None
335+ if has_cert != has_key :
336+ missing = (
337+ "tls_client_key" if has_cert else "tls_client_certificate"
338+ )
339+ raise ValueError (
340+ f"Both tls_client_certificate and tls_client_key must be "
341+ f"provided together for mTLS. { missing } is missing."
342+ )
343+
316344 tls_root_certs = _Cert (tls_root_cert ).tls_root_certs
317345 self ._flight = _SpiceFlight (
318- flight_url , api_key or "" , tls_root_certs , user_agent
346+ flight_url , api_key or "" , tls_root_certs , user_agent ,
347+ tls_client_certificate = tls_client_certificate ,
348+ tls_client_key = tls_client_key ,
319349 )
320350
321351 self .api_key = api_key
322352 self ._flight_url = flight_url
323353 self ._user_agent = user_agent
324354 self ._adbc_client : _ADBCClient | None = None
325- self .http = HttpRequests (http_url , self ._headers (user_agent ))
355+ self .http = HttpRequests (
356+ http_url ,
357+ self ._headers (user_agent ),
358+ tls_client_certificate = tls_client_certificate ,
359+ tls_client_key = tls_client_key ,
360+ )
326361
327362 def _headers (self , user_agent = None ) -> dict [str , str ]:
328363 headers = {
0 commit comments