Skip to content

Commit 005bd48

Browse files
committed
Add black formatting check to CI lint workflow
Add black>=26.3.1 to test dependencies and a lint-black job to the lint workflow. Adjust ruff isort config for black compatibility: add force-sort-within-sections and remove lines-after-imports=2.
1 parent 5b91eb7 commit 005bd48

11 files changed

Lines changed: 218 additions & 64 deletions

File tree

.github/workflows/lint.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,27 @@ jobs:
8484
run: |
8585
mypy spicepy --ignore-missing-imports
8686
87+
lint-black:
88+
runs-on: ubuntu-latest
89+
strategy:
90+
fail-fast: false
91+
matrix:
92+
python-version: ['3.11', '3.12', '3.13', '3.14']
93+
name: Lint with black (Python ${{ matrix.python-version }})
94+
steps:
95+
- uses: actions/checkout@v6
96+
- name: Set up Python ${{ matrix.python-version }}
97+
uses: actions/setup-python@v6
98+
with:
99+
python-version: ${{ matrix.python-version }}
100+
cache: 'pip'
101+
- name: Install requirements
102+
run: |
103+
pip install ".[test]"
104+
- name: Check formatting with black
105+
run: |
106+
black --check spicepy tests
107+
87108
security-check:
88109
runs-on: ubuntu-latest
89110
name: Security scan with bandit

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ test = [
5050
"pytest_httpserver==1.1.5",
5151
"types-requests>=2.33.0",
5252
"pandas-stubs>=3.0.0",
53+
"black>=26.3.1",
5354
"bandit>=1.9.4",
5455
"pandas>=3.0.2",
5556
"pyarrow>=23.0.1",
@@ -189,7 +190,7 @@ ignore = [
189190
[tool.ruff.lint.isort]
190191
known-first-party = ["spicepy"]
191192
force-single-line = false
192-
lines-after-imports = 2
193+
force-sort-within-sections = true
193194

194195
[tool.ruff.lint.pydocstyle]
195196
convention = "google"

spicepy/_client.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
22
import os
3+
from pathlib import Path
34
import platform
45
import threading
5-
from pathlib import Path
66
from typing import Any
77

88
import certifi
@@ -21,7 +21,10 @@
2121

2222

2323
def is_macos_arm64() -> bool:
24-
return platform.platform().lower().startswith("macos") and platform.machine() == "arm64"
24+
return (
25+
platform.platform().lower().startswith("macos")
26+
and platform.machine() == "arm64"
27+
)
2528

2629

2730
try:
@@ -85,7 +88,9 @@ def _init_connection(self):
8588
# Add authentication if API key provided
8689
if self._api_key:
8790
db_kwargs[adbc_driver_manager.DatabaseOptions.USERNAME.value] = ""
88-
db_kwargs[adbc_driver_manager.DatabaseOptions.PASSWORD.value] = self._api_key
91+
db_kwargs[adbc_driver_manager.DatabaseOptions.PASSWORD.value] = (
92+
self._api_key
93+
)
8994

9095
# Create low-level database and connection (avoids dbapi autocommit warning)
9196
self._db = adbc_driver_flightsql.connect(self._uri, db_kwargs=db_kwargs)
@@ -110,7 +115,11 @@ def _create_param_batch(
110115

111116
for param in params:
112117
# Check if param is a tuple of (value, arrow_type)
113-
if isinstance(param, tuple) and len(param) == 2 and isinstance(param[1], pa.DataType):
118+
if (
119+
isinstance(param, tuple)
120+
and len(param) == 2
121+
and isinstance(param[1], pa.DataType)
122+
):
114123
value, arrow_type = param
115124
param_values.append(value)
116125
param_types.append(arrow_type)
@@ -124,7 +133,9 @@ def _create_param_batch(
124133
param_arrays.append(pa.array([value], type=arrow_type))
125134

126135
# Create parameter schema with positional field names ($1, $2, etc.)
127-
param_fields = [pa.field(f"${i + 1}", param_types[i]) for i in range(len(params))]
136+
param_fields = [
137+
pa.field(f"${i + 1}", param_types[i]) for i in range(len(params))
138+
]
128139
param_schema = pa.schema(param_fields)
129140

130141
return pa.record_batch(param_arrays, schema=param_schema)
@@ -189,7 +200,11 @@ def __init__(
189200
tls_root_cert,
190201
):
191202
if tls_root_cert is not None:
192-
tls_root_cert = tls_root_cert if isinstance(tls_root_cert, Path) else Path(tls_root_cert)
203+
tls_root_cert = (
204+
tls_root_cert
205+
if isinstance(tls_root_cert, Path)
206+
else Path(tls_root_cert)
207+
)
193208
else:
194209
tls_root_cert = Path(certifi.where())
195210

@@ -208,14 +223,19 @@ def _user_agent(custom_user_agent=None):
208223

209224
# Prepend the custom user agent (if provided) to the default user agent
210225
if custom_user_agent:
211-
return (str.encode("user-agent"), str.encode(f"{custom_user_agent} {config.SPICE_USER_AGENT}"))
226+
return (
227+
str.encode("user-agent"),
228+
str.encode(f"{custom_user_agent} {config.SPICE_USER_AGENT}"),
229+
)
212230
return (str.encode("user-agent"), str.encode(config.SPICE_USER_AGENT))
213231

214232
def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None):
215233
self._flight_client = flight.connect(grpc, tls_root_certs=tls_root_certs)
216234
self._api_key = api_key
217235
self.headers = [_SpiceFlight._user_agent(user_agent)]
218-
self._flight_options = flight.FlightCallOptions(headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS)
236+
self._flight_options = flight.FlightCallOptions(
237+
headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS
238+
)
219239
self._authenticate()
220240

221241
def _authenticate(self):
@@ -224,30 +244,42 @@ def _authenticate(self):
224244
self._flight_client.authenticate_basic_token("", self._api_key),
225245
_SpiceFlight._user_agent(),
226246
]
227-
self._flight_options = flight.FlightCallOptions(headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS)
247+
self._flight_options = flight.FlightCallOptions(
248+
headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS
249+
)
228250
else:
229251
self.headers = [_SpiceFlight._user_agent()]
230-
self._flight_options = flight.FlightCallOptions(headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS)
252+
self._flight_options = flight.FlightCallOptions(
253+
headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS
254+
)
231255

232256
def query(self, query: str, **kwargs) -> flight.FlightStreamReader:
233257
timeout = kwargs.get("timeout")
234258

235259
if timeout is not None:
236260
if not isinstance(timeout, int) or timeout <= 0:
237261
raise ValueError("Timeout must be a positive integer")
238-
self._flight_options = flight.FlightCallOptions(headers=self.headers, timeout=timeout)
262+
self._flight_options = flight.FlightCallOptions(
263+
headers=self.headers, timeout=timeout
264+
)
239265

240266
flight_info = self._flight_client.get_flight_info(
241267
flight.FlightDescriptor.for_command(query), self._flight_options
242268
)
243269

244270
try:
245-
reader = self._threaded_flight_do_get(ticket=flight_info.endpoints[0].ticket)
271+
reader = self._threaded_flight_do_get(
272+
ticket=flight_info.endpoints[0].ticket
273+
)
246274
except flight.FlightUnauthenticatedError:
247275
self._authenticate()
248-
reader = self._threaded_flight_do_get(ticket=flight_info.endpoints[0].ticket)
276+
reader = self._threaded_flight_do_get(
277+
ticket=flight_info.endpoints[0].ticket
278+
)
249279
except flight.FlightTimedOutError as exc:
250-
raise TimeoutError(f"Query timed out and was canceled after {timeout} seconds.") from exc
280+
raise TimeoutError(
281+
f"Query timed out and was canceled after {timeout} seconds."
282+
) from exc
251283

252284
return reader
253285

@@ -275,7 +307,9 @@ def __init__(
275307
user_agent: str | None = None,
276308
): # pylint: disable=R0913
277309
tls_root_certs = _Cert(tls_root_cert).tls_root_certs
278-
self._flight = _SpiceFlight(flight_url, api_key or "", tls_root_certs, user_agent)
310+
self._flight = _SpiceFlight(
311+
flight_url, api_key or "", tls_root_certs, user_agent
312+
)
279313

280314
self.api_key = api_key
281315
self._flight_url = flight_url
@@ -369,15 +403,23 @@ def query_with_params(
369403
ValueError: If params is None
370404
"""
371405
if params is None:
372-
raise ValueError("params must be a list, not None. Use [] for queries without parameters.")
406+
raise ValueError(
407+
"params must be a list, not None. Use [] for queries without parameters."
408+
)
373409
adbc = self._ensure_adbc_client()
374410
return adbc.query_with_params(sql, params)
375411

376-
def refresh_dataset(self, dataset: str, refresh_opts: RefreshOpts | None = None) -> Any:
412+
def refresh_dataset(
413+
self, dataset: str, refresh_opts: RefreshOpts | None = None
414+
) -> Any:
377415
response = self.http.send_request(
378416
"POST",
379417
f"/v1/datasets/{dataset}/acceleration/refresh",
380-
body=(json.dumps(refresh_opts.to_dict()) if refresh_opts is not None else json.dumps({})),
418+
body=(
419+
json.dumps(refresh_opts.to_dict())
420+
if refresh_opts is not None
421+
else json.dumps({})
422+
),
381423
headers={"Content-Type": "application/json"},
382424
)
383425

spicepy/_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import datetime
21
from collections.abc import Callable
32
from dataclasses import dataclass
3+
import datetime
44
from typing import Any, Literal
55

66
from requests import Response, Session

spicepy/config.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
from importlib.metadata import version
12
import os
23
import platform
3-
from importlib.metadata import version
4-
54

65
DEFAULT_FLIGHT_URL = os.environ.get("SPICE_FLIGHT_URL", "grpc+tls://flight.spiceai.io")
76
DEFAULT_HTTP_URL = os.environ.get("SPICE_HTTP_URL", "https://data.spiceai.io")
87

9-
DEFAULT_LOCAL_FLIGHT_URL = os.environ.get("SPICE_LOCAL_FLIGHT_URL", "grpc://localhost:50051")
8+
DEFAULT_LOCAL_FLIGHT_URL = os.environ.get(
9+
"SPICE_LOCAL_FLIGHT_URL", "grpc://localhost:50051"
10+
)
1011
DEFAULT_LOCAL_HTTP_URL = os.environ.get("SPICE_LOCAL_HTTP_URL", "http://localhost:8090")
1112

1213

@@ -19,7 +20,9 @@
1920
# Default is the system information of the current system, e.g. `Linux/5.4.0-1043-aws x86_64`.
2021
###
2122
def get_user_agent(
22-
client_name: str | None = None, client_version: str | None = None, client_system: str | None = None
23+
client_name: str | None = None,
24+
client_version: str | None = None,
25+
client_system: str | None = None,
2326
) -> str:
2427
package_version = version("spicepy") if client_version is None else client_version
2528
system = platform.system()
@@ -28,7 +31,9 @@ def get_user_agent(
2831
if arch == "AMD64":
2932
arch = "x86_64"
3033

31-
system_info = f"{system}/{release} {arch}" if client_system is None else client_system
34+
system_info = (
35+
f"{system}/{release} {arch}" if client_system is None else client_system
36+
)
3237
client = "spicepy" if client_name is None else client_name
3338
return f"{client}/{package_version} ({system_info})"
3439

tests/conftest.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,26 @@
22

33
from __future__ import annotations
44

5-
import os
65
from collections.abc import Generator
6+
import os
77
from unittest.mock import MagicMock, patch
88

99
import pytest
1010

11-
1211
# ============== Markers ==============
1312

1413

1514
def pytest_configure(config: pytest.Config) -> None:
1615
"""Register custom markers."""
17-
config.addinivalue_line("markers", "unit: Unit tests that don't require external services")
18-
config.addinivalue_line("markers", "integration: Integration tests that require Spice runtime")
19-
config.addinivalue_line("markers", "cloud: Tests that require Spice.ai cloud connection")
16+
config.addinivalue_line(
17+
"markers", "unit: Unit tests that don't require external services"
18+
)
19+
config.addinivalue_line(
20+
"markers", "integration: Integration tests that require Spice runtime"
21+
)
22+
config.addinivalue_line(
23+
"markers", "cloud: Tests that require Spice.ai cloud connection"
24+
)
2025
config.addinivalue_line("markers", "slow: Tests that take a long time to run")
2126

2227

@@ -55,7 +60,9 @@ def is_adbc_available() -> bool:
5560
return False
5661

5762

58-
skip_if_no_adbc = pytest.mark.skipif(not is_adbc_available(), reason="ADBC driver not installed")
63+
skip_if_no_adbc = pytest.mark.skipif(
64+
not is_adbc_available(), reason="ADBC driver not installed"
65+
)
5966

6067

6168
# ============== Mock Fixtures ==============

tests/test_client.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import annotations
44

55
import os
6-
import threading
76
from pathlib import Path
7+
import threading
88
from unittest.mock import MagicMock, patch
99

1010
import pyarrow as pa
@@ -98,7 +98,10 @@ def test_client_with_custom_urls(
9898
mock_cert.tls_root_certs = b"cert"
9999
mock_cert_class.return_value = mock_cert
100100

101-
client = Client(flight_url="grpc+tls://custom.spiceai.io", http_url="https://custom-data.spiceai.io")
101+
client = Client(
102+
flight_url="grpc+tls://custom.spiceai.io",
103+
http_url="https://custom-data.spiceai.io",
104+
)
102105

103106
assert client._flight_url == "grpc+tls://custom.spiceai.io"
104107

@@ -340,7 +343,9 @@ def test_query_with_params_passes_correct_args(
340343
client = Client(flight_url="grpc://localhost:50051", api_key="test-key")
341344
result = client.query_with_params("SELECT * FROM t WHERE id = $1", [42])
342345

343-
mock_adbc.query_with_params.assert_called_once_with("SELECT * FROM t WHERE id = $1", [42])
346+
mock_adbc.query_with_params.assert_called_once_with(
347+
"SELECT * FROM t WHERE id = $1", [42]
348+
)
344349
assert result == mock_reader
345350

346351
@patch("spicepy._client._SpiceFlight")
@@ -781,7 +786,9 @@ def test_spice_flight_query_basic(
781786
)
782787

783788
# Mock the _threaded_flight_do_get method directly
784-
with patch.object(flight_instance, "_threaded_flight_do_get", return_value=mock_reader):
789+
with patch.object(
790+
flight_instance, "_threaded_flight_do_get", return_value=mock_reader
791+
):
785792
result = flight_instance.query("SELECT 1")
786793

787794
assert result is mock_reader
@@ -814,7 +821,9 @@ def test_spice_flight_query_with_timeout(
814821
)
815822

816823
# Mock the _threaded_flight_do_get method directly
817-
with patch.object(flight_instance, "_threaded_flight_do_get", return_value=mock_reader):
824+
with patch.object(
825+
flight_instance, "_threaded_flight_do_get", return_value=mock_reader
826+
):
818827
flight_instance.query("SELECT 1", timeout=60)
819828

820829
@patch("spicepy._client.flight")
@@ -878,7 +887,9 @@ def mock_threaded_do_get(*args, **kwargs):
878887
raise FlightUnauthenticatedError("")
879888
return mock_reader
880889

881-
with patch.object(flight_instance, "_threaded_flight_do_get", side_effect=mock_threaded_do_get):
890+
with patch.object(
891+
flight_instance, "_threaded_flight_do_get", side_effect=mock_threaded_do_get
892+
):
882893
result = flight_instance.query("SELECT 1")
883894

884895
assert result == mock_reader

0 commit comments

Comments
 (0)