Skip to content

Commit 03fb860

Browse files
committed
add: test case.
1 parent 30a80b0 commit 03fb860

File tree

4 files changed

+88
-10
lines changed

4 files changed

+88
-10
lines changed

turu-snowflake/src/turu/snowflake/record/async_record_cursor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import AsyncIterator
1+
from typing import AsyncIterator, Generic, cast
22

33
import turu.core.record
44
import turu.snowflake.async_cursor
@@ -11,14 +11,14 @@
1111

1212
class AsyncRecordCursor( # type: ignore[override]
1313
turu.core.record.AsyncRecordCursor,
14-
turu.snowflake.async_cursor.AsyncCursor[
14+
Generic[
1515
GenericRowType,
1616
GenericPandasDataFrame,
1717
GenericPyArrowTable,
1818
],
1919
):
2020
async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
21-
df = await super().fetch_pandas_all(**kwargs)
21+
df = cast(GenericPandasDataFrame, await self._cursor.fetch_pandas_all(**kwargs)) # type: ignore[assignment]
2222

2323
if isinstance(self._recorder, turu.core.record.CsvRecorder):
2424
if limit := self._recorder._options.get("limit"):
@@ -35,7 +35,10 @@ async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
3535
async def fetch_pandas_batches(
3636
self, **kwargs
3737
) -> AsyncIterator[GenericPandasDataFrame]:
38-
batches = super().fetch_pandas_batches(**kwargs)
38+
batches = cast(
39+
AsyncIterator[GenericPandasDataFrame],
40+
self._cursor.fetch_pandas_batches(**kwargs), # type: ignore[assignment]
41+
)
3942
if isinstance(self._recorder, turu.core.record.CsvRecorder):
4043
if limit := self._recorder._options.get("limit"):
4144
async for batch in batches:

turu-snowflake/src/turu/snowflake/record/record_cursor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterator
1+
from typing import Generic, Iterator, cast
22

33
import turu.core.record
44
import turu.snowflake.cursor
@@ -11,12 +11,10 @@
1111

1212
class RecordCursor( # type: ignore[override]
1313
turu.core.record.RecordCursor,
14-
turu.snowflake.cursor.Cursor[
15-
GenericRowType, GenericPandasDataFrame, GenericPyArrowTable
16-
],
14+
Generic[GenericRowType, GenericPandasDataFrame, GenericPyArrowTable],
1715
):
1816
def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
19-
df = super().fetch_pandas_all(**kwargs)
17+
df = cast(GenericPandasDataFrame, self._cursor.fetch_pandas_all(**kwargs)) # type: ignore[assignment]
2018

2119
if isinstance(self._recorder, turu.core.record.CsvRecorder):
2220
if limit := self._recorder._options.get("limit"):
@@ -31,7 +29,10 @@ def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
3129
return df
3230

3331
def fetch_pandas_batches(self, **kwargs) -> Iterator[GenericPandasDataFrame]:
34-
batches = super().fetch_pandas_batches(**kwargs)
32+
batches = cast(
33+
Iterator[GenericPandasDataFrame],
34+
self._cursor.fetch_pandas_batches(**kwargs), # type: ignore[assignment]
35+
)
3536
if isinstance(self._recorder, turu.core.record.CsvRecorder):
3637
if limit := self._recorder._options.get("limit"):
3738
for batch in batches:

turu-snowflake/tests/turu/test_snowflake_mock.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import turu.snowflake
8+
from turu.core.record import record_to_csv
89
from turu.snowflake.features import (
910
USE_PANDAS,
1011
USE_PANDERA,
@@ -416,3 +417,38 @@ class RowModel(pa.DataFrameModel):
416417
).execute_map(RowModel, "select 1 as ID union all select 2 ID") as cursor:
417418
with pytest.raises(pandera.errors.SchemaInitError):
418419
cursor.fetch_pandas_all()
420+
421+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
422+
def test_record_to_csv_and_fetch_pandas_all_with_limit_options(
423+
self, mock_connection: turu.snowflake.MockConnection
424+
):
425+
import pandas as pd # type: ignore[import]
426+
427+
with tempfile.NamedTemporaryFile() as file:
428+
with record_to_csv(
429+
file.name,
430+
mock_connection.inject_response(
431+
pd.DataFrame, pd.DataFrame({"ID": list(range(10))})
432+
).execute_map(pd.DataFrame, "select 1"),
433+
limit=5,
434+
) as cursor:
435+
assert cursor.fetch_pandas_all().equals(
436+
pd.DataFrame({"ID": list(range(5))})
437+
)
438+
439+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
440+
def test_record_to_csv_and_fetch_pandas_batches_with_limit_options(
441+
self, mock_connection: turu.snowflake.MockConnection
442+
):
443+
import pandas as pd # type: ignore[import]
444+
445+
with tempfile.NamedTemporaryFile() as file:
446+
with record_to_csv(
447+
file.name,
448+
mock_connection.inject_response(
449+
pd.DataFrame, pd.DataFrame({"ID": list(range(10))})
450+
).execute_map(pd.DataFrame, "select 1"),
451+
limit=5,
452+
) as cursor:
453+
for batch in cursor.fetch_pandas_batches():
454+
assert batch.equals(pd.DataFrame({"ID": list(range(5))}))

turu-snowflake/tests/turu/test_snowflake_mock_async.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import turu.snowflake
8+
from turu.core.record import record_to_csv
89
from turu.snowflake.features import USE_PANDAS, USE_PANDERA, USE_PYARROW, PyArrowTable
910
from typing_extensions import Never
1011

@@ -508,3 +509,40 @@ class RowModel(pa.DataFrameModel):
508509
.execute_map(RowModel, "select 1 as ID union all select 2 as ID")
509510
) as cursor:
510511
assert (await cursor.fetch_pandas_all()).equals(expected)
512+
513+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
514+
@pytest.mark.asyncio
515+
async def test_record_to_csv_and_fetch_pandas_all_with_limit_options(
516+
self, mock_async_connection: turu.snowflake.MockAsyncConnection
517+
):
518+
import pandas as pd # type: ignore[import]
519+
520+
with tempfile.NamedTemporaryFile() as file:
521+
async with record_to_csv(
522+
file.name,
523+
await mock_async_connection.inject_response(
524+
pd.DataFrame, pd.DataFrame({"ID": list(range(10))})
525+
).execute_map(pd.DataFrame, "select 1"),
526+
limit=5,
527+
) as cursor:
528+
assert (await cursor.fetch_pandas_all()).equals(
529+
pd.DataFrame({"ID": list(range(5))})
530+
)
531+
532+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
533+
@pytest.mark.asyncio
534+
async def test_record_to_csv_and_fetch_pandas_batches_with_limit_options(
535+
self, mock_async_connection: turu.snowflake.MockAsyncConnection
536+
):
537+
import pandas as pd # type: ignore[import]
538+
539+
with tempfile.NamedTemporaryFile() as file:
540+
async with record_to_csv(
541+
file.name,
542+
await mock_async_connection.inject_response(
543+
pd.DataFrame, pd.DataFrame({"ID": list(range(10))})
544+
).execute_map(pd.DataFrame, "select 1"),
545+
limit=5,
546+
) as cursor:
547+
async for batch in cursor.fetch_pandas_batches():
548+
assert batch.equals(pd.DataFrame({"ID": list(range(5))}))

0 commit comments

Comments
 (0)