Skip to content

Commit 81e3206

Browse files
authored
Merge branch 'main' into dependabot/pip/ruff-0.2.1
2 parents 71b6cdb + 1aa8d84 commit 81e3206

File tree

4 files changed

+109
-10
lines changed

4 files changed

+109
-10
lines changed

turu-snowflake/src/turu/snowflake/record/async_record_cursor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@ async def fetch_pandas_all(self, **kwargs) -> "PandasDataFrame":
88

99
if isinstance(self._recorder, turu.core.record.CsvRecorder):
1010
if limit := self._recorder._options.get("limit"):
11-
df.head(limit).to_csv(self._recorder.file, index=False)
11+
record_df = df.head(limit)
1212

1313
else:
14-
df.to_csv(self._recorder.file, index=False)
14+
record_df = df
15+
16+
record_df.to_csv(
17+
self._recorder.file,
18+
index=False,
19+
header=self._recorder._options.get("header", True),
20+
)
1521

1622
return df

turu-snowflake/src/turu/snowflake/record/record_cursor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@ def fetch_pandas_all(self, **kwargs) -> "PandasDataFrame":
88

99
if isinstance(self._recorder, turu.core.record.CsvRecorder):
1010
if limit := self._recorder._options.get("limit"):
11-
df.head(limit).to_csv(self._recorder.file, index=False)
11+
record_df = df.head(limit)
1212

1313
else:
14-
df.to_csv(self._recorder.file, index=False)
14+
record_df = df
15+
16+
record_df.to_csv(
17+
self._recorder.file,
18+
index=False,
19+
header=self._recorder._options.get("header", True),
20+
)
1521

1622
return df

turu-snowflake/tests/turu/test_snowflake.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,35 @@ def test_record_pandas_dataframe(self, connection: turu.snowflake.Connection):
266266
).lstrip()
267267
)
268268

269+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
270+
def test_record_pandas_dataframe_without_header_option(
271+
self, connection: turu.snowflake.Connection
272+
):
273+
import pandas as pd # type: ignore[import]
274+
from pandas.testing import assert_frame_equal # type: ignore[import]
275+
276+
with tempfile.NamedTemporaryFile() as file:
277+
with record_to_csv(
278+
file.name,
279+
connection.execute_map(
280+
pd.DataFrame, "select 1 as ID union all select 2 AS ID"
281+
),
282+
header=False,
283+
) as cursor:
284+
expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8")
285+
286+
assert_frame_equal(cursor.fetch_pandas_all(), expected)
287+
288+
assert (
289+
Path(file.name).read_text()
290+
== dedent(
291+
"""
292+
1
293+
2
294+
"""
295+
).lstrip()
296+
)
297+
269298
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
270299
def test_record_pandas_dataframe_with_limit_option(
271300
self, connection: turu.snowflake.Connection

turu-snowflake/tests/turu/test_snowflake_async.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,14 +379,72 @@ async def test_record_pandas_dataframe(
379379
"select 1 as ID union all select 2 AS ID",
380380
),
381381
) as cursor:
382-
expected = pd.DataFrame(
383-
{"ID": [1, 2]},
384-
dtype="int8",
385-
)
382+
expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8")
383+
384+
assert_frame_equal(await cursor.fetch_pandas_all(), expected)
385+
386+
assert (
387+
Path(file.name).read_text()
388+
== dedent(
389+
"""
390+
ID
391+
1
392+
2
393+
"""
394+
).lstrip()
395+
)
396+
397+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
398+
@pytest.mark.asyncio
399+
async def test_record_pandas_dataframe_without_header_option(
400+
self, async_connection: turu.snowflake.AsyncConnection
401+
):
402+
import pandas as pd # type: ignore[import]
403+
from pandas.testing import assert_frame_equal # type: ignore[import]
404+
405+
with tempfile.NamedTemporaryFile() as file:
406+
async with record_to_csv(
407+
file.name,
408+
await async_connection.execute_map(
409+
pd.DataFrame,
410+
"select 1 as ID union all select 2 AS ID",
411+
),
412+
header=False,
413+
) as cursor:
414+
expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8")
415+
416+
assert_frame_equal(await cursor.fetch_pandas_all(), expected)
417+
418+
assert (
419+
Path(file.name).read_text()
420+
== dedent(
421+
"""
422+
1
423+
2
424+
"""
425+
).lstrip()
426+
)
427+
428+
@pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed")
429+
@pytest.mark.asyncio
430+
async def test_record_pandas_dataframe_with_limit_option(
431+
self, async_connection: turu.snowflake.AsyncConnection
432+
):
433+
import pandas as pd # type: ignore[import]
434+
from pandas.testing import assert_frame_equal # type: ignore[import]
435+
436+
with tempfile.NamedTemporaryFile() as file:
437+
async with record_to_csv(
438+
file.name,
439+
await async_connection.execute_map(
440+
pd.DataFrame,
441+
"select value::integer as ID from table(flatten(ARRAY_GENERATE_RANGE(1, 10)))",
442+
),
443+
limit=2,
444+
) as cursor:
445+
expected = pd.DataFrame({"ID": list(range(1, 10))}, dtype="object")
386446

387447
assert_frame_equal(await cursor.fetch_pandas_all(), expected)
388-
for row in expected.values:
389-
print(row)
390448

391449
assert (
392450
Path(file.name).read_text()

0 commit comments

Comments
 (0)