Skip to content

Commit 188d89e

Browse files
authored
Merge branch 'main' into dependabot/pip/mkdocs-material-9.5.7
2 parents e74e6c2 + a956d4f commit 188d89e

File tree

4 files changed

+50
-12
lines changed

4 files changed

+50
-12
lines changed

turu-core/src/turu/core/mock/store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class TuruMockStore:
1111
def __init__(self):
12-
self._data: List[Tuple[Optional[Type], Sequence | None | Exception]] = []
12+
self._data: List[Tuple[Optional[Type], Union[Sequence, None, Exception]]] = []
1313
self._counter = 0
1414

1515
@overload
@@ -37,15 +37,18 @@ def inject_response(self, row_type, response):
3737
def provide_response(
3838
self,
3939
row_type: Optional[Type],
40-
) -> "Sequence[Any] | None":
40+
) -> "Optional[Sequence[Any]]":
4141
self._counter += 1
4242

4343
if len(self._data) == 0:
4444
raise TuruMockStoreDataNotFoundError(self._counter)
4545

4646
_row_type, _response = self._data.pop(0)
4747

48-
if _row_type is not row_type:
48+
if _row_type is not row_type and not (
49+
row_type.__module__ == _row_type.__module__
50+
and row_type.__name__ == _row_type.__name__
51+
):
4952
raise TuruMockResponseTypeMismatchError(row_type, _row_type, self._counter)
5053

5154
if isinstance(_response, Exception):

turu-snowflake/src/turu/snowflake/record/async_record_cursor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
class AsyncRecordCursor(turu.core.record.AsyncRecordCursor):
66
async def fetch_pandas_all(self, **kwargs) -> "PandasDataFrame":
77
df: PandasDataFrame = self._raw_cursor.fetch_pandas_all(**kwargs) # type: ignore
8+
89
if isinstance(self._recorder, turu.core.record.CsvRecorder):
9-
df.to_csv(self._recorder.file, index=False)
10+
if limit := self._recorder._options.get("limit"):
11+
df.head(limit).to_csv(self._recorder.file, index=False)
12+
13+
else:
14+
df.to_csv(self._recorder.file, index=False)
1015

1116
return df

turu-snowflake/src/turu/snowflake/record/record_cursor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
class RecordCursor(turu.core.record.RecordCursor):
66
def fetch_pandas_all(self, **kwargs) -> "PandasDataFrame":
77
df: PandasDataFrame = self._raw_cursor.fetch_pandas_all(**kwargs) # type: ignore
8+
89
if isinstance(self._recorder, turu.core.record.CsvRecorder):
9-
df.to_csv(self._recorder.file, index=False)
10+
if limit := self._recorder._options.get("limit"):
11+
df.head(limit).to_csv(self._recorder.file, index=False)
12+
13+
else:
14+
df.to_csv(self._recorder.file, index=False)
1015

1116
return df

turu-snowflake/tests/turu/test_snowflake.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,22 +244,47 @@ def test_record_pandas_dataframe(self, connection: turu.snowflake.Connection):
244244
import pandas as pd # type: ignore[import]
245245
from pandas.testing import assert_frame_equal # type: ignore[import]
246246

247+
with tempfile.NamedTemporaryFile() as file:
248+
with record_to_csv(
249+
file.name,
250+
connection.execute_map(
251+
pd.DataFrame, "select 1 as ID union all select 2 AS ID"
252+
),
253+
) as cursor:
254+
expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8")
255+
256+
assert_frame_equal(cursor.fetch_pandas_all(), expected)
257+
258+
assert (
259+
Path(file.name).read_text()
260+
== dedent(
261+
"""
262+
ID
263+
1
264+
2
265+
"""
266+
).lstrip()
267+
)
268+
269+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
270+
def test_record_pandas_dataframe_with_limit_option(
271+
self, connection: turu.snowflake.Connection
272+
):
273+
import pandas as pd # type: ignore[import]
274+
from pandas.testing import assert_frame_equal # type: ignore[import]
275+
247276
with tempfile.NamedTemporaryFile() as file:
248277
with record_to_csv(
249278
file.name,
250279
connection.execute_map(
251280
pd.DataFrame,
252-
"select 1 as ID union all select 2 AS ID",
281+
"select value::integer as ID from table(flatten(ARRAY_GENERATE_RANGE(1, 10)))",
253282
),
283+
limit=2,
254284
) as cursor:
255-
expected = pd.DataFrame(
256-
{"ID": [1, 2]},
257-
dtype="int8",
258-
)
285+
expected = pd.DataFrame({"ID": list(range(1, 10))}, dtype="object")
259286

260287
assert_frame_equal(cursor.fetch_pandas_all(), expected)
261-
for row in expected.values:
262-
print(row)
263288

264289
assert (
265290
Path(file.name).read_text()

0 commit comments

Comments
 (0)