Skip to content

Commit cb743a9

Browse files
committed
code review feedback
1 parent e8099d8 commit cb743a9

File tree

5 files changed

+97
-6
lines changed

5 files changed

+97
-6
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,12 @@ See also [`log_account_link`](#log_account_link-fixture), [`make_acc_group`](#ma
375375
### `spark` fixture
376376
Get Databricks Connect Spark session. Requires `databricks-connect` package to be installed.
377377

378+
To enable serverless set the local environment variable `DATABRICKS_SERVERLESS_COMPUTE_ID` to `"auto"`.
379+
If this environment variable is set, Databricks Connect ignores the cluster_id.
380+
If `DATABRICKS_SERVERLESS_COMPUTE_ID` is set to a specific serverless cluster ID, that cluster will be used instead.
381+
However, this is not recommended, as serverless clusters are ephemeral by design.
382+
See more details [here](https://docs.databricks.com/en/dev-tools/databricks-connect/cluster-config.html#configure-a-connection-to-serverless-compute).
383+
378384
Usage:
379385
```python
380386
def test_databricks_connect(spark):

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ classifiers = [
4848
]
4949

5050
dependencies = [
51-
"databricks-sdk>=0.30",
51+
"databricks-sdk>=0.40,<0.42",
5252
"databricks-labs-lsql>=0.10",
5353
"pytest>=8.3",
5454
]
@@ -77,6 +77,7 @@ dependencies = [
7777
"pytest-timeout~=2.3.1",
7878
"pytest-xdist~=3.5.0",
7979
"ruff~=0.3.4",
80+
"databricks-connect~=15.4.3",
8081
]
8182

8283
# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better

src/databricks/labs/pytester/fixtures/connect.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,46 @@ def spark(ws: WorkspaceClient):
1111
"""
1212
Get Databricks Connect Spark session. Requires `databricks-connect` package to be installed.
1313
14+
To enable serverless set the local environment variable `DATABRICKS_SERVERLESS_COMPUTE_ID` to `"auto"`.
15+
If this environment variable is set, Databricks Connect ignores the cluster_id.
16+
If `DATABRICKS_SERVERLESS_COMPUTE_ID` is set to a specific serverless cluster ID, that cluster will be used instead.
17+
However, this is not recommended, as serverless clusters are ephemeral by design.
18+
See more details [here](https://docs.databricks.com/en/dev-tools/databricks-connect/cluster-config.html#configure-a-connection-to-serverless-compute).
19+
1420
Usage:
1521
```python
1622
def test_databricks_connect(spark):
1723
rows = spark.sql("SELECT 1").collect()
1824
assert rows[0][0] == 1
1925
```
2026
"""
21-
if not ws.config.cluster_id:
22-
skip("No cluster_id found in the environment")
23-
ws.clusters.ensure_cluster_is_running(ws.config.cluster_id)
27+
cluster_id = ws.config.cluster_id
28+
serverless_cluster_id = ws.config.serverless_compute_id
29+
30+
if not serverless_cluster_id:
31+
ensure_cluster_is_running(cluster_id, ws)
32+
33+
if serverless_cluster_id and serverless_cluster_id != "auto":
34+
ensure_cluster_is_running(serverless_cluster_id, ws)
35+
2436
try:
2537
# pylint: disable-next=import-outside-toplevel
2638
from databricks.connect import ( # type: ignore[import-untyped]
2739
DatabricksSession,
2840
)
2941

42+
if serverless_cluster_id:
43+
logging.debug(f"Using serverless cluster id '{serverless_cluster_id}'")
44+
return DatabricksSession.builder.serverless(True).getOrCreate()
45+
46+
logging.debug(f"Using cluster id '{cluster_id}'")
3047
return DatabricksSession.builder.sdkConfig(ws.config).getOrCreate()
3148
except ImportError:
3249
skip("Please run `pip install databricks-connect`")
3350
return None
51+
52+
53+
def ensure_cluster_is_running(cluster_id: str, ws: WorkspaceClient) -> None:
54+
if not cluster_id:
55+
skip("No cluster_id found in the environment")
56+
ws.clusters.ensure_cluster_is_running(cluster_id)

src/databricks/labs/pytester/fixtures/ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def create() -> Wait[ServingEndpointDetailed]:
128128
model = make_model()
129129
endpoint = ws.serving_endpoints.create(
130130
endpoint_name,
131-
EndpointCoreConfigInput(
131+
config=EndpointCoreConfigInput(
132132
served_models=[
133133
ServedModelInput(
134134
model_name=model.name,
Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,64 @@
1-
def test_databricks_connect(spark):
1+
import os
2+
from pytest import fixture
3+
from databricks.connect import DatabricksSession
4+
5+
6+
@fixture
7+
def serverless_env():
8+
os.environ['DATABRICKS_SERVERLESS_COMPUTE_ID'] = "auto"
9+
yield
10+
os.environ.pop('DATABRICKS_SERVERLESS_COMPUTE_ID')
11+
12+
13+
@fixture
14+
def debug_env_bugfix(monkeypatch, debug_env):
15+
# This is a workaround to set shared cluster
16+
# TODO: Update secret vault for acceptance testing and remove the bugfix
17+
monkeypatch.setitem(debug_env, "DATABRICKS_CLUSTER_ID", "1114-152544-29g1w07e")
18+
19+
20+
@fixture
21+
def spark_serverless_cluster_id(ws):
22+
# get new spark session with serverless cluster outside the actual spark fixture under test
23+
spark_serverless = DatabricksSession.builder.serverless(True).getOrCreate()
24+
# get cluster id from the existing serverless spark session
25+
cluster_id = spark_serverless.conf.get("spark.databricks.clusterUsageTags.clusterId")
26+
ws.config.serverless_compute_id = cluster_id
27+
yield cluster_id
28+
spark_serverless.stop()
29+
30+
31+
def test_databricks_connect(debug_env_bugfix, ws, spark):
232
rows = spark.sql("SELECT 1").collect()
333
assert rows[0][0] == 1
34+
35+
creator = get_cluster_creator(spark, ws)
36+
assert creator # non-serverless clusters must have assigned creator
37+
38+
39+
def test_databricks_connect_serverless(serverless_env, ws, spark):
40+
rows = spark.sql("SELECT 1").collect()
41+
assert rows[0][0] == 1
42+
43+
creator = get_cluster_creator(spark, ws)
44+
assert not creator # serverless clusters don't have assigned creator
45+
46+
47+
def test_databricks_connect_serverless_set_cluster_id(ws, spark_serverless_cluster_id, spark):
48+
rows = spark.sql("SELECT 1").collect()
49+
assert rows[0][0] == 1
50+
51+
cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
52+
assert spark_serverless_cluster_id == cluster_id
53+
54+
creator = get_cluster_creator(spark, ws)
55+
assert not creator # serverless clusters don't have assigned creator
56+
57+
58+
def get_cluster_creator(spark, ws):
59+
"""
60+
Get the creator of the cluster that the Spark session is connected to.
61+
"""
62+
cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
63+
creator = ws.clusters.get(cluster_id).creator_user_name
64+
return creator

0 commit comments

Comments
 (0)