Skip to content

Commit 1ab71c2

Browse files
committed
add fwd_creds flag to send credentials
1 parent 4673b82 commit 1ab71c2

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

dask_bigquery/core.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
@contextmanager
24-
def bigquery_clients(project_id, credentials):
24+
def bigquery_clients(project_id, credentials=None):
2525
"""This context manager is a temporary solution until there is an
2626
upstream solution to handle this.
2727
See googleapis/google-cloud-python#9457
@@ -73,17 +73,20 @@ def bigquery_read(
7373
Name of the BigQuery project.
7474
read_kwargs: dict
7575
kwargs to pass to read_rows()
76-
creds: dict
77-
credentials dictionary
7876
stream_name: str
7977
BigQuery Storage API Stream "name"
8078
NOTE: Please set if reading from Storage API without any `row_restriction`.
8179
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
80+
cred_token: str
81+
google_auth bearer token
8282
"""
8383

84-
credentials = google.oauth2.credentials.Credentials(cred_token)
84+
if cred_token:
85+
credentials = google.oauth2.credentials.Credentials(cred_token)
86+
else:
87+
credentials = None
8588

86-
with bigquery_clients(project_id, credentials) as (_, bqs_client):
89+
with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
8790
session = bqs_client.create_read_session(make_create_read_session_request())
8891
schema = pyarrow.ipc.read_schema(
8992
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
@@ -103,6 +106,7 @@ def read_gbq(
103106
row_filter: str = "",
104107
columns: list[str] = None,
105108
read_kwargs: dict = None,
109+
fwd_creds: bool = False,
106110
):
107111
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
108112
Partitions will be approximately balanced according to BigQuery stream allocation logic.
@@ -121,26 +125,35 @@ def read_gbq(
121125
list of columns to load from the table
122126
read_kwargs: dict
123127
kwargs to pass to read_rows()
128+
fwd_creds: bool
129+
Set to True if user desires to forward credentials to the workers. Default to False.
124130
125131
Returns
126132
-------
127133
Dask DataFrame
128134
"""
129135
read_kwargs = read_kwargs or {}
130136

131-
creds_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
132-
if creds_path is None:
133-
raise ValueError("No credentials found")
137+
if fwd_creds:
138+
creds_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
139+
if creds_path is None:
140+
raise ValueError("No credentials found")
134141

135-
credentials = service_account.Credentials.from_service_account_file(
136-
creds_path, scopes=["https://www.googleapis.com/auth/bigquery.readonly"]
137-
)
138-
139-
auth_req = google.auth.transport.requests.Request()
140-
credentials.refresh(auth_req)
141-
cred_token = credentials.token
142+
credentials = service_account.Credentials.from_service_account_file(
143+
creds_path, scopes=["https://www.googleapis.com/auth/bigquery.readonly"]
144+
)
142145

143-
with bigquery_clients(project_id, credentials) as (bq_client, bqs_client):
146+
auth_req = google.auth.transport.requests.Request()
147+
credentials.refresh(auth_req)
148+
cred_token = credentials.token
149+
else:
150+
credentials = None
151+
cred_token = None
152+
153+
with bigquery_clients(project_id, credentials=credentials) as (
154+
bq_client,
155+
bqs_client,
156+
):
144157
table_ref = bq_client.get_table(f"{dataset_id}.{table_id}")
145158
if table_ref.table_type == "VIEW":
146159
raise TypeError("Table type VIEW not supported")

dask_bigquery/tests/test_core.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,43 +51,54 @@ def dataset(df):
5151
)
5252

5353

54-
def test_read_gbq(df, dataset, client):
54+
@pytest.mark.parametrize("fwd_creds", [False, True])
55+
def test_read_gbq(df, dataset, fwd_creds, client):
5556
project_id, dataset_id, table_id = dataset
56-
ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id)
57+
ddf = read_gbq(
58+
project_id=project_id,
59+
dataset_id=dataset_id,
60+
table_id=table_id,
61+
fwd_creds=fwd_creds,
62+
)
5763

5864
assert list(ddf.columns) == ["name", "number", "idx"]
5965
assert ddf.npartitions == 2
6066
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))
6167

6268

63-
def test_read_row_filter(df, dataset, client):
69+
@pytest.mark.parametrize("fwd_creds", [False, True])
70+
def test_read_row_filter(df, dataset, fwd_creds, client):
6471
project_id, dataset_id, table_id = dataset
6572
ddf = read_gbq(
6673
project_id=project_id,
6774
dataset_id=dataset_id,
6875
table_id=table_id,
6976
row_filter="idx < 5",
77+
fwd_creds=fwd_creds,
7078
)
7179

7280
assert list(ddf.columns) == ["name", "number", "idx"]
7381
assert ddf.npartitions == 2
7482
assert assert_eq(ddf.set_index("idx").loc[:4], df.set_index("idx").loc[:4])
7583

7684

77-
def test_read_kwargs(dataset, client):
85+
@pytest.mark.parametrize("fwd_creds", [False, True])
86+
def test_read_kwargs(dataset, fwd_creds, client):
7887
project_id, dataset_id, table_id = dataset
7988
ddf = read_gbq(
8089
project_id=project_id,
8190
dataset_id=dataset_id,
8291
table_id=table_id,
8392
read_kwargs={"timeout": 1e-12},
93+
fwd_creds=fwd_creds,
8494
)
8595

8696
with pytest.raises(Exception, match="Deadline Exceeded"):
8797
ddf.compute()
8898

8999

90-
def test_read_columns(df, dataset, client):
100+
@pytest.mark.parametrize("fwd_creds", [False, True])
101+
def test_read_columns(df, dataset, fwd_creds, client):
91102
project_id, dataset_id, table_id = dataset
92103
assert df.shape[1] > 1, "Test data should have multiple columns"
93104

@@ -97,5 +108,6 @@ def test_read_columns(df, dataset, client):
97108
dataset_id=dataset_id,
98109
table_id=table_id,
99110
columns=columns,
111+
fwd_creds=fwd_creds,
100112
)
101113
assert list(ddf.columns) == columns

0 commit comments

Comments
 (0)