-
Notifications
You must be signed in to change notification settings - Fork 158
Expand file tree
/
Copy pathclients.py
More file actions
438 lines (366 loc) · 14.5 KB
/
clients.py
File metadata and controls
438 lines (366 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import copy
import logging
from typing import TYPE_CHECKING, Optional
import warnings
import google.oauth2.service_account
from google.cloud import bigquery
from google.api_core import client_options
import ibis
import pandas
from data_validation import client_info, consts, exceptions
from data_validation.secret_manager import SecretManagerBuilder
from third_party.ibis.ibis_bigquery.api import bigquery_connect
from third_party.ibis.ibis_cloud_spanner.api import spanner_connect
from third_party.ibis.ibis_impala.api import impala_connect
from third_party.ibis.ibis_mssql.api import mssql_connect
from third_party.ibis.ibis_redshift.api import redshift_connect
if TYPE_CHECKING:
import ibis.expr.schema as sch
import ibis.expr.types as ir
ibis.options.sql.default_limit = None
# Filter Ibis MySQL error when loading client.table()
warnings.filterwarnings(
"ignore",
"`BaseBackend.database` is deprecated; use equivalent methods in the backend",
)
IBIS_ALCHEMY_BACKENDS = [
"mysql",
"oracle",
"postgres",
"db2",
"db2_zos",
"mssql",
"redshift",
"snowflake",
"sybase",
]
def _raise_missing_client_error(msg):
def get_client_call(*args, **kwargs):
raise Exception(msg)
return get_client_call
# Teradata requires teradatasql and licensing
try:
from third_party.ibis.ibis_teradata.api import teradata_connect
except ImportError:
msg = "pip install teradatasql (requires Teradata licensing)"
teradata_connect = _raise_missing_client_error(msg)
# Oracle requires python-oracldb driver
try:
from third_party.ibis.ibis_oracle.api import oracle_connect
except ImportError:
oracle_connect = _raise_missing_client_error("pip install oracledb")
# Snowflake requires snowflake-connector-python and snowflake-sqlalchemy
try:
from third_party.ibis.ibis_snowflake.api import snowflake_connect
except ImportError:
snowflake_connect = _raise_missing_client_error(
"pip install snowflake-connector-python && pip install snowflake-sqlalchemy"
)
# DB2 requires ibm_db_sa
try:
from third_party.ibis.ibis_db2.api import db2_connect
from third_party.ibis.ibis_db2_zos.api import db2_zos_connect
except ImportError:
db2_connect = _raise_missing_client_error("pip install ibm_db_sa")
db2_zos_connect = _raise_missing_client_error("pip install ibm_db_sa")
# Sybase requires sqlalchemy_sybase package.
try:
from third_party.ibis.ibis_sybase.api import sybase_connect
except ImportError:
sybase_connect = _raise_missing_client_error("pip install sqlalchemy_sybase")
def get_google_bigquery_client(
project_id: str,
credentials=None,
api_endpoint: Optional[str] = None,
quota_project_id: Optional[str] = None,
):
info = client_info.get_http_client_info()
job_config = bigquery.QueryJobConfig(
connection_properties=[bigquery.ConnectionProperty("time_zone", "UTC")]
)
effective_project = quota_project_id or project_id
options = None
if api_endpoint or quota_project_id:
options = client_options.ClientOptions(
api_endpoint=api_endpoint,
quota_project_id=quota_project_id if quota_project_id else None,
)
return bigquery.Client(
project=effective_project,
client_info=info,
credentials=credentials,
default_query_job_config=job_config,
client_options=options,
)
def _get_google_bqstorage_client(
credentials=None,
api_endpoint: Optional[str] = None,
quota_project_id: Optional[str] = None,
):
options = None
if api_endpoint or quota_project_id:
options = client_options.ClientOptions(
api_endpoint=api_endpoint,
quota_project_id=quota_project_id if quota_project_id else None,
)
from google.cloud import bigquery_storage_v1 as bigquery_storage
return bigquery_storage.BigQueryReadClient(
credentials=credentials,
client_options=options,
)
def get_bigquery_client(
project_id: str,
dataset_id: str = "",
credentials=None,
api_endpoint: Optional[str] = None,
storage_api_endpoint: Optional[str] = None,
client_project_id: Optional[str] = None,
):
google_client = get_google_bigquery_client(
project_id,
credentials=credentials,
api_endpoint=api_endpoint,
quota_project_id=client_project_id,
)
bqstorage_client = None
if storage_api_endpoint:
bqstorage_client = _get_google_bqstorage_client(
credentials=credentials,
api_endpoint=storage_api_endpoint,
quota_project_id=client_project_id,
)
return bigquery_connect(
project_id=project_id or client_project_id,
dataset_id=dataset_id,
credentials=credentials,
bigquery_client=google_client,
bqstorage_client=bqstorage_client,
)
def get_pandas_client(table_name, file_path, file_type):
"""Return pandas client and env with file loaded into DataFrame
table_name (str): Table name to use as reference for file data
file_path (str): The local, s3, or GCS file path to the data
file_type (str): The file type of the file (csv, json, orc or parquet)
"""
if file_type == "csv":
df = pandas.read_csv(file_path)
elif file_type == "json":
df = pandas.read_json(file_path)
elif file_type == "orc":
df = pandas.read_orc(file_path)
elif file_type == "parquet":
df = pandas.read_parquet(file_path)
else:
raise ValueError(f"Unknown Pandas File Type: {file_type}")
pandas_client = ibis.pandas.connect({table_name: df})
return pandas_client
def is_sqlalchemy_backend(client):
try:
return bool(client.name in IBIS_ALCHEMY_BACKENDS)
except Exception:
return False
def is_oracle_client(client):
try:
return client.name == "oracle"
except TypeError:
# When no Oracle backend has been installed OracleBackend is not a class
return False
def get_ibis_table(client, schema_name, table_name, database_name=None):
"""Return Ibis Table for Supplied Client.
client (IbisClient): Client to use for table
schema_name (str): Schema name of table object
table_name (str): Table name of table object
database_name (str): Database name (generally default is used)
"""
if client.name in [
"oracle",
"postgres",
"db2",
"db2_zos",
"mssql",
"redshift",
"sybase",
]:
return client.table(table_name, database=database_name, schema=schema_name)
elif client.name == "pandas":
return client.table(table_name, schema=schema_name)
else:
return client.table(table_name, database=schema_name)
def get_ibis_query(client, query) -> "ir.Table":
"""Return Ibis Table from query expression for Supplied Client."""
iq = client.sql(query)
# Normalise all columns in the query to lower case.
# https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/992
iq = iq.relabel(dict(zip(iq.columns, [_.lower() for _ in iq.columns])))
return iq
def get_ibis_table_schema(client, schema_name: str, table_name: str) -> "sch.Schema":
"""Return Ibis Table Schema for Supplied Client.
client (IbisClient): Client to use for table
schema_name (str): Schema name of table object, may not need this since Backend uses database
table_name (str): Table name of table object
database_name (str): Database name (generally default is used)
"""
if is_sqlalchemy_backend(client):
return client.table(table_name, schema=schema_name).schema()
else:
return client.get_schema(table_name, schema_name)
def get_ibis_query_schema(client, query_str) -> "sch.Schema":
if is_sqlalchemy_backend(client):
ibis_query = get_ibis_query(client, query_str)
return ibis_query.schema()
else:
# NJ: I'm not happy about calling a private method but don't see how I can avoid it.
# Ibis does not expose a public method like it does for get_schema().
return client._get_schema_using_query(query_str)
def list_schemas(client):
"""Return a list of schemas in the DB."""
if hasattr(client, "list_databases"):
try:
return client.list_databases()
except NotImplementedError:
return [None]
else:
return [None]
def list_tables(client, schema_name, tables_only=True):
"""Return a list of tables in the DB schema."""
fn = (
client.dvt_list_tables
if tables_only and client.name != "pandas"
else client.list_tables
)
if client.name in ["redshift", "snowflake", "pandas"]:
return fn()
return fn(database=schema_name)
def get_all_tables(client, allowed_schemas=None, tables_only=True):
"""Return a list of tuples with database and table names.
client (IbisClient): Client to use for tables
allowed_schemas (List[str]): List of schemas to pull.
"""
table_objs = []
schemas = list_schemas(client)
for schema_name in schemas:
if allowed_schemas and schema_name not in allowed_schemas:
continue
try:
tables = list_tables(client, schema_name, tables_only=tables_only)
except Exception as e:
logging.warning(f"List Tables Error: {schema_name} -> {e}")
continue
for table_name in tables:
table_objs.append((schema_name, table_name))
return table_objs
def get_data_client(connection_config):
"""Return DataClient client from given configuration"""
connection_config = copy.deepcopy(connection_config)
source_type = connection_config.pop(consts.SOURCE_TYPE)
secret_manager_type = connection_config.pop(consts.SECRET_MANAGER_TYPE, None)
secret_manager_project_id = connection_config.pop(
consts.SECRET_MANAGER_PROJECT_ID, None
)
decrypted_connection_config = {}
if secret_manager_type is not None:
sm = SecretManagerBuilder().build(secret_manager_type.lower())
for config_item in connection_config:
decrypted_connection_config[config_item] = sm.maybe_secret(
secret_manager_project_id, connection_config[config_item]
)
else:
decrypted_connection_config = connection_config
# The ibis_bigquery.connect expects a credentials object, not a string.
if consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH in decrypted_connection_config:
key_path = decrypted_connection_config.pop(
consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH
)
if key_path:
decrypted_connection_config["credentials"] = (
google.oauth2.service_account.Credentials.from_service_account_file(
key_path
)
)
if source_type not in CLIENT_LOOKUP:
msg = 'ConfigurationError: Source type "{source_type}" is not supported'.format(
source_type=source_type
)
raise Exception(msg)
try:
data_client = CLIENT_LOOKUP[source_type](**decrypted_connection_config)
data_client._source_type = source_type
except Exception as e:
msg = 'Connection Type "{source_type}" could not connect: {error}'.format(
source_type=source_type, error=str(e)
)
raise exceptions.DataClientConnectionFailure(msg)
return data_client
@contextmanager
def get_data_client_ctx(*args, **kwargs):
"""Provide get_data_client() via a context manager."""
client = None
try:
client = get_data_client(*args, **kwargs)
yield client
finally:
# TODO When we upgrade Ibis beyond 5.x this try/except may become redundant.
# https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/1376
if hasattr(client, "close"):
try:
client.close()
except Exception as exc:
# No need to reraise, we can silently fail if exiting throws up an issue.
logging.warning("Exception closing connection: %s", str(exc))
def get_max_column_length(client):
"""Return the max column length supported by client.
client (IbisClient): Client to use for tables
"""
if is_oracle_client(client):
# We can't reliably know which Version class client.version is stored in
# because it is out of our control. Therefore using string identification
# of Oracle <= 12.1 to avoid exceptions of this nature:
# TypeError: '<' not supported between instances of 'Version' and 'Version'
if str(client.version)[:2] in ["10", "11"] or str(client.version)[:4] == "12.1":
return 30
return 128
def get_max_in_list_size(client, in_list_over_expressions=False):
if client.name == "snowflake":
if in_list_over_expressions:
# This is a workaround for Snowflake limitation:
# SQL compilation error: In-list contains more than 50 non-constant values
# getattr(..., "cast") expression above is looking for lists where the contents are casts and not simple literals.
return 50
else:
return 16000
elif is_oracle_client(client):
# This is a workaround for Oracle limitation:
# ORA-01795: maximum number of expressions in a list is 1000
return 1000
else:
return None
CLIENT_LOOKUP = {
consts.SOURCE_TYPE_BIGQUERY: get_bigquery_client,
consts.SOURCE_TYPE_IMPALA: impala_connect,
consts.SOURCE_TYPE_MYSQL: ibis.mysql.connect,
consts.SOURCE_TYPE_ORACLE: oracle_connect,
consts.SOURCE_TYPE_FILESYSTEM: get_pandas_client,
consts.SOURCE_TYPE_POSTGRES: ibis.postgres.connect,
consts.SOURCE_TYPE_REDSHIFT: redshift_connect,
consts.SOURCE_TYPE_TERADATA: teradata_connect,
consts.SOURCE_TYPE_MSSQL: mssql_connect,
consts.SOURCE_TYPE_SNOWFLAKE: snowflake_connect,
consts.SOURCE_TYPE_SPANNER: spanner_connect,
consts.SOURCE_TYPE_SYBASE: sybase_connect,
consts.SOURCE_TYPE_DB2: db2_connect,
consts.SOURCE_TYPE_DB2_ZOS: db2_zos_connect,
}