@@ -379,14 +379,72 @@ async def test_record_pandas_dataframe(
379379 "select 1 as ID union all select 2 AS ID" ,
380380 ),
381381 ) as cursor :
382- expected = pd .DataFrame (
383- {"ID" : [1 , 2 ]},
384- dtype = "int8" ,
385- )
382+ expected = pd .DataFrame ({"ID" : [1 , 2 ]}, dtype = "int8" )
383+
384+ assert_frame_equal (await cursor .fetch_pandas_all (), expected )
385+
386+ assert (
387+ Path (file .name ).read_text ()
388+ == dedent (
389+ """
390+ ID
391+ 1
392+ 2
393+ """
394+ ).lstrip ()
395+ )
396+
397+ @pytest .mark .skipif (not USE_PANDAS , reason = "pandas is not installed" )
398+ @pytest .mark .asyncio
399+ async def test_record_pandas_dataframe_without_header_option (
400+ self , async_connection : turu .snowflake .AsyncConnection
401+ ):
402+ import pandas as pd # type: ignore[import]
403+ from pandas .testing import assert_frame_equal # type: ignore[import]
404+
405+ with tempfile .NamedTemporaryFile () as file :
406+ async with record_to_csv (
407+ file .name ,
408+ await async_connection .execute_map (
409+ pd .DataFrame ,
410+ "select 1 as ID union all select 2 AS ID" ,
411+ ),
412+ header = False ,
413+ ) as cursor :
414+ expected = pd .DataFrame ({"ID" : [1 , 2 ]}, dtype = "int8" )
415+
416+ assert_frame_equal (await cursor .fetch_pandas_all (), expected )
417+
418+ assert (
419+ Path (file .name ).read_text ()
420+ == dedent (
421+ """
422+ 1
423+ 2
424+ """
425+ ).lstrip ()
426+ )
427+
428+ @pytest .mark .skipif (not USE_PANDAS , reason = "pandas is not installed" )
429+ @pytest .mark .asyncio
430+ async def test_record_pandas_dataframe_with_limit_option (
431+ self , async_connection : turu .snowflake .AsyncConnection
432+ ):
433+ import pandas as pd # type: ignore[import]
434+ from pandas .testing import assert_frame_equal # type: ignore[import]
435+
436+ with tempfile .NamedTemporaryFile () as file :
437+ async with record_to_csv (
438+ file .name ,
439+ await async_connection .execute_map (
440+ pd .DataFrame ,
441+ "select value::integer as ID from table(flatten(ARRAY_GENERATE_RANGE(1, 10)))" ,
442+ ),
443+ limit = 2 ,
444+ ) as cursor :
445+ expected = pd .DataFrame ({"ID" : list (range (1 , 10 ))}, dtype = "object" )
386446
387447 assert_frame_equal (await cursor .fetch_pandas_all (), expected )
388- for row in expected .values :
389- print (row )
390448
391449 assert (
392450 Path (file .name ).read_text ()
0 commit comments