Skip to content

Commit ece70f9

Browse files
[io] feat: switch to pandas_gbq to write data to BigQuery
1 parent a944948 commit ece70f9

File tree

3 files changed

+68
-178
lines changed

3 files changed

+68
-178
lines changed

libs/io/garf_io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
"""Writing GarfReport to anywhere."""
1616

17-
__version__ = '0.0.13'
17+
__version__ = '0.0.14'

libs/io/garf_io/writers/bigquery_writer.py

Lines changed: 29 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,20 @@
1616
from __future__ import annotations
1717

1818
import os
19+
from typing import Literal
1920

2021
try:
22+
import pandas as pd
23+
import pandas_gbq
2124
from google.cloud import bigquery
2225
except ImportError as e:
2326
raise ImportError(
2427
'Please install garf-io with BigQuery support - `pip install garf-io[bq]`'
2528
) from e
2629

27-
import datetime
2830
import logging
29-
from collections.abc import Sequence
3031

3132
import numpy as np
32-
import pandas as pd
33-
import proto # type: ignore
34-
from garf_core import parsers
3533
from garf_core import report as garf_report
3634
from google.cloud import exceptions as google_cloud_exceptions
3735

@@ -40,6 +38,13 @@
4038

4139
logger = logging.getLogger(__name__)
4240

41+
_WRITE_DISPOSITION_MAPPING = {
42+
'WRITE_TRUNCATE': 'replace',
43+
'WRITE_TRUNCATE_DATA': 'replace',
44+
'WRITE_APPEND': 'append',
45+
'WRITE_EMPTY': 'fail',
46+
}
47+
4348

