Skip to content

Commit 6c85187

Browse files
committed
refactor: move session creation flow from SparkClient to backend.create_and_connect()
Signed-off-by: Shekhar Rajak <shekharrajak@live.com>
1 parent 0dbc312 commit 6c85187

File tree

3 files changed

+100
-64
lines changed

3 files changed

+100
-64
lines changed

kubeflow/spark/api/spark_client.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
from collections.abc import Iterator
1818
import logging
19-
import os
20-
import sys
2119
from typing import Optional
2220

2321
from pyspark.sql import SparkSession
@@ -29,22 +27,6 @@
2927

3028
logger = logging.getLogger(__name__)
3129

32-
_spark_debug_logging_enabled = False
33-
34-
35-
def _enable_spark_debug_logging() -> None:
36-
"""Turn on INFO logging for kubeflow.spark to stderr (for E2E debug)."""
37-
global _spark_debug_logging_enabled
38-
if _spark_debug_logging_enabled:
39-
return
40-
_spark_debug_logging_enabled = True
41-
root = logging.getLogger("kubeflow.spark")
42-
root.setLevel(logging.INFO)
43-
if not root.handlers:
44-
h = logging.StreamHandler(sys.stderr)
45-
h.setLevel(logging.INFO)
46-
root.addHandler(h)
47-
4830

4931
class SparkClient:
5032
"""Stateless Spark client for Kubeflow."""
@@ -127,48 +109,23 @@ def connect(
127109
Server port defaults to 15002 (Spark Connect gRPC). PySpark and server Spark
128110
major.minor should match; see constants and pyproject.toml [spark].
129111
"""
130-
if resources_per_executor is not None and not isinstance(resources_per_executor, dict):
131-
raise TypeError(
132-
f"resources_per_executor must be a dict, got {type(resources_per_executor)}"
133-
)
134-
if spark_conf is not None and not isinstance(spark_conf, dict):
135-
raise TypeError(f"spark_conf must be a dict, got {type(spark_conf)}")
136-
if num_executors is not None and not isinstance(num_executors, int):
137-
raise TypeError(f"num_executors must be an int, got {type(num_executors)}")
138-
if driver is not None and not isinstance(driver, Driver):
139-
raise TypeError(f"driver must be a Driver instance, got {type(driver)}")
140-
if executor is not None and not isinstance(executor, Executor):
141-
raise TypeError(f"executor must be an Executor instance, got {type(executor)}")
142-
143112
if base_url:
144113
validate_spark_connect_url(base_url)
145114
builder = SparkSession.builder.remote(base_url)
146115
if token:
147116
builder = builder.config("spark.connect.authenticate.token", token)
148117
return builder.getOrCreate()
149118

150-
if os.environ.get("SPARK_E2E_DEBUG"):
151-
_enable_spark_debug_logging()
152-
153-
info = self.backend._create_session(
119+
return self.backend.create_and_connect(
154120
num_executors=num_executors,
155121
resources_per_executor=resources_per_executor,
156122
spark_conf=spark_conf,
157123
driver=driver,
158124
executor=executor,
159125
options=options,
126+
timeout=timeout,
127+
connect_timeout=connect_timeout,
160128
)
161-
logger.info(
162-
"Created session %s/%s, waiting for ready (timeout=%ss)",
163-
info.namespace,
164-
info.name,
165-
timeout,
166-
)
167-
168-
info = self.backend._wait_for_session_ready(info.name, timeout=timeout)
169-
logger.info("Session ready, connecting (service_name=%s)", info.service_name)
170-
171-
return self.backend.connect(info, connect_timeout=connect_timeout)
172129

173130
def list_sessions(self) -> list[SparkConnectInfo]:
174131
"""List all SparkConnect sessions."""

kubeflow/spark/api/spark_client_test.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,28 +154,18 @@ class TestSparkClientConnectWithNameOption:
154154
def test_connect_with_name_option(self, spark_client, mock_backend):
155155
"""C18: Connect passes options to backend including Name option."""
156156
mock_session = Mock()
157-
mock_builder = Mock()
158-
mock_builder.remote.return_value = mock_builder
159-
mock_builder.getOrCreate.return_value = mock_session
160-
mock_spark = Mock()
161-
mock_spark.builder = mock_builder
157+
mock_backend.create_and_connect.return_value = mock_session
162158
options = [Name("custom-session")]
163-
with patch("kubeflow.spark.api.spark_client.SparkSession", mock_spark):
164-
spark_client.connect(options=options)
165-
mock_backend._create_session.assert_called_once()
166-
call_args = mock_backend._create_session.call_args
159+
spark_client.connect(options=options)
160+
mock_backend.create_and_connect.assert_called_once()
161+
call_args = mock_backend.create_and_connect.call_args
167162
assert call_args.kwargs["options"] == options
168163

169164
def test_connect_without_options_auto_generates(self, spark_client, mock_backend):
170165
"""C19: Connect without options auto-generates name via backend."""
171166
mock_session = Mock()
172-
mock_builder = Mock()
173-
mock_builder.remote.return_value = mock_builder
174-
mock_builder.getOrCreate.return_value = mock_session
175-
mock_spark = Mock()
176-
mock_spark.builder = mock_builder
177-
with patch("kubeflow.spark.api.spark_client.SparkSession", mock_spark):
178-
spark_client.connect()
179-
mock_backend._create_session.assert_called_once()
180-
call_args = mock_backend._create_session.call_args
167+
mock_backend.create_and_connect.return_value = mock_session
168+
spark_client.connect()
169+
mock_backend.create_and_connect.assert_called_once()
170+
call_args = mock_backend.create_and_connect.call_args
181171
assert call_args.kwargs["options"] is None

kubeflow/spark/backends/kubernetes/backend.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import random
2323
import socket
2424
import subprocess
25+
import sys
2526
import threading
2627
import time
2728
from typing import Optional
@@ -44,6 +45,22 @@
4445

4546
logger = logging.getLogger(__name__)
4647

48+
_spark_debug_logging_enabled = False
49+
50+
51+
def _enable_spark_debug_logging() -> None:
52+
"""Turn on INFO logging for kubeflow.spark to stderr (for E2E debug)."""
53+
global _spark_debug_logging_enabled
54+
if _spark_debug_logging_enabled:
55+
return
56+
_spark_debug_logging_enabled = True
57+
root = logging.getLogger("kubeflow.spark")
58+
root.setLevel(logging.INFO)
59+
if not root.handlers:
60+
h = logging.StreamHandler(sys.stderr)
61+
h.setLevel(logging.INFO)
62+
root.addHandler(h)
63+
4764

4865
class KubernetesBackend(RuntimeBackend):
4966
"""Kubernetes backend for managing SparkConnect sessions."""
@@ -104,6 +121,20 @@ def _create_session(
104121
options: Optional[list] = None,
105122
) -> SparkConnectInfo:
106123
"""Create a new SparkConnect session (INTERNAL USE ONLY)."""
124+
# Validate input types
125+
if resources_per_executor is not None and not isinstance(resources_per_executor, dict):
126+
raise TypeError(
127+
f"resources_per_executor must be a dict, got {type(resources_per_executor)}"
128+
)
129+
if spark_conf is not None and not isinstance(spark_conf, dict):
130+
raise TypeError(f"spark_conf must be a dict, got {type(spark_conf)}")
131+
if num_executors is not None and not isinstance(num_executors, int):
132+
raise TypeError(f"num_executors must be an int, got {type(num_executors)}")
133+
if driver is not None and not isinstance(driver, Driver):
134+
raise TypeError(f"driver must be a Driver instance, got {type(driver)}")
135+
if executor is not None and not isinstance(executor, Executor):
136+
raise TypeError(f"executor must be an Executor instance, got {type(executor)}")
137+
107138
# Extract Name option if present, or auto-generate
108139
name, filtered_options = self._extract_name_option(options)
109140

@@ -503,6 +534,64 @@ def _get_or_create() -> None:
503534
)
504535
raise TimeoutError(base_msg)
505536

537+
def create_and_connect(
538+
self,
539+
num_executors: Optional[int] = None,
540+
resources_per_executor: Optional[dict[str, str]] = None,
541+
spark_conf: Optional[dict[str, str]] = None,
542+
driver: Optional[Driver] = None,
543+
executor: Optional[Executor] = None,
544+
options: Optional[list] = None,
545+
timeout: int = 300,
546+
connect_timeout: int = 120,
547+
) -> SparkSession:
548+
"""Create a new SparkConnect session and connect to it.
549+
550+
This method handles the full session lifecycle:
551+
1. Creates a new session via _create_session
552+
2. Waits for session to become ready
553+
3. Connects to the session and returns SparkSession
554+
555+
Args:
556+
num_executors: Number of executor instances.
557+
resources_per_executor: Resource requirements per executor.
558+
spark_conf: Spark configuration properties.
559+
driver: Driver configuration.
560+
executor: Executor configuration.
561+
options: List of configuration options (use Name option for custom name).
562+
timeout: Timeout in seconds to wait for session ready.
563+
connect_timeout: Timeout in seconds for SparkSession.getOrCreate().
564+
565+
Returns:
566+
Connected SparkSession.
567+
568+
Raises:
569+
TimeoutError: If session creation or connection times out.
570+
RuntimeError: If session creation or connection fails.
571+
"""
572+
if os.environ.get("SPARK_E2E_DEBUG"):
573+
_enable_spark_debug_logging()
574+
575+
info = self._create_session(
576+
num_executors=num_executors,
577+
resources_per_executor=resources_per_executor,
578+
spark_conf=spark_conf,
579+
driver=driver,
580+
executor=executor,
581+
options=options,
582+
)
583+
logger.info(
584+
"Created session %s/%s, waiting for ready (timeout=%ss)",
585+
info.namespace,
586+
info.name,
587+
timeout,
588+
)
589+
590+
info = self._wait_for_session_ready(info.name, timeout=timeout)
591+
logger.info("Session ready, connecting (service_name=%s)", info.service_name)
592+
593+
return self.connect(info, connect_timeout=connect_timeout)
594+
506595
def get_session_logs(
507596
self,
508597
name: str,

0 commit comments

Comments
 (0)