Skip to content

Commit 48797cb

Browse files
authored
Merge pull request #88 from yassun7010/add_use_method
feat: add use_* methods.
2 parents bacbb35 + 3440a18 commit 48797cb

File tree

2 files changed

+85
-20
lines changed

2 files changed

+85
-20
lines changed

turu-snowflake/src/turu/snowflake/record/async_record_cursor.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AsyncRecordCursor( # type: ignore[override]
1818
],
1919
):
2020
async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
21-
df = cast(GenericPandasDataFrame, await self._cursor.fetch_pandas_all(**kwargs)) # type: ignore[assignment]
21+
df = await self._sf_cursor.fetch_pandas_all(**kwargs)
2222

2323
if isinstance(self._recorder, turu.core.record.CsvRecorder):
2424
if limit := self._recorder._options.get("limit"):
@@ -35,10 +35,8 @@ async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
3535
async def fetch_pandas_batches(
3636
self, **kwargs
3737
) -> AsyncIterator[GenericPandasDataFrame]:
38-
batches = cast(
39-
AsyncIterator[GenericPandasDataFrame],
40-
self._cursor.fetch_pandas_batches(**kwargs), # type: ignore[assignment]
41-
)
38+
batches = self._sf_cursor.fetch_pandas_batches(**kwargs)
39+
4240
if isinstance(self._recorder, turu.core.record.CsvRecorder):
4341
if limit := self._recorder._options.get("limit"):
4442
async for batch in batches:
@@ -52,7 +50,7 @@ async def fetch_pandas_batches(
5250
yield batch
5351

5452
async def fetch_arrow_all(self) -> GenericPyArrowTable:
55-
table = cast(GenericPyArrowTable, await self._cursor.fetch_arrow_all()) # type: ignore[assignment]
53+
table = await self._sf_cursor.fetch_arrow_all()
5654

5755
if isinstance(self._recorder, turu.core.record.CsvRecorder):
5856
if limit := self._recorder._options.get("limit"):
@@ -67,10 +65,8 @@ async def fetch_arrow_all(self) -> GenericPyArrowTable:
6765
return table
6866

6967
async def fetch_arrow_batches(self) -> AsyncIterator[GenericPyArrowTable]:
70-
batches = cast(
71-
AsyncIterator[GenericPyArrowTable],
72-
self._cursor.fetch_arrow_batches(), # type: ignore[assignment]
73-
)
68+
batches = self._sf_cursor.fetch_arrow_batches()
69+
7470
if isinstance(self._recorder, turu.core.record.CsvRecorder):
7571
if limit := self._recorder._options.get("limit"):
7672
async for batch in batches:
@@ -82,3 +78,39 @@ async def fetch_arrow_batches(self) -> AsyncIterator[GenericPyArrowTable]:
8278

8379
async for batch in batches:
8480
yield batch
81+
82+
def use_warehouse(self, warehouse: str, /) -> "AsyncRecordCursor":
83+
"""Use a warehouse in cursor."""
84+
85+
self._sf_cursor.use_warehouse(warehouse)
86+
87+
return self
88+
89+
def use_database(self, database: str, /) -> "AsyncRecordCursor":
90+
"""Use a database in cursor."""
91+
92+
self._sf_cursor.use_database(database)
93+
94+
return self
95+
96+
def use_schema(self, schema: str, /) -> "AsyncRecordCursor":
97+
"""Use a schema in cursor."""
98+
99+
self._sf_cursor.use_schema(schema)
100+
101+
return self
102+
103+
def use_role(self, role: str, /) -> "AsyncRecordCursor":
104+
"""Use a role in cursor."""
105+
106+
self._sf_cursor.use_role(role)
107+
108+
return self
109+
110+
@property
111+
def _sf_cursor(
112+
self,
113+
) -> turu.snowflake.async_cursor.AsyncCursor[
114+
GenericRowType, GenericPandasDataFrame, GenericPyArrowTable
115+
]:
116+
return cast(turu.snowflake.async_cursor.AsyncCursor, self._cursor)

turu-snowflake/src/turu/snowflake/record/record_cursor.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
GenericPandasDataFrame,
88
GenericPyArrowTable,
99
)
10+
from typing_extensions import Self
1011