4449
class BigQueryWriterError(exceptions.GarfIoError):
4550
"""BigQueryWriter specific errors."""
@@ -60,9 +65,8 @@ def __init__(
6065
project: str | None = os.getenv('GOOGLE_CLOUD_PROJECT'),
6166
dataset: str = 'garf',
6267
location: str = 'US',
63-
write_disposition: bigquery.WriteDisposition | str = (
64-
bigquery.WriteDisposition.WRITE_TRUNCATE
65-
),
68+
write_disposition: bigquery.WriteDisposition
69+
| Literal['append', 'replace', 'fail'] = 'replace',
6670
**kwargs,
6771
):
6872
"""Initializes BigQueryWriter.
@@ -83,11 +87,20 @@ def __init__(
8387
self.project = project
8488
self.dataset_id = f'{project}.{dataset}'
8589
self.location = location
86-
if isinstance(write_disposition, str):
87-
write_disposition = getattr(
88-
bigquery.WriteDisposition, write_disposition.upper()
90+
if write_disposition in ('replace', 'append', 'fail'):
91+
self.write_disposition = write_disposition
92+
elif isinstance(write_disposition, bigquery.WriteDisposition):
93+
self.write_disposition = _WRITE_DISPOSITION_MAPPING.get(
94+
write_disposition.name
95+
)
96+
elif _WRITE_DISPOSITION_MAPPING.get(write_disposition.upper()):
97+
self.write_disposition = _WRITE_DISPOSITION_MAPPING.get(
98+
write_disposition.upper()
99+
)
100+
else:
101+
raise BigQueryWriterError(
102+
'Unsupported writer disposition, choose one of: replace, append, fail'
89103
)
90-
self.write_disposition = write_disposition
91104

92105
def __str__(self) -> str:
93106
return f'[BigQuery] - {self.dataset_id} at {self.location} location.'
@@ -118,19 +131,9 @@ def write(self, report: garf_report.GarfReport, destination: str) -> str:
118131
Name of the table in `dataset.table` format.
119132
"""
120133
report = self.format_for_write(report)
121-
schema = _define_schema(report)
122134
destination = formatter.format_extension(destination)
123135
_ = self.create_or_get_dataset()
124-
table = self._create_or_get_table(
125-
f'{self.dataset_id}.{destination}', schema
126-
)
127-
job_config = bigquery.LoadJobConfig(
128-
write_disposition=self.write_disposition,
129-
schema=schema,
130-
source_format='CSV',
131-
max_bad_records=len(report),
132-
)
133-
136+
table = f'{self.dataset_id}.{destination}'
134137
if not report:
135138
df = pd.DataFrame(
136139
data=report.results_placeholder, columns=report.column_names
@@ -139,123 +142,8 @@ def write(self, report: garf_report.GarfReport, destination: str) -> str:
139142
df = report.to_pandas()
140143
df = df.replace({np.nan: None})
141144
logger.debug('Writing %d rows of data to %s', len(df), destination)
142-
job = self.client.load_table_from_dataframe(
143-
dataframe=df, destination=table, job_config=job_config
145+
pandas_gbq.to_gbq(
146+
dataframe=df, destination_table=table, if_exists=self.write_disposition
144147
)
145-
try:
146-
job.result()
147-
logger.debug('Writing to %s is completed', destination)
148-
except google_cloud_exceptions.BadRequest as e:
149-
raise ValueError(f'Unable to save data to BigQuery! {str(e)}') from e
148+
logger.debug('Writing to %s is completed', destination)
150149
return f'[BigQuery] - at {self.dataset_id}.{destination}'
151-
152-
def _create_or_get_table(
153-
self, table_name: str, schema: Sequence[bigquery.SchemaField]
154-
) -> bigquery.Table:
155-
"""Gets existing table or create a new one.
156-
157-
Args:
158-
table_name: Name of the table in BigQuery.
159-
schema: Schema of the table if one should be created.
160-
161-
Returns:
162-
BigQuery table object.
163-
"""
164-
try:
165-
table = self.client.get_table(table_name)
166-
except google_cloud_exceptions.NotFound:
167-
table_ref = bigquery.Table(table_name, schema=schema)
168-
table = self.client.create_table(table_ref)
169-
table = self.client.get_table(table_name)
170-
return table
171-
172-
173-
def _define_schema(
174-
report: garf_report.GarfReport,
175-
) -> list[bigquery.SchemaField]:
176-
"""Infers schema from GarfReport.
177-
178-
Args:
179-
report: GarfReport to infer schema from.
180-
181-
Returns:
182-
Schema fields for a given report.
183-
184-
"""
185-
result_types = _get_result_types(report)
186-
return _get_bq_schema(result_types)
187-
188-
189-
def _get_result_types(
190-
report: garf_report.GarfReport,
191-
) -> dict[str, dict[str, parsers.ApiRowElement]]:
192-
"""Maps each column of report to BigQuery field type and repeated status.
193-
194-
Fields types are inferred based on report results or results placeholder.
195-
196-
Args:
197-
report: GarfReport to infer field types from.
198-
199-
Returns:
200-
Mapping between each column of report and its field type.
201-
"""
202-
result_types: dict[str, dict[str, parsers.ApiRowElement]] = {}
203-
column_names = report.column_names
204-
for row in report.results or report.results_placeholder:
205-
if set(column_names) == set(result_types.keys()):
206-
break
207-
for i, field in enumerate(row):
208-
if field is None or column_names[i] in result_types:
209-
continue
210-
field_type = type(field)
211-
if field_type in [
212-
list,
213-
proto.marshal.collections.repeated.RepeatedComposite,
214-
proto.marshal.collections.repeated.Repeated,
215-
]:
216-
repeated = True
217-
field_type = str if len(field) == 0 else type(field[0])
218-
else:
219-
field_type = type(field)
220-
repeated = False
221-
result_types[column_names[i]] = {
222-
'field_type': field_type,
223-
'repeated': repeated,
224-
}
225-
return result_types
226-
227-
228-
def _get_bq_schema(
229-
types: dict[str, dict[str, parsers.ApiRowElement]],
230-
) -> list[bigquery.SchemaField]:
231-
"""Converts report fields types to BigQuery schema fields.
232-
233-
Args:
234-
types: Mapping between column names and its field type.
235-
236-
Returns:
237-
BigQuery schema fields corresponding to GarfReport.
238-
"""
239-
type_mapping = {
240-
list: 'REPEATED',
241-
str: 'STRING',
242-
datetime.datetime: 'DATETIME',
243-
datetime.date: 'DATE',
244-
int: 'INT64',
245-
float: 'FLOAT64',
246-
bool: 'BOOL',
247-
proto.marshal.collections.repeated.RepeatedComposite: 'REPEATED',
248-
proto.marshal.collections.repeated.Repeated: 'REPEATED',
249-
}
250-
251-
schema: list[bigquery.SchemaField] = []
252-
for key, value in types.items():
253-
field_type = type_mapping.get(value.get('field_type'))
254-
schema.append(
255-
bigquery.SchemaField(
256-
name=key,
257-
field_type=field_type if field_type else 'STRING',
258-
mode='REPEATED' if value.get('repeated') else 'NULLABLE',
259-
)
260-
)
261-
return schema

libs/io/tests/unit/writers/test_bigquery_writer.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,48 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import os
17+
18+
import garf_core
1619
import pytest
1720
from garf_io.writers import bigquery_writer
1821
from google.cloud import bigquery
1922

2023

2124
class TestBigQueryWriter:
22-
@pytest.fixture
23-
def bq_writer(self):
24-
return bigquery_writer.BigQueryWriter(project='test', dataset='test')
25-
26-
def test_get_results_types_returns_correct_mapping(self, sample_data):
27-
result_types = bigquery_writer._get_result_types(sample_data)
28-
assert result_types == {
29-
'column_1': {'field_type': int, 'repeated': False},
30-
'column_2': {'field_type': str, 'repeated': False},
31-
'column_3': {'field_type': int, 'repeated': True},
32-
}
33-
34-
def test_define_schema_returns_correct_schema_fields(self, sample_data):
35-
schema = bigquery_writer._define_schema(sample_data)
36-
assert schema == [
37-
bigquery.SchemaField(
38-
'column_1', 'INT64', 'NULLABLE', None, None, (), None
39-
),
40-
bigquery.SchemaField(
41-
'column_2', 'STRING', 'NULLABLE', None, None, (), None
42-
),
43-
bigquery.SchemaField(
44-
'column_3', 'INT64', 'REPEATED', None, None, (), None
45-
),
46-
]
25+
@pytest.mark.skipif(
26+
not os.environ.get('GOOGLE_CLOUD_PROJECT'),
27+
reason='GOOGLE_CLOUD_PROJECT env variable not set.',
28+
)
29+
def test_write(self):
30+
writer = bigquery_writer.BigQueryWriter(array_handling='arrays')
31+
report = garf_core.GarfReport(
32+
results=[
33+
[{'key': ['one', 'two']}, 'three'],
34+
],
35+
column_names=['column1', 'column2'],
36+
)
37+
result = writer.write(report, 'test')
38+
assert result
4739

48-
def test_define_schema_correctly_handles_dates(self, sample_data_with_dates):
49-
schema = bigquery_writer._define_schema(sample_data_with_dates)
50-
assert schema == [
51-
bigquery.SchemaField(
52-
'column_1', 'INT64', 'NULLABLE', None, None, (), None
53-
),
54-
bigquery.SchemaField(
55-
'datetime', 'DATETIME', 'NULLABLE', None, None, (), None
56-
),
57-
bigquery.SchemaField('date', 'DATE', 'NULLABLE', None, None, (), None),
58-
]
40+
@pytest.mark.parametrize(
41+
('disposition', 'expected'),
42+
[
43+
('append', 'append'),
44+
('replace', 'replace'),
45+
('fail', 'fail'),
46+
('write_append', 'append'),
47+
('write_truncate', 'replace'),
48+
('write_truncate_data', 'replace'),
49+
('write_empty', 'fail'),
50+
(bigquery.WriteDisposition.WRITE_APPEND, 'append'),
51+
(bigquery.WriteDisposition.WRITE_TRUNCATE, 'replace'),
52+
(bigquery.WriteDisposition.WRITE_TRUNCATE_DATA, 'replace'),
53+
(bigquery.WriteDisposition.WRITE_EMPTY, 'fail'),
54+
],
55+
)
56+
def test_init_creates_correct_write_disposition(self, disposition, expected):
57+
writer = bigquery_writer.BigQueryWriter(
58+
project='test', write_disposition=disposition
59+
)
60+
assert writer.write_disposition == expected

0 commit comments

Comments
 (0)