77 GenericPandasDataFrame ,
88 GenericPyArrowTable ,
99)
10+ from typing_extensions import Self
1011
1112
1213class RecordCursor ( # type: ignore[override]
1314 turu .core .record .RecordCursor ,
1415 Generic [GenericRowType , GenericPandasDataFrame , GenericPyArrowTable ],
1516):
1617 def fetch_pandas_all (self , ** kwargs ) -> GenericPandasDataFrame :
17- df = cast ( GenericPandasDataFrame , self ._cursor .fetch_pandas_all (** kwargs )) # type: ignore[assignment]
18+ df = self ._sf_cursor .fetch_pandas_all (** kwargs )
1819
1920 if isinstance (self ._recorder , turu .core .record .CsvRecorder ):
2021 if limit := self ._recorder ._options .get ("limit" ):
@@ -29,10 +30,8 @@ def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
2930 return df
3031
3132 def fetch_pandas_batches (self , ** kwargs ) -> Iterator [GenericPandasDataFrame ]:
32- batches = cast (
33- Iterator [GenericPandasDataFrame ],
34- self ._cursor .fetch_pandas_batches (** kwargs ), # type: ignore[assignment]
35- )
33+ batches = self ._sf_cursor .fetch_pandas_batches (** kwargs )
34+
3635 if isinstance (self ._recorder , turu .core .record .CsvRecorder ):
3736 if limit := self ._recorder ._options .get ("limit" ):
3837 for batch in batches :
@@ -45,7 +44,7 @@ def fetch_pandas_batches(self, **kwargs) -> Iterator[GenericPandasDataFrame]:
4544 return batches
4645
4746 def fetch_arrow_all (self ) -> GenericPyArrowTable :
48- table = cast ( GenericPyArrowTable , self ._cursor .fetch_arrow_all ()) # type: ignore[assignment]
47+ table = self ._sf_cursor .fetch_arrow_all ()
4948
5049 if isinstance (self ._recorder , turu .core .record .CsvRecorder ):
5150 if limit := self ._recorder ._options .get ("limit" ):
@@ -60,10 +59,8 @@ def fetch_arrow_all(self) -> GenericPyArrowTable:
6059 return table
6160
6261 def fetch_arrow_batches (self ) -> Iterator [GenericPyArrowTable ]:
63- batches = cast (
64- Iterator [GenericPyArrowTable ],
65- self ._cursor .fetch_arrow_batches (), # type: ignore[assignment]
66- )
62+ batches = self ._sf_cursor .fetch_arrow_batches ()
63+
6764 if isinstance (self ._recorder , turu .core .record .CsvRecorder ):
6865 if limit := self ._recorder ._options .get ("limit" ):
6966 for batch in batches :
@@ -74,3 +71,39 @@ def fetch_arrow_batches(self) -> Iterator[GenericPyArrowTable]:
7471 return
7572
7673 return batches
74+
75+ def use_warehouse (self , warehouse : str , / ) -> Self :
76+ """Use a warehouse in cursor."""
77+
78+ self ._sf_cursor .use_warehouse (warehouse )
79+
80+ return self
81+
82+ def use_database (self , database : str , / ) -> Self :
83+ """Use a database in cursor."""
84+
85+ self ._sf_cursor .use_database (database )
86+
87+ return self
88+
89+ def use_schema (self , schema : str , / ) -> Self :
90+ """Use a schema in cursor."""
91+
92+ self ._sf_cursor .use_schema (schema )
93+
94+ return self
95+
96+ def use_role (self , role : str , / ) -> Self :
97+ """Use a role in cursor."""
98+
99+ self ._sf_cursor .use_role (role )
100+
101+ return self
102+
103+ @property
104+ def _sf_cursor (
105+ self ,
106+ ) -> turu .snowflake .cursor .Cursor [
107+ GenericRowType , GenericPandasDataFrame , GenericPyArrowTable
108+ ]:
109+ return cast (turu .snowflake .cursor .Cursor , self ._cursor )
0 commit comments