|
3 | 3 | import platform |
4 | 4 | import threading |
5 | 5 | from pathlib import Path |
6 | | -from typing import Any, Optional, Union |
| 6 | +from typing import Any |
7 | 7 |
|
8 | 8 | import certifi |
9 | 9 | import pyarrow as pa |
@@ -54,8 +54,8 @@ class _ADBCClient: |
54 | 54 | def __init__( |
55 | 55 | self, |
56 | 56 | uri: str, |
57 | | - api_key: Optional[str] = None, |
58 | | - user_agent: Optional[str] = None, |
| 57 | + api_key: str | None = None, |
| 58 | + user_agent: str | None = None, |
59 | 59 | ): |
60 | 60 | if not ADBC_AVAILABLE: |
61 | 61 | raise ImportError( |
@@ -120,7 +120,7 @@ def _create_param_batch( |
120 | 120 |
|
121 | 121 | # Create parameter arrays (each with a single row) |
122 | 122 | param_arrays = [] |
123 | | - for value, arrow_type in zip(param_values, param_types): |
| 123 | + for value, arrow_type in zip(param_values, param_types, strict=True): |
124 | 124 | param_arrays.append(pa.array([value], type=arrow_type)) |
125 | 125 |
|
126 | 126 | # Create parameter schema with positional field names ($1, $2, etc.) |
@@ -268,19 +268,19 @@ class Client: |
268 | 268 | # pylint: disable=R0917 |
269 | 269 | def __init__( |
270 | 270 | self, |
271 | | - api_key: Optional[str] = None, |
| 271 | + api_key: str | None = None, |
272 | 272 | flight_url: str = config.DEFAULT_LOCAL_FLIGHT_URL, |
273 | 273 | http_url: str = config.DEFAULT_HTTP_URL, |
274 | | - tls_root_cert: Union[str, Path, None] = None, |
275 | | - user_agent: Optional[str] = None, |
| 274 | + tls_root_cert: str | Path | None = None, |
| 275 | + user_agent: str | None = None, |
276 | 276 | ): # pylint: disable=R0913 |
277 | 277 | tls_root_certs = _Cert(tls_root_cert).tls_root_certs |
278 | 278 | self._flight = _SpiceFlight(flight_url, api_key or "", tls_root_certs, user_agent) |
279 | 279 |
|
280 | 280 | self.api_key = api_key |
281 | 281 | self._flight_url = flight_url |
282 | 282 | self._user_agent = user_agent |
283 | | - self._adbc_client: Optional[_ADBCClient] = None |
| 283 | + self._adbc_client: _ADBCClient | None = None |
284 | 284 | self.http = HttpRequests(http_url, self._headers(user_agent)) |
285 | 285 |
|
286 | 286 | def _headers(self, user_agent=None) -> dict[str, str]: |
@@ -373,7 +373,7 @@ def query_with_params( |
373 | 373 | adbc = self._ensure_adbc_client() |
374 | 374 | return adbc.query_with_params(sql, params) |
375 | 375 |
|
376 | | - def refresh_dataset(self, dataset: str, refresh_opts: Optional[RefreshOpts] = None) -> Any: |
| 376 | + def refresh_dataset(self, dataset: str, refresh_opts: RefreshOpts | None = None) -> Any: |
377 | 377 | response = self.http.send_request( |
378 | 378 | "POST", |
379 | 379 | f"/v1/datasets/{dataset}/acceleration/refresh", |
|
0 commit comments