Skip to content

Commit 7ca8d05

Browse files
committed
Support pandas in BigQuery cache
1 parent f7b88eb commit 7ca8d05

File tree

5 files changed

+59
-39
lines changed

5 files changed

+59
-39
lines changed

Diff for: airbyte/caches/base.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
if TYPE_CHECKING:
2929
from collections.abc import Iterator
3030

31+
from sqlalchemy.engine import Engine
32+
3133
from airbyte._message_iterators import AirbyteMessageIterator
3234
from airbyte.caches._state_backend_base import StateBackendBase
3335
from airbyte.progress import ProgressTracker
@@ -66,7 +68,9 @@ class CacheBase(SqlConfig, AirbyteWriterInterface):
6668
paired_destination_config_class: ClassVar[type | None] = None
6769

6870
@property
69-
def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
71+
def paired_destination_config(
72+
self,
73+
) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
7074
"""Return a dictionary of destination configuration values."""
7175
raise NotImplementedError(
7276
f"The type '{type(self).__name__}' does not define an equivalent destination "
@@ -177,6 +181,14 @@ def get_record_processor(
177181

178182
# Read methods:
179183

184+
def _read_to_pandas_dataframe(
185+
self,
186+
table_name: str,
187+
con: Engine,
188+
**kwargs,
189+
) -> pd.DataFrame:
190+
return pd.read_sql_table(table_name, con=con, **kwargs)
191+
180192
def get_records(
181193
self,
182194
stream_name: str,
@@ -191,7 +203,11 @@ def get_pandas_dataframe(
191203
"""Return a Pandas data frame with the stream's data."""
192204
table_name = self._read_processor.get_sql_table_name(stream_name)
193205
engine = self.get_sql_engine()
194-
return pd.read_sql_table(table_name, engine, schema=self.schema_name)
206+
return self._read_to_pandas_dataframe(
207+
table_name=table_name,
208+
con=engine,
209+
schema=self.schema_name,
210+
)
195211

196212
def get_arrow_dataset(
197213
self,
@@ -204,7 +220,7 @@ def get_arrow_dataset(
204220
engine = self.get_sql_engine()
205221

206222
# Read the table in chunks to handle large tables which does not fits in memory
207-
pandas_chunks = pd.read_sql_table(
223+
pandas_chunks = self._read_to_pandas_dataframe(
208224
table_name=table_name,
209225
con=engine,
210226
schema=self.schema_name,

Diff for: airbyte/caches/bigquery.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,23 @@
1919

2020
from typing import TYPE_CHECKING, ClassVar, NoReturn
2121

22+
import pandas as pd
23+
import pandas_gbq
2224
from airbyte_api.models import DestinationBigquery
25+
from google.oauth2.service_account import Credentials
2326

2427
from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor
2528
from airbyte.caches.base import (
2629
CacheBase,
2730
)
28-
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
2931
from airbyte.destinations._translate_cache_to_dest import (
3032
bigquery_cache_to_destination_configuration,
3133
)
3234

3335

3436
if TYPE_CHECKING:
37+
from collections.abc import Iterator
38+
3539
from airbyte.shared.sql_processor import SqlProcessorBase
3640

3741

@@ -48,21 +52,35 @@ def paired_destination_config(self) -> DestinationBigquery:
4852
"""Return a dictionary of destination configuration values."""
4953
return bigquery_cache_to_destination_configuration(cache=self)
5054

51-
def get_arrow_dataset(
55+
def _read_to_pandas_dataframe(
5256
self,
53-
stream_name: str,
54-
*,
55-
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
56-
) -> NoReturn:
57-
"""Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`.
58-
59-
See: https://github.com/airbytehq/PyAirbyte/issues/165
60-
"""
61-
raise NotImplementedError(
62-
"BigQuery doesn't currently support to_arrow"
63-
"Please consider using a different cache implementation for these functionalities."
57+
table_name: str,
58+
chunksize: int | None = None,
59+
**kwargs,
60+
) -> pd.DataFrame | Iterator[pd.DataFrame]:
61+
# Pop unused kwargs, maybe not the best way to do this
62+
kwargs.pop("con", None)
63+
kwargs.pop("schema", None)
64+
65+
# Read the table using pandas_gbq
66+
credentials = Credentials.from_service_account_file(self.credentials_path)
67+
result = pandas_gbq.read_gbq(
68+
f"{self.project_name}.{self.dataset_name}.{table_name}",
69+
project_id=self.project_name,
70+
credentials=credentials,
71+
**kwargs,
6472
)
6573

74+
# Cast result to DataFrame if it's not already a DataFrame
75+
if not isinstance(result, pd.DataFrame):
76+
result = pd.DataFrame(result)
77+
78+
# Return chunks as iterator if chunksize is provided
79+
if chunksize is not None:
80+
return (result[i : i + chunksize] for i in range(0, len(result), chunksize))
81+
82+
return result
83+
6684

6785
# Expose the Cache class and also the Config class.
6886
__all__ = [

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jsonschema = ">=3.2.0,<5.0"
2929
orjson = "^3.10"
3030
overrides = "^7.4.0"
3131
pandas = { version = ">=1.5.3,<3.0" }
32+
pandas-gbq = ">=0.26.1"
3233
pendulum = "<=3.0.0"
3334
psycopg = {extras = ["binary", "pool"], version = "^3.1.19"}
3435
psycopg2-binary = "^2.9.9"

Diff for: tests/integration_tests/cloud/test_cloud_sql_reads.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,8 @@ def test_read_from_deployed_connection(
6868

6969
dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users")
7070
assert dataset.stream_name == "users"
71-
data_as_list = list(dataset)
72-
assert len(data_as_list) == 100
73-
74-
# TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165
75-
# pandas_df = dataset.to_pandas()
76-
77-
pandas_df = pd.DataFrame(data_as_list)
7871

72+
pandas_df = dataset.to_pandas()
7973
assert pandas_df.shape == (100, 20)
8074

8175
# Check that no values are null
@@ -177,15 +171,10 @@ def test_read_from_previous_job(
177171
assert "users" in sync_result.stream_names
178172
dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users")
179173
assert dataset.stream_name == "users"
180-
data_as_list = list(dataset)
181-
assert len(data_as_list) == 100
182-
183-
# TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165
184-
# pandas_df = dataset.to_pandas()
185-
186-
pandas_df = pd.DataFrame(data_as_list)
187174

175+
pandas_df = dataset.to_pandas()
188176
assert pandas_df.shape == (100, 20)
177+
189178
for col in pandas_df.columns:
190179
# Check that no values are null
191180
assert pandas_df[col].notnull().all()

Diff for: tests/integration_tests/test_all_cache_types.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,11 @@ def test_faker_read(
158158
assert "Read **0** records" not in status_msg
159159
assert f"Read **{configured_count}** records" in status_msg
160160

161-
if "bigquery" not in new_generic_cache.get_sql_alchemy_url():
162-
# BigQuery doesn't support to_arrow
163-
# https://github.com/airbytehq/PyAirbyte/issues/165
164-
arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10)
165-
assert arrow_dataset.count_rows() == FAKER_SCALE_A
166-
assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10
167-
168-
# TODO: Uncomment this line after resolving https://github.com/airbytehq/PyAirbyte/issues/165
169-
# assert len(result["users"].to_pandas()) == FAKER_SCALE_A
161+
arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10)
162+
assert arrow_dataset.count_rows() == FAKER_SCALE_A
163+
assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10
164+
165+
assert len(read_result["users"].to_pandas()) == FAKER_SCALE_A
170166

171167

172168
@pytest.mark.requires_creds

0 commit comments

Comments
 (0)