11import json
22import os
3+ from pathlib import Path
34import platform
45import threading
5- from pathlib import Path
66from typing import Any
77
88import certifi
2121
2222
2323def is_macos_arm64 () -> bool :
24- return platform .platform ().lower ().startswith ("macos" ) and platform .machine () == "arm64"
24+ return (
25+ platform .platform ().lower ().startswith ("macos" )
26+ and platform .machine () == "arm64"
27+ )
2528
2629
2730try :
@@ -85,7 +88,9 @@ def _init_connection(self):
8588 # Add authentication if API key provided
8689 if self ._api_key :
8790 db_kwargs [adbc_driver_manager .DatabaseOptions .USERNAME .value ] = ""
88- db_kwargs [adbc_driver_manager .DatabaseOptions .PASSWORD .value ] = self ._api_key
91+ db_kwargs [adbc_driver_manager .DatabaseOptions .PASSWORD .value ] = (
92+ self ._api_key
93+ )
8994
9095 # Create low-level database and connection (avoids dbapi autocommit warning)
9196 self ._db = adbc_driver_flightsql .connect (self ._uri , db_kwargs = db_kwargs )
@@ -110,7 +115,11 @@ def _create_param_batch(
110115
111116 for param in params :
112117 # Check if param is a tuple of (value, arrow_type)
113- if isinstance (param , tuple ) and len (param ) == 2 and isinstance (param [1 ], pa .DataType ):
118+ if (
119+ isinstance (param , tuple )
120+ and len (param ) == 2
121+ and isinstance (param [1 ], pa .DataType )
122+ ):
114123 value , arrow_type = param
115124 param_values .append (value )
116125 param_types .append (arrow_type )
@@ -124,7 +133,9 @@ def _create_param_batch(
124133 param_arrays .append (pa .array ([value ], type = arrow_type ))
125134
126135 # Create parameter schema with positional field names ($1, $2, etc.)
127- param_fields = [pa .field (f"${ i + 1 } " , param_types [i ]) for i in range (len (params ))]
136+ param_fields = [
137+ pa .field (f"${ i + 1 } " , param_types [i ]) for i in range (len (params ))
138+ ]
128139 param_schema = pa .schema (param_fields )
129140
130141 return pa .record_batch (param_arrays , schema = param_schema )
@@ -189,7 +200,11 @@ def __init__(
189200 tls_root_cert ,
190201 ):
191202 if tls_root_cert is not None :
192- tls_root_cert = tls_root_cert if isinstance (tls_root_cert , Path ) else Path (tls_root_cert )
203+ tls_root_cert = (
204+ tls_root_cert
205+ if isinstance (tls_root_cert , Path )
206+ else Path (tls_root_cert )
207+ )
193208 else :
194209 tls_root_cert = Path (certifi .where ())
195210
@@ -208,14 +223,19 @@ def _user_agent(custom_user_agent=None):
208223
209224 # Prepend the custom user agent (if provided) to the default user agent
210225 if custom_user_agent :
211- return (str .encode ("user-agent" ), str .encode (f"{ custom_user_agent } { config .SPICE_USER_AGENT } " ))
226+ return (
227+ str .encode ("user-agent" ),
228+ str .encode (f"{ custom_user_agent } { config .SPICE_USER_AGENT } " ),
229+ )
212230 return (str .encode ("user-agent" ), str .encode (config .SPICE_USER_AGENT ))
213231
214232 def __init__ (self , grpc : str , api_key : str , tls_root_certs , user_agent = None ):
215233 self ._flight_client = flight .connect (grpc , tls_root_certs = tls_root_certs )
216234 self ._api_key = api_key
217235 self .headers = [_SpiceFlight ._user_agent (user_agent )]
218- self ._flight_options = flight .FlightCallOptions (headers = self .headers , timeout = DEFAULT_QUERY_TIMEOUT_SECS )
236+ self ._flight_options = flight .FlightCallOptions (
237+ headers = self .headers , timeout = DEFAULT_QUERY_TIMEOUT_SECS
238+ )
219239 self ._authenticate ()
220240
221241 def _authenticate (self ):
@@ -224,30 +244,42 @@ def _authenticate(self):
224244 self ._flight_client .authenticate_basic_token ("" , self ._api_key ),
225245 _SpiceFlight ._user_agent (),
226246 ]
227- self ._flight_options = flight .FlightCallOptions (headers = self .headers , timeout = DEFAULT_QUERY_TIMEOUT_SECS )
247+ self ._flight_options = flight .FlightCallOptions (
248+ headers = self .headers , timeout = DEFAULT_QUERY_TIMEOUT_SECS
249+ )
228250 else :
229251 self .headers = [_SpiceFlight ._user_agent ()]
230- self ._flight_options = flight .FlightCallOptions (headers = self .headers , timeout = DEFAULT_QUERY_TIMEOUT_SECS )
252+ self ._flight_options = flight .FlightCallOptions (
253+ headers = self .headers , timeout = DEFAULT_QUERY_TIMEOUT_SECS
254+ )
231255
232256 def query (self , query : str , ** kwargs ) -> flight .FlightStreamReader :
233257 timeout = kwargs .get ("timeout" )
234258
235259 if timeout is not None :
236260 if not isinstance (timeout , int ) or timeout <= 0 :
237261 raise ValueError ("Timeout must be a positive integer" )
238- self ._flight_options = flight .FlightCallOptions (headers = self .headers , timeout = timeout )
262+ self ._flight_options = flight .FlightCallOptions (
263+ headers = self .headers , timeout = timeout
264+ )
239265
240266 flight_info = self ._flight_client .get_flight_info (
241267 flight .FlightDescriptor .for_command (query ), self ._flight_options
242268 )
243269
244270 try :
245- reader = self ._threaded_flight_do_get (ticket = flight_info .endpoints [0 ].ticket )
271+ reader = self ._threaded_flight_do_get (
272+ ticket = flight_info .endpoints [0 ].ticket
273+ )
246274 except flight .FlightUnauthenticatedError :
247275 self ._authenticate ()
248- reader = self ._threaded_flight_do_get (ticket = flight_info .endpoints [0 ].ticket )
276+ reader = self ._threaded_flight_do_get (
277+ ticket = flight_info .endpoints [0 ].ticket
278+ )
249279 except flight .FlightTimedOutError as exc :
250- raise TimeoutError (f"Query timed out and was canceled after { timeout } seconds." ) from exc
280+ raise TimeoutError (
281+ f"Query timed out and was canceled after { timeout } seconds."
282+ ) from exc
251283
252284 return reader
253285
@@ -275,7 +307,9 @@ def __init__(
275307 user_agent : str | None = None ,
276308 ): # pylint: disable=R0913
277309 tls_root_certs = _Cert (tls_root_cert ).tls_root_certs
278- self ._flight = _SpiceFlight (flight_url , api_key or "" , tls_root_certs , user_agent )
310+ self ._flight = _SpiceFlight (
311+ flight_url , api_key or "" , tls_root_certs , user_agent
312+ )
279313
280314 self .api_key = api_key
281315 self ._flight_url = flight_url
@@ -369,15 +403,23 @@ def query_with_params(
369403 ValueError: If params is None
370404 """
371405 if params is None :
372- raise ValueError ("params must be a list, not None. Use [] for queries without parameters." )
406+ raise ValueError (
407+ "params must be a list, not None. Use [] for queries without parameters."
408+ )
373409 adbc = self ._ensure_adbc_client ()
374410 return adbc .query_with_params (sql , params )
375411
376- def refresh_dataset (self , dataset : str , refresh_opts : RefreshOpts | None = None ) -> Any :
412+ def refresh_dataset (
413+ self , dataset : str , refresh_opts : RefreshOpts | None = None
414+ ) -> Any :
377415 response = self .http .send_request (
378416 "POST" ,
379417 f"/v1/datasets/{ dataset } /acceleration/refresh" ,
380- body = (json .dumps (refresh_opts .to_dict ()) if refresh_opts is not None else json .dumps ({})),
418+ body = (
419+ json .dumps (refresh_opts .to_dict ())
420+ if refresh_opts is not None
421+ else json .dumps ({})
422+ ),
381423 headers = {"Content-Type" : "application/json" },
382424 )
383425
0 commit comments