1616from __future__ import annotations
1717
1818import os
19+ from typing import Literal
1920
2021try :
22+ import pandas as pd
23+ import pandas_gbq
2124 from google .cloud import bigquery
2225except ImportError as e :
2326 raise ImportError (
2427 'Please install garf-io with BigQuery support - `pip install garf-io[bq]`'
2528 ) from e
2629
27- import datetime
2830import logging
29- from collections .abc import Sequence
3031
3132import numpy as np
32- import pandas as pd
33- import proto # type: ignore
34- from garf_core import parsers
3533from garf_core import report as garf_report
3634from google .cloud import exceptions as google_cloud_exceptions
3735
4038
4139logger = logging .getLogger (__name__ )
4240
41+ _WRITE_DISPOSITION_MAPPING = {
42+ 'WRITE_TRUNCATE' : 'replace' ,
43+ 'WRITE_TRUNCATE_DATA' : 'replace' ,
44+ 'WRITE_APPEND' : 'append' ,
45+ 'WRITE_EMPTY' : 'fail' ,
46+ }
47+
4348
4449class BigQueryWriterError (exceptions .GarfIoError ):
4550 """BigQueryWriter specific errors."""
@@ -60,9 +65,8 @@ def __init__(
6065 project : str | None = os .getenv ('GOOGLE_CLOUD_PROJECT' ),
6166 dataset : str = 'garf' ,
6267 location : str = 'US' ,
63- write_disposition : bigquery .WriteDisposition | str = (
64- bigquery .WriteDisposition .WRITE_TRUNCATE
65- ),
68+ write_disposition : bigquery .WriteDisposition
69+ | Literal ['append' , 'replace' , 'fail' ] = 'replace' ,
6670 ** kwargs ,
6771 ):
6872 """Initializes BigQueryWriter.
@@ -83,11 +87,20 @@ def __init__(
8387 self .project = project
8488 self .dataset_id = f'{ project } .{ dataset } '
8589 self .location = location
86- if isinstance (write_disposition , str ):
87- write_disposition = getattr (
88- bigquery .WriteDisposition , write_disposition .upper ()
90+ if write_disposition in ('replace' , 'append' , 'fail' ):
91+ self .write_disposition = write_disposition
92+ elif isinstance (write_disposition , bigquery .WriteDisposition ):
93+ self .write_disposition = _WRITE_DISPOSITION_MAPPING .get (
94+ write_disposition .name
95+ )
96+ elif _WRITE_DISPOSITION_MAPPING .get (write_disposition .upper ()):
97+ self .write_disposition = _WRITE_DISPOSITION_MAPPING .get (
98+ write_disposition .upper ()
99+ )
100+ else :
101+ raise BigQueryWriterError (
102+ 'Unsupported writer disposition, choose one of: replace, append, fail'
89103 )
90- self .write_disposition = write_disposition
91104
92105 def __str__ (self ) -> str :
93106 return f'[BigQuery] - { self .dataset_id } at { self .location } location.'
@@ -118,19 +131,9 @@ def write(self, report: garf_report.GarfReport, destination: str) -> str:
118131 Name of the table in `dataset.table` format.
119132 """
120133 report = self .format_for_write (report )
121- schema = _define_schema (report )
122134 destination = formatter .format_extension (destination )
123135 _ = self .create_or_get_dataset ()
124- table = self ._create_or_get_table (
125- f'{ self .dataset_id } .{ destination } ' , schema
126- )
127- job_config = bigquery .LoadJobConfig (
128- write_disposition = self .write_disposition ,
129- schema = schema ,
130- source_format = 'CSV' ,
131- max_bad_records = len (report ),
132- )
133-
136+ table = f'{ self .dataset_id } .{ destination } '
134137 if not report :
135138 df = pd .DataFrame (
136139 data = report .results_placeholder , columns = report .column_names
@@ -139,123 +142,8 @@ def write(self, report: garf_report.GarfReport, destination: str) -> str:
139142 df = report .to_pandas ()
140143 df = df .replace ({np .nan : None })
141144 logger .debug ('Writing %d rows of data to %s' , len (df ), destination )
142- job = self . client . load_table_from_dataframe (
143- dataframe = df , destination = table , job_config = job_config
145+ pandas_gbq . to_gbq (
146+ dataframe = df , destination_table = table , if_exists = self . write_disposition
144147 )
145- try :
146- job .result ()
147- logger .debug ('Writing to %s is completed' , destination )
148- except google_cloud_exceptions .BadRequest as e :
149- raise ValueError (f'Unable to save data to BigQuery! { str (e )} ' ) from e
148+ logger .debug ('Writing to %s is completed' , destination )
150149 return f'[BigQuery] - at { self .dataset_id } .{ destination } '
151-
152- def _create_or_get_table (
153- self , table_name : str , schema : Sequence [bigquery .SchemaField ]
154- ) -> bigquery .Table :
155- """Gets existing table or create a new one.
156-
157- Args:
158- table_name: Name of the table in BigQuery.
159- schema: Schema of the table if one should be created.
160-
161- Returns:
162- BigQuery table object.
163- """
164- try :
165- table = self .client .get_table (table_name )
166- except google_cloud_exceptions .NotFound :
167- table_ref = bigquery .Table (table_name , schema = schema )
168- table = self .client .create_table (table_ref )
169- table = self .client .get_table (table_name )
170- return table
171-
172-
173- def _define_schema (
174- report : garf_report .GarfReport ,
175- ) -> list [bigquery .SchemaField ]:
176- """Infers schema from GarfReport.
177-
178- Args:
179- report: GarfReport to infer schema from.
180-
181- Returns:
182- Schema fields for a given report.
183-
184- """
185- result_types = _get_result_types (report )
186- return _get_bq_schema (result_types )
187-
188-
189- def _get_result_types (
190- report : garf_report .GarfReport ,
191- ) -> dict [str , dict [str , parsers .ApiRowElement ]]:
192- """Maps each column of report to BigQuery field type and repeated status.
193-
194- Fields types are inferred based on report results or results placeholder.
195-
196- Args:
197- report: GarfReport to infer field types from.
198-
199- Returns:
200- Mapping between each column of report and its field type.
201- """
202- result_types : dict [str , dict [str , parsers .ApiRowElement ]] = {}
203- column_names = report .column_names
204- for row in report .results or report .results_placeholder :
205- if set (column_names ) == set (result_types .keys ()):
206- break
207- for i , field in enumerate (row ):
208- if field is None or column_names [i ] in result_types :
209- continue
210- field_type = type (field )
211- if field_type in [
212- list ,
213- proto .marshal .collections .repeated .RepeatedComposite ,
214- proto .marshal .collections .repeated .Repeated ,
215- ]:
216- repeated = True
217- field_type = str if len (field ) == 0 else type (field [0 ])
218- else :
219- field_type = type (field )
220- repeated = False
221- result_types [column_names [i ]] = {
222- 'field_type' : field_type ,
223- 'repeated' : repeated ,
224- }
225- return result_types
226-
227-
228- def _get_bq_schema (
229- types : dict [str , dict [str , parsers .ApiRowElement ]],
230- ) -> list [bigquery .SchemaField ]:
231- """Converts report fields types to BigQuery schema fields.
232-
233- Args:
234- types: Mapping between column names and its field type.
235-
236- Returns:
237- BigQuery schema fields corresponding to GarfReport.
238- """
239- type_mapping = {
240- list : 'REPEATED' ,
241- str : 'STRING' ,
242- datetime .datetime : 'DATETIME' ,
243- datetime .date : 'DATE' ,
244- int : 'INT64' ,
245- float : 'FLOAT64' ,
246- bool : 'BOOL' ,
247- proto .marshal .collections .repeated .RepeatedComposite : 'REPEATED' ,
248- proto .marshal .collections .repeated .Repeated : 'REPEATED' ,
249- }
250-
251- schema : list [bigquery .SchemaField ] = []
252- for key , value in types .items ():
253- field_type = type_mapping .get (value .get ('field_type' ))
254- schema .append (
255- bigquery .SchemaField (
256- name = key ,
257- field_type = field_type if field_type else 'STRING' ,
258- mode = 'REPEATED' if value .get ('repeated' ) else 'NULLABLE' ,
259- )
260- )
261- return schema
0 commit comments