1112

1213
class RecordCursor( # type: ignore[override]
1314
turu.core.record.RecordCursor,
1415
Generic[GenericRowType, GenericPandasDataFrame, GenericPyArrowTable],
1516
):
1617
def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
17-
df = cast(GenericPandasDataFrame, self._cursor.fetch_pandas_all(**kwargs)) # type: ignore[assignment]
18+
df = self._sf_cursor.fetch_pandas_all(**kwargs)
1819

1920
if isinstance(self._recorder, turu.core.record.CsvRecorder):
2021
if limit := self._recorder._options.get("limit"):
@@ -29,10 +30,8 @@ def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
2930
return df
3031

3132
def fetch_pandas_batches(self, **kwargs) -> Iterator[GenericPandasDataFrame]:
32-
batches = cast(
33-
Iterator[GenericPandasDataFrame],
34-
self._cursor.fetch_pandas_batches(**kwargs), # type: ignore[assignment]
35-
)
33+
batches = self._sf_cursor.fetch_pandas_batches(**kwargs)
34+
3635
if isinstance(self._recorder, turu.core.record.CsvRecorder):
3736
if limit := self._recorder._options.get("limit"):
3837
for batch in batches:
@@ -45,7 +44,7 @@ def fetch_pandas_batches(self, **kwargs) -> Iterator[GenericPandasDataFrame]:
4544
return batches
4645

4746
def fetch_arrow_all(self) -> GenericPyArrowTable:
48-
table = cast(GenericPyArrowTable, self._cursor.fetch_arrow_all()) # type: ignore[assignment]
47+
table = self._sf_cursor.fetch_arrow_all()
4948

5049
if isinstance(self._recorder, turu.core.record.CsvRecorder):
5150
if limit := self._recorder._options.get("limit"):
@@ -60,10 +59,8 @@ def fetch_arrow_all(self) -> GenericPyArrowTable:
6059
return table
6160

6261
def fetch_arrow_batches(self) -> Iterator[GenericPyArrowTable]:
63-
batches = cast(
64-
Iterator[GenericPyArrowTable],
65-
self._cursor.fetch_arrow_batches(), # type: ignore[assignment]
66-
)
62+
batches = self._sf_cursor.fetch_arrow_batches()
63+
6764
if isinstance(self._recorder, turu.core.record.CsvRecorder):
6865
if limit := self._recorder._options.get("limit"):
6966
for batch in batches:
@@ -74,3 +71,39 @@ def fetch_arrow_batches(self) -> Iterator[GenericPyArrowTable]:
7471
return
7572

7673
return batches
74+
75+
def use_warehouse(self, warehouse: str, /) -> Self:
76+
"""Use a warehouse in cursor."""
77+
78+
self._sf_cursor.use_warehouse(warehouse)
79+
80+
return self
81+
82+
def use_database(self, database: str, /) -> Self:
83+
"""Use a database in cursor."""
84+
85+
self._sf_cursor.use_database(database)
86+
87+
return self
88+
89+
def use_schema(self, schema: str, /) -> Self:
90+
"""Use a schema in cursor."""
91+
92+
self._sf_cursor.use_schema(schema)
93+
94+
return self
95+
96+
def use_role(self, role: str, /) -> Self:
97+
"""Use a role in cursor."""
98+
99+
self._sf_cursor.use_role(role)
100+
101+
return self
102+
103+
@property
104+
def _sf_cursor(
105+
self,
106+
) -> turu.snowflake.cursor.Cursor[
107+
GenericRowType, GenericPandasDataFrame, GenericPyArrowTable
108+
]:
109+
return cast(turu.snowflake.cursor.Cursor, self._cursor)

0 commit comments

Comments
 (0)