Skip to content

Commit 3a0169a

Browse files
committed
code review feedback
1 parent bb6ba4f commit 3a0169a

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

src/databricks/labs/pytester/fixtures/connect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_databricks_connect(spark):
5050
return None
5151

5252

53-
def ensure_cluster_is_running(cluster_id, ws):
53+
def ensure_cluster_is_running(cluster_id: str, ws: WorkspaceClient) -> None:
5454
if not cluster_id:
5555
skip("No cluster_id found in the environment")
5656
ws.clusters.ensure_cluster_is_running(cluster_id)

tests/integration/fixtures/test_connect.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ def debug_env_bugfix(monkeypatch, debug_env):
1717
monkeypatch.setitem(debug_env, "DATABRICKS_CLUSTER_ID", "1114-152544-29g1w07e")
1818

1919

20+
@fixture
21+
def spark_serverless_cluster_id(ws):
22+
# get new spark session with serverless cluster outside the actual spark fixture under test
23+
spark_serverless = DatabricksSession.builder.serverless(True).getOrCreate()
24+
# get cluster id from the existing serverless spark session
25+
cluster_id = spark_serverless.conf.get("spark.databricks.clusterUsageTags.clusterId")
26+
ws.config.serverless_compute_id = cluster_id
27+
yield cluster_id
28+
spark_serverless.stop()
29+
30+
2031
def test_databricks_connect(debug_env_bugfix, ws, spark):
2132
rows = spark.sql("SELECT 1").collect()
2233
assert rows[0][0] == 1
@@ -33,20 +44,14 @@ def test_databricks_connect_serverless(serverless_env, ws, spark):
3344
assert not creator # serverless clusters don't have assigned creator
3445

3546

36-
def test_databricks_connect_serverless_set_cluster_id(serverless_env, ws, request):
37-
# get spark session to retrieve serverless cluster id
38-
spark_serverless = DatabricksSession.builder.serverless(True).getOrCreate()
39-
cluster_id = spark_serverless.conf.get("spark.databricks.clusterUsageTags.clusterId")
40-
ws.config.serverless_compute_id = cluster_id
41-
42-
# get a new spark session for serverless from the spark fixture using provided cluster id
43-
spark_serverless_new = request.getfixturevalue("spark")
44-
45-
rows = spark_serverless_new.sql("SELECT 1").collect()
47+
def test_databricks_connect_serverless_set_cluster_id(ws, spark_serverless_cluster_id, spark):
48+
rows = spark.sql("SELECT 1").collect()
4649
assert rows[0][0] == 1
4750

48-
assert spark_serverless_new.conf.get("spark.databricks.clusterUsageTags.clusterId") == cluster_id
49-
creator = get_cluster_creator(spark_serverless_new, ws)
51+
cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
52+
assert spark_serverless_cluster_id == cluster_id
53+
54+
creator = get_cluster_creator(spark, ws)
5055
assert not creator # serverless clusters don't have assigned creator
5156

5257

0 commit comments

Comments
 (0)