|
22 | 22 | import random |
23 | 23 | import socket |
24 | 24 | import subprocess |
| 25 | +import sys |
25 | 26 | import threading |
26 | 27 | import time |
27 | 28 | from typing import Optional |
|
44 | 45 |
|
45 | 46 | logger = logging.getLogger(__name__) |
46 | 47 |
|
| 48 | +_spark_debug_logging_enabled = False |
| 49 | + |
| 50 | + |
| 51 | +def _enable_spark_debug_logging() -> None: |
| 52 | + """Turn on INFO logging for kubeflow.spark to stderr (for E2E debug).""" |
| 53 | + global _spark_debug_logging_enabled |
| 54 | + if _spark_debug_logging_enabled: |
| 55 | + return |
| 56 | + _spark_debug_logging_enabled = True |
| 57 | + root = logging.getLogger("kubeflow.spark") |
| 58 | + root.setLevel(logging.INFO) |
| 59 | + if not root.handlers: |
| 60 | + h = logging.StreamHandler(sys.stderr) |
| 61 | + h.setLevel(logging.INFO) |
| 62 | + root.addHandler(h) |
| 63 | + |
47 | 64 |
|
48 | 65 | class KubernetesBackend(RuntimeBackend): |
49 | 66 | """Kubernetes backend for managing SparkConnect sessions.""" |
@@ -104,6 +121,20 @@ def _create_session( |
104 | 121 | options: Optional[list] = None, |
105 | 122 | ) -> SparkConnectInfo: |
106 | 123 | """Create a new SparkConnect session (INTERNAL USE ONLY).""" |
| 124 | + # Validate input types |
| 125 | + if resources_per_executor is not None and not isinstance(resources_per_executor, dict): |
| 126 | + raise TypeError( |
| 127 | + f"resources_per_executor must be a dict, got {type(resources_per_executor)}" |
| 128 | + ) |
| 129 | + if spark_conf is not None and not isinstance(spark_conf, dict): |
| 130 | + raise TypeError(f"spark_conf must be a dict, got {type(spark_conf)}") |
| 131 | + if num_executors is not None and not isinstance(num_executors, int): |
| 132 | + raise TypeError(f"num_executors must be an int, got {type(num_executors)}") |
| 133 | + if driver is not None and not isinstance(driver, Driver): |
| 134 | + raise TypeError(f"driver must be a Driver instance, got {type(driver)}") |
| 135 | + if executor is not None and not isinstance(executor, Executor): |
| 136 | + raise TypeError(f"executor must be an Executor instance, got {type(executor)}") |
| 137 | + |
107 | 138 | # Extract Name option if present, or auto-generate |
108 | 139 | name, filtered_options = self._extract_name_option(options) |
109 | 140 |
|
@@ -503,6 +534,64 @@ def _get_or_create() -> None: |
503 | 534 | ) |
504 | 535 | raise TimeoutError(base_msg) |
505 | 536 |
|
| 537 | + def create_and_connect( |
| 538 | + self, |
| 539 | + num_executors: Optional[int] = None, |
| 540 | + resources_per_executor: Optional[dict[str, str]] = None, |
| 541 | + spark_conf: Optional[dict[str, str]] = None, |
| 542 | + driver: Optional[Driver] = None, |
| 543 | + executor: Optional[Executor] = None, |
| 544 | + options: Optional[list] = None, |
| 545 | + timeout: int = 300, |
| 546 | + connect_timeout: int = 120, |
| 547 | + ) -> SparkSession: |
| 548 | + """Create a new SparkConnect session and connect to it. |
| 549 | +
|
| 550 | + This method handles the full session lifecycle: |
| 551 | + 1. Creates a new session via _create_session |
| 552 | + 2. Waits for session to become ready |
| 553 | + 3. Connects to the session and returns SparkSession |
| 554 | +
|
| 555 | + Args: |
| 556 | + num_executors: Number of executor instances. |
| 557 | + resources_per_executor: Resource requirements per executor. |
| 558 | + spark_conf: Spark configuration properties. |
| 559 | + driver: Driver configuration. |
| 560 | + executor: Executor configuration. |
| 561 | + options: List of configuration options (use Name option for custom name). |
| 562 | + timeout: Timeout in seconds to wait for session ready. |
| 563 | + connect_timeout: Timeout in seconds for SparkSession.getOrCreate(). |
| 564 | +
|
| 565 | + Returns: |
| 566 | + Connected SparkSession. |
| 567 | +
|
| 568 | + Raises: |
| 569 | + TimeoutError: If session creation or connection times out. |
| 570 | + RuntimeError: If session creation or connection fails. |
| 571 | + """ |
| 572 | + if os.environ.get("SPARK_E2E_DEBUG"): |
| 573 | + _enable_spark_debug_logging() |
| 574 | + |
| 575 | + info = self._create_session( |
| 576 | + num_executors=num_executors, |
| 577 | + resources_per_executor=resources_per_executor, |
| 578 | + spark_conf=spark_conf, |
| 579 | + driver=driver, |
| 580 | + executor=executor, |
| 581 | + options=options, |
| 582 | + ) |
| 583 | + logger.info( |
| 584 | + "Created session %s/%s, waiting for ready (timeout=%ss)", |
| 585 | + info.namespace, |
| 586 | + info.name, |
| 587 | + timeout, |
| 588 | + ) |
| 589 | + |
| 590 | + info = self._wait_for_session_ready(info.name, timeout=timeout) |
| 591 | + logger.info("Session ready, connecting (service_name=%s)", info.service_name) |
| 592 | + |
| 593 | + return self.connect(info, connect_timeout=connect_timeout) |
| 594 | + |
506 | 595 | def get_session_logs( |
507 | 596 | self, |
508 | 597 | name: str, |
|
0 commit comments