@@ -236,8 +236,32 @@ def _user_agent(custom_user_agent=None):
236236 )
237237 return (str .encode ("user-agent" ), str .encode (config .SPICE_USER_AGENT ))
238238
239- def __init__ (self , grpc : str , api_key : str , tls_root_certs , user_agent = None ):
240- self ._flight_client = flight .connect (grpc , tls_root_certs = tls_root_certs )
239+ def __init__ (
240+ self ,
241+ grpc : str ,
242+ api_key : str ,
243+ tls_root_certs ,
244+ user_agent = None ,
245+ tls_client_certificate : str | Path | None = None ,
246+ tls_client_key : str | Path | None = None ,
247+ ):
248+ connect_kwargs = {"tls_root_certs" : tls_root_certs }
249+ if tls_client_certificate is not None and tls_client_key is not None :
250+ cert_path = (
251+ tls_client_certificate
252+ if isinstance (tls_client_certificate , Path )
253+ else Path (tls_client_certificate )
254+ )
255+ key_path = (
256+ tls_client_key
257+ if isinstance (tls_client_key , Path )
258+ else Path (tls_client_key )
259+ )
260+ with open (cert_path , "rb" ) as f :
261+ connect_kwargs ["cert_chain" ] = f .read ()
262+ with open (key_path , "rb" ) as f :
263+ connect_kwargs ["private_key" ] = f .read ()
264+ self ._flight_client = flight .connect (grpc , ** connect_kwargs )
241265 self ._api_key = api_key
242266 self .headers = [_SpiceFlight ._user_agent (user_agent )]
243267 self ._flight_options = flight .FlightCallOptions (
@@ -312,17 +336,39 @@ def __init__(
312336 http_url : str = config .DEFAULT_HTTP_URL ,
313337 tls_root_cert : str | Path | None = None ,
314338 user_agent : str | None = None ,
339+ tls_client_certificate : str | Path | None = None ,
340+ tls_client_key : str | Path | None = None ,
315341 ): # pylint: disable=R0913
342+ # Validate that client cert and key are either both set or both unset
343+ has_cert = tls_client_certificate is not None
344+ has_key = tls_client_key is not None
345+ if has_cert != has_key :
346+ missing = "tls_client_key" if has_cert else "tls_client_certificate"
347+ raise ValueError (
348+ f"Both tls_client_certificate and tls_client_key must be "
349+ f"provided together for mTLS. { missing } is missing."
350+ )
351+
316352 tls_root_certs = _Cert (tls_root_cert ).tls_root_certs
317353 self ._flight = _SpiceFlight (
318- flight_url , api_key or "" , tls_root_certs , user_agent
354+ flight_url ,
355+ api_key or "" ,
356+ tls_root_certs ,
357+ user_agent ,
358+ tls_client_certificate = tls_client_certificate ,
359+ tls_client_key = tls_client_key ,
319360 )
320361
321362 self .api_key = api_key
322363 self ._flight_url = flight_url
323364 self ._user_agent = user_agent
324365 self ._adbc_client : _ADBCClient | None = None
325- self .http = HttpRequests (http_url , self ._headers (user_agent ))
366+ self .http = HttpRequests (
367+ http_url ,
368+ self ._headers (user_agent ),
369+ tls_client_certificate = tls_client_certificate ,
370+ tls_client_key = tls_client_key ,
371+ )
326372
327373 def _headers (self , user_agent = None ) -> dict [str , str ]:
328374 headers = {
0 commit comments