Skip to content
Merged
10 changes: 10 additions & 0 deletions google/cloud/dataproc_spark_connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PermissionDenied,
)
from google.api_core.future.polling import POLLING_PREDICATE
from google.auth.exceptions import DefaultCredentialsError
from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder
from google.cloud.dataproc_spark_connect.exceptions import DataprocSparkConnectException
from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts
Expand Down Expand Up @@ -456,6 +457,15 @@ def create_session_pbar():
raise DataprocSparkConnectException(
f"Error while creating Dataproc Session: {e.message}"
)
except DefaultCredentialsError as e:
stop_create_session_pbar_event.set()
if create_session_pbar_thread.is_alive():
create_session_pbar_thread.join()
DataprocSparkSession._active_s8s_session_id = None
DataprocSparkSession._active_session_uses_custom_id = False
raise DataprocSparkConnectException(
"Credentials error while creating Dataproc Session (see https://docs.cloud.google.com/docs/authentication/provide-credentials-adc for more info)"
) from e
except Exception as e:
stop_create_session_pbar_event.set()
if create_session_pbar_thread.is_alive():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_pypi_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_valid_inputs():
def test_bad_format(self):
with self.assertRaisesRegex(
InvalidRequirement,
"Expected end or semicolon \(after name and no valid version specifier\).*",
r"Expected semicolon \(after name with no version specifier\) or end",
):
PyPiArtifacts({"pypi://spacy:23"})

Expand Down
16 changes: 14 additions & 2 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ def test_display_button_with_aiplatform_installed_ipython_non_interactive(

@mock.patch(
"IPython.core.interactiveshell.InteractiveShell.initialized",
return_value=False,
return_value=True,
)
@mock.patch("IPython.display.display")
def test_display_session_link_on_creation_colab_enterprise(
Expand All @@ -1238,7 +1238,7 @@ def test_display_session_link_on_creation_colab_enterprise(

@mock.patch(
"IPython.core.interactiveshell.InteractiveShell.initialized",
return_value=False,
return_value=True,
)
@mock.patch("IPython.display.display")
def test_display_session_link_on_creation_not_colab_enterprise(
Expand Down Expand Up @@ -1487,6 +1487,18 @@ def test_create_session_without_location(self):
except DataprocSparkConnectException as e:
self.assertIn("location is not set", str(e))

def test_create_session_without_application_default_credentials(self):
"""Tests that an exception is raised when application default credentials is not provided."""
os.environ.clear()
try:
DataprocSparkSession.builder.location("test-region").projectId(
"test-project"
).getOrCreate()
except DataprocSparkConnectException as e:
self.assertIn(
"Credentials error while creating Dataproc Session", str(e)
)


class DataprocSparkConnectClientTest(unittest.TestCase):

Expand Down
Loading