diff --git a/dbt-bigquery/.changes/unreleased/Fixes-20250404-215640.yaml b/dbt-bigquery/.changes/unreleased/Fixes-20250404-215640.yaml new file mode 100644 index 000000000..e03382c42 --- /dev/null +++ b/dbt-bigquery/.changes/unreleased/Fixes-20250404-215640.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Add create_notebook_client +time: 2025-04-04T21:56:40.174432349Z +custom: + Author: jialuoo + Issue: "977" diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/clients.py b/dbt-bigquery/src/dbt/adapters/bigquery/clients.py index 722266240..18714efcd 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/clients.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/clients.py @@ -1,10 +1,14 @@ +from typing import Optional + from google.api_core.client_info import ClientInfo from google.api_core.client_options import ClientOptions from google.auth.exceptions import DefaultCredentialsError +from google.cloud import aiplatform_v1 from google.cloud.bigquery import Client as BigQueryClient, DEFAULT_RETRY as BQ_DEFAULT_RETRY from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient from google.cloud.storage import Client as StorageClient from google.cloud.storage.retry import DEFAULT_RETRY as GCS_DEFAULT_RETRY +from google.oauth2.credentials import Credentials as GoogleCredentials from dbt.adapters.events.logging import AdapterLogger @@ -67,3 +71,15 @@ def _create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: def _dataproc_endpoint(credentials: BigQueryCredentials) -> str: return f"{credentials.dataproc_region}-dataproc.googleapis.com:443" + + +def create_notebook_client( + credentials: GoogleCredentials, region: Optional[str] +) -> aiplatform_v1.NotebookServiceClient: + api_endpoint = f"{region}-aiplatform.googleapis.com" + notebook_client = aiplatform_v1.NotebookServiceClient( + credentials=credentials, + client_options=ClientOptions(api_endpoint), + ) + + return notebook_client diff --git a/dbt-bigquery/src/dbt/adapters/bigquery/python_submissions.py b/dbt-bigquery/src/dbt/adapters/bigquery/python_submissions.py index 7b42c2913..1cba02ec4 100644 --- a/dbt-bigquery/src/dbt/adapters/bigquery/python_submissions.py +++ b/dbt-bigquery/src/dbt/adapters/bigquery/python_submissions.py @@ -9,6 +9,7 @@ create_dataproc_batch_controller_client, create_dataproc_job_controller_client, create_gcs_client, + create_notebook_client, ) from dbt.adapters.bigquery.credentials import ( BigQueryConnectionMethod, @@ -17,7 +18,6 @@ ) from dbt.adapters.bigquery.retry import RetryFactory from dbt.adapters.events.logging import AdapterLogger -from google.api_core.client_options import ClientOptions from google.auth.transport.requests import Request from google.cloud import aiplatform_v1 @@ -195,12 +195,7 @@ def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None self._model_name = parsed_model["alias"] self._connection_method = credentials.method self._GoogleCredentials = create_google_credentials(credentials) - - # TODO(jialuo): Add a function in clients.py for it. - self._ai_platform_client = aiplatform_v1.NotebookServiceClient( - credentials=self._GoogleCredentials, - client_options=ClientOptions(api_endpoint=f"{self._region}-aiplatform.googleapis.com"), - ) + self._notebook_client = create_notebook_client(self._GoogleCredentials, self._region) self._notebook_template_id = parsed_model["config"].get("notebook_template_id") def _py_to_ipynb(self, compiled_code: str) -> str: @@ -220,7 +215,7 @@ def _get_notebook_template_id(self) -> str: parent=f"projects/{self._project}/locations/{self._region}", filter="notebookRuntimeType = ONE_CLICK", ) - page_result = self._ai_platform_client.list_notebook_runtime_templates(request=request) + page_result = self._notebook_client.list_notebook_runtime_templates(request=request) try: # Check if a default runtime template is available and applicable. @@ -269,9 +264,7 @@ def _create_notebook_template(self) -> str: notebook_runtime_template=template, ) - operation = self._ai_platform_client.create_notebook_runtime_template( - request=create_request - ) + operation = self._notebook_client.create_notebook_runtime_template(request=create_request) response = operation.result() return self._extract_template_id(response.name) @@ -408,7 +401,7 @@ def _submit_bigframes_job( ) try: - res = self._ai_platform_client.create_notebook_execution_job(request=request).result( + res = self._notebook_client.create_notebook_execution_job(request=request).result( timeout=self._polling_retry.timeout ) except TimeoutError as timeout_error: @@ -424,4 +417,4 @@ def _submit_bigframes_job( gcs_log_uri = f"{notebook_execution_job.gcs_output_uri}/{job_id}/{self._model_name}.py" self._process_gcs_log(gcs_log_uri) - return self._ai_platform_client.get_notebook_execution_job(name=res.name) + return self._notebook_client.get_notebook_execution_job(name=res.name)