Skip to content

Commit d896707

Browse files
committed
fix: typehint.
1 parent e8506aa commit d896707

File tree

9 files changed

+39
-24
lines changed

9 files changed

+39
-24
lines changed

turu-snowflake/poetry.lock

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

turu-snowflake/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ pytest-cov = "^4.1.0"
6363
pytest-asyncio = "^0.23.2"
6464
pandas = "^2.1.4"
6565
pyarrow = "^14.0.2"
66+
pyarrow-stubs = "^10.0.1.7"
6667
numpy = "^1.26.3"
6768

6869
[tool.taskipy.tasks]

turu-snowflake/src/turu/snowflake/async_cursor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,17 @@ async def fetchall(self) -> List[GenericRowType]:
290290
async def fetch_arrow_all(self) -> GenericPyArrowTable:
291291
"""Fetches a single Arrow Table."""
292292

293-
return self._raw_cursor.fetch_arrow_all(force_return_table=True)
293+
return cast(
294+
GenericPyArrowTable,
295+
self._raw_cursor.fetch_arrow_all(force_return_table=True),
296+
)
294297

295298
async def fetch_arrow_batches(self) -> Iterator[GenericPyArrowTable]:
296299
"""Fetches Arrow Tables in batches, where 'batch' refers to Snowflake Chunk."""
297300

298-
return self._raw_cursor.fetch_arrow_batches()
301+
return cast(
302+
Iterator[GenericPyArrowTable], self._raw_cursor.fetch_arrow_batches()
303+
)
299304

300305
async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFlame:
301306
"""Fetch Pandas dataframes."""

turu-snowflake/src/turu/snowflake/cursor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ def fetch_arrow_all(self) -> GenericPyArrowTable:
324324
def fetch_arrow_batches(self) -> "Iterator[GenericPyArrowTable]":
325325
"""Fetches Arrow Tables in batches, where 'batch' refers to Snowflake Chunk."""
326326

327-
return self._raw_cursor.fetch_arrow_batches()
327+
return cast(
328+
Iterator[GenericPyArrowTable], self._raw_cursor.fetch_arrow_batches()
329+
)
328330

329331
def fetch_pandas_all(self, **kwargs) -> "GenericPandasDataFlame":
330332
"""Fetch a single Pandas dataframe."""

turu-snowflake/src/turu/snowflake/features.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, TypeVar
1+
from typing import TypeVar
22

33
from turu.core.features import _NotSupportFeature
44
from typing_extensions import Never, TypeAlias
@@ -20,11 +20,8 @@
2020
import pyarrow # type: ignore[import] # noqa: F401
2121

2222
USE_PYARROW = True
23-
if TYPE_CHECKING:
24-
PyArrowTable: TypeAlias = _NotSupportFeature # type: ignore
23+
PyArrowTable = pyarrow.Table # type: ignore
2524

26-
else:
27-
PyArrowTable: TypeAlias = pyarrow.Table
2825

2926
except ImportError:
3027
USE_PYARROW = False

turu-snowflake/src/turu/snowflake/mock_async_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def inject_response_from_csv( # type: ignore[override]
127127

128128
self.inject_response(
129129
row_type,
130-
cast(Any, pyarrow.csv.read_csv(filepath, **options)),
130+
pyarrow.csv.read_csv(filepath, **options), # type: ignore
131131
)
132132

133133
else:

turu-snowflake/tests/turu/test_snowflake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ def test_cursor_use_role(self, connection: turu.snowflake.Connection):
153153
def test_fetch_arrow_all(self, connection: turu.snowflake.Connection):
154154
import pyarrow as pa
155155

156-
expected = pa.table(
156+
expected: pa.Table = pa.table(
157157
data=[pa.array([1, 2], type=pa.int8())],
158158
schema=pa.schema([pa.field("ID", pa.int8(), False)]),
159-
)
159+
) # type: ignore
160160

161161
with connection.execute("select 1 as ID union all select 2 as ID") as cursor:
162162
assert cursor.fetch_arrow_all() == expected

turu-snowflake/tests/turu/test_snowflake_mock.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ def test_cursor_use_role(self, mock_connection: turu.snowflake.MockConnection):
173173
def test_fetch_arrow_all(self, mock_connection: turu.snowflake.MockConnection):
174174
import pyarrow as pa
175175

176-
expected = pa.table(
176+
expected: pa.Table = pa.table(
177177
data=[pa.array([1, 2], type=pa.int8())],
178178
schema=pa.schema([pa.field("ID", pa.int8(), False)]),
179-
)
179+
) # type: ignore
180180

181181
mock_connection.inject_response(PyArrowTable, expected)
182182

@@ -192,10 +192,10 @@ def test_fetch_arrow_all(self, mock_connection: turu.snowflake.MockConnection):
192192
def test_fetch_arrow_batches(self, mock_connection: turu.snowflake.MockConnection):
193193
import pyarrow as pa
194194

195-
expected = pa.table(
195+
expected: pa.Table = pa.table(
196196
data=[pa.array([1, 2], type=pa.int8())],
197197
schema=pa.schema([pa.field("ID", pa.int8(), False)]),
198-
)
198+
) # type: ignore
199199

200200
with mock_connection.inject_response(PyArrowTable, expected).execute_map(
201201
PyArrowTable, "select 1 as ID union all select 2 as ID"
@@ -232,10 +232,10 @@ def test_inject_pyarrow_response_from_csv(
232232
):
233233
import pyarrow as pa
234234

235-
expected = pa.table(
235+
expected: pa.Table = pa.table(
236236
data=[pa.array([1, 2], type=pa.int64())],
237237
schema=pa.schema([pa.field("ID", pa.int64())]),
238-
)
238+
) # type: ignore
239239

240240
with tempfile.NamedTemporaryFile() as file:
241241
Path(file.name).write_text(

turu-snowflake/tests/turu/test_snowflake_mock_async.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ async def test_fetch_arrow_all(
229229
):
230230
import pyarrow as pa
231231

232-
expected = pa.table(
232+
expected: pa.Table = pa.table(
233233
data=[pa.array([1, 2], type=pa.int8())],
234234
schema=pa.schema([pa.field("ID", pa.int8(), False)]),
235-
)
235+
) # type: ignore
236236

237237
mock_async_connection.inject_response(PyArrowTable, expected)
238238

@@ -251,10 +251,10 @@ async def test_fetch_arrow_batches(
251251
):
252252
import pyarrow as pa
253253

254-
expected = pa.table(
254+
expected: pa.Table = pa.table(
255255
data=[pa.array([1, 2], type=pa.int8())],
256256
schema=pa.schema([pa.field("ID", pa.int8(), False)]),
257-
)
257+
) # type: ignore
258258

259259
async with await mock_async_connection.inject_response(
260260
PyArrowTable, expected
@@ -344,10 +344,10 @@ async def test_inject_pyarrow_response_from_csv(
344344
):
345345
import pyarrow as pa
346346

347-
expected = pa.table(
347+
expected: pa.Table = pa.table(
348348
data=[pa.array([1, 2], type=pa.int64())],
349349
schema=pa.schema([pa.field("ID", pa.int64())]),
350-
)
350+
) # type: ignore
351351

352352
with tempfile.NamedTemporaryFile() as file:
353353
Path(file.name).write_text(

0 commit comments

Comments
 (0)