Skip to content

Commit cefd79e

Browse files
authored
Added serverless support to spark fixture (#91)
Extend spark fixture to support Serverless compute. ### Linked issues Resolves #90 ### Tests - [x] manually tested - [] added unit tests - [x] added integration tests - [ ] verified on staging environment (screenshot attached)
1 parent e8099d8 commit cefd79e

File tree

5 files changed

+95
-6
lines changed

5 files changed

+95
-6
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,12 @@ See also [`log_account_link`](#log_account_link-fixture), [`make_acc_group`](#ma
375375
### `spark` fixture
376376
Get Databricks Connect Spark session. Requires `databricks-connect` package to be installed.
377377

378+
To enable serverless set the local environment variable `DATABRICKS_SERVERLESS_COMPUTE_ID` to `"auto"`.
379+
If this environment variable is set, Databricks Connect ignores the cluster_id.
380+
If `DATABRICKS_SERVERLESS_COMPUTE_ID` is set to a specific serverless cluster ID, that cluster will be used instead.
381+
However, this is not recommended, as serverless clusters are ephemeral by design.
382+
See more details [here](https://docs.databricks.com/en/dev-tools/databricks-connect/cluster-config.html#configure-a-connection-to-serverless-compute).
383+
378384
Usage:
379385
```python
380386
def test_databricks_connect(spark):

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ classifiers = [
4848
]
4949

5050
dependencies = [
51-
"databricks-sdk>=0.30",
51+
"databricks-sdk>=0.40,<0.42",
5252
"databricks-labs-lsql>=0.10",
5353
"pytest>=8.3",
5454
]
@@ -77,6 +77,7 @@ dependencies = [
7777
"pytest-timeout~=2.3.1",
7878
"pytest-xdist~=3.5.0",
7979
"ruff~=0.3.4",
80+
"databricks-connect~=15.4.3",
8081
]
8182

8283
# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better

src/databricks/labs/pytester/fixtures/connect.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,46 @@ def spark(ws: WorkspaceClient):
1111
"""
1212
Get Databricks Connect Spark session. Requires `databricks-connect` package to be installed.
1313
14+
To enable serverless set the local environment variable `DATABRICKS_SERVERLESS_COMPUTE_ID` to `"auto"`.
15+
If this environment variable is set, Databricks Connect ignores the cluster_id.
16+
If `DATABRICKS_SERVERLESS_COMPUTE_ID` is set to a specific serverless cluster ID, that cluster will be used instead.
17+
However, this is not recommended, as serverless clusters are ephemeral by design.
18+
See more details [here](https://docs.databricks.com/en/dev-tools/databricks-connect/cluster-config.html#configure-a-connection-to-serverless-compute).
19+
1420
Usage:
1521
```python
1622
def test_databricks_connect(spark):
1723
rows = spark.sql("SELECT 1").collect()
1824
assert rows[0][0] == 1
1925
```
2026
"""
21-
if not ws.config.cluster_id:
22-
skip("No cluster_id found in the environment")
23-
ws.clusters.ensure_cluster_is_running(ws.config.cluster_id)
27+
cluster_id = ws.config.cluster_id
28+
serverless_cluster_id = ws.config.serverless_compute_id
29+
30+
if not serverless_cluster_id:
31+
ensure_cluster_is_running(cluster_id, ws)
32+
33+
if serverless_cluster_id and serverless_cluster_id != "auto":
34+
ensure_cluster_is_running(serverless_cluster_id, ws)
35+
2436
try:
2537
# pylint: disable-next=import-outside-toplevel
2638
from databricks.connect import ( # type: ignore[import-untyped]
2739
DatabricksSession,
2840
)
2941

42+
if serverless_cluster_id:
43+
logging.debug(f"Using serverless cluster id '{serverless_cluster_id}'")
44+
return DatabricksSession.builder.serverless(True).getOrCreate()
45+
46+
logging.debug(f"Using cluster id '{cluster_id}'")
3047
return DatabricksSession.builder.sdkConfig(ws.config).getOrCreate()
3148
except ImportError:
3249
skip("Please run `pip install databricks-connect`")
3350
return None
51+
52+
53+
def ensure_cluster_is_running(cluster_id: str, ws: WorkspaceClient) -> None:
54+
if not cluster_id:
55+
skip("No cluster_id found in the environment")
56+
ws.clusters.ensure_cluster_is_running(cluster_id)

src/databricks/labs/pytester/fixtures/ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def create() -> Wait[ServingEndpointDetailed]:
128128
model = make_model()
129129
endpoint = ws.serving_endpoints.create(
130130
endpoint_name,
131-
EndpointCoreConfigInput(
131+
config=EndpointCoreConfigInput(
132132
served_models=[
133133
ServedModelInput(
134134
model_name=model.name,
Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,62 @@
1-
def test_databricks_connect(spark):
1+
import os
2+
from pytest import fixture
3+
from pyspark.sql.session import SparkSession
4+
from databricks.connect import DatabricksSession
5+
from databricks.sdk import WorkspaceClient
6+
7+
8+
@fixture
9+
def serverless_env():
10+
os.environ['DATABRICKS_SERVERLESS_COMPUTE_ID'] = "auto"
11+
yield
12+
os.environ.pop('DATABRICKS_SERVERLESS_COMPUTE_ID')
13+
14+
15+
@fixture
16+
def debug_env_bugfix(monkeypatch, debug_env):
17+
# This is a workaround to set shared cluster
18+
# TODO: Update secret vault for acceptance testing and remove the bugfix
19+
monkeypatch.setitem(debug_env, "DATABRICKS_CLUSTER_ID", "1114-152544-29g1w07e")
20+
21+
22+
@fixture
23+
def spark_serverless_cluster_id(ws):
24+
# get new spark session with serverless cluster outside the actual spark fixture under test
25+
spark_serverless = DatabricksSession.builder.serverless(True).getOrCreate()
26+
# get cluster id from the existing serverless spark session
27+
cluster_id = spark_serverless.conf.get("spark.databricks.clusterUsageTags.clusterId")
28+
ws.config.serverless_compute_id = cluster_id
29+
yield cluster_id
30+
spark_serverless.stop()
31+
32+
33+
def test_databricks_connect(debug_env_bugfix, ws, spark):
234
rows = spark.sql("SELECT 1").collect()
335
assert rows[0][0] == 1
36+
assert not is_serverless_cluster(spark, ws)
37+
38+
39+
def test_databricks_connect_serverless(serverless_env, ws, spark):
40+
rows = spark.sql("SELECT 1").collect()
41+
assert rows[0][0] == 1
42+
assert is_serverless_cluster(spark, ws)
43+
44+
45+
def test_databricks_connect_serverless_set_cluster_id(ws, spark_serverless_cluster_id, spark):
46+
rows = spark.sql("SELECT 1").collect()
47+
assert rows[0][0] == 1
48+
49+
cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
50+
assert spark_serverless_cluster_id == cluster_id
51+
assert is_serverless_cluster(spark, ws)
52+
53+
54+
def is_serverless_cluster(spark: SparkSession, ws: WorkspaceClient) -> bool:
55+
"""
56+
Check if the current cluster used is serverless.
57+
"""
58+
cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
59+
if not cluster_id:
60+
raise ValueError("clusterId usage tag does not exist")
61+
creator = ws.clusters.get(cluster_id).creator_user_name
62+
return not creator # serverless clusters don't have assigned creator

0 commit comments

Comments
 (0)