diff --git a/docs/comparison-dbt-fabricspark.md b/docs/comparison-dbt-fabricspark.md index d7ca1ca5..17da7699 100644 --- a/docs/comparison-dbt-fabricspark.md +++ b/docs/comparison-dbt-fabricspark.md @@ -75,8 +75,10 @@ Notable differences: | Feature | dbt-fabric-samdebruyn | microsoft/dbt-fabricspark | |---|---|---| +| **[High-concurrency Livy](lakehouse.md#high-concurrency-livy)** | Yes (HC-only, instance-based lifecycle) | Yes (default on, `atexit` cleanup) | | **Session creation** | `FabricApiClient` singleton | `LivySessionManager` with static globals | -| **Session reuse** | By session name | Via `session_id_file` + `reuse_session` flag | +| **Session reuse** | Deterministic session tag (HC) | Via `session_id_file` + `reuse_session` flag (singleton) / deterministic session tag (HC) | +| **HC session cleanup** | Connection manager `close()` path | `atexit` handler (fragile — see [Code quality](#code-quality)) | | **Polling interval** | Fixed 3 seconds | Adaptive (configurable) | | **Session idle timeout** | 15 min default | 30 min default, configurable | | **Local Livy mode** | No | Yes (`livy_mode: local`) | @@ -190,7 +192,7 @@ This adapter uses proper instance-based encapsulation: `FabricTokenProvider` (pe ### atexit handler for session cleanup -The upstream registers an `atexit` handler at module import time (`livysession.py` lines 1314-1322) to delete Livy sessions on process exit. This is fragile: `atexit` handlers run in undefined order, logging/network may already be torn down, and merely importing the module registers the handler even if no session was created. +The upstream registers `atexit` handlers at module import time (in both `singleton_livy.py` and `concurrent_livy.py`) to delete Livy sessions and HC sessions on process exit. This is fragile: `atexit` handlers run in undefined order, logging/network may already be torn down, and merely importing the module registers the handler even if no session was created. The HC implementation adds a second `atexit` handler with a global `_active_sessions` set, compounding the global mutable state problem. This adapter manages session lifecycle through dbt's normal connection manager `close()` path. diff --git a/docs/lakehouse.md b/docs/lakehouse.md index 8d700618..9a5ad014 100644 --- a/docs/lakehouse.md +++ b/docs/lakehouse.md @@ -45,37 +45,36 @@ The FabricSpark adapter does not use the [`host`](configuration.md#host) option ## How it works -The FabricSpark adapter executes all SQL through Fabric Livy sessions. Here is the execution flow: +The FabricSpark adapter executes all SQL through Fabric's [high-concurrency Livy API](https://learn.microsoft.com/en-us/fabric/data-engineering/high-concurrency-livy?WT.mc_id=MVP_310840). Each dbt thread gets its own REPL inside a shared underlying Livy session. Here is the execution flow: ```mermaid sequenceDiagram participant dbt participant Adapter - participant Livy API + participant HC Livy API participant Spark Session dbt->>Adapter: Compiled Spark SQL - Adapter->>Livy API: GET /sessions (find existing session) - alt Session exists - Livy API-->>Adapter: Session ID - else No session - Adapter->>Livy API: POST /sessions (create new) - Livy API-->>Adapter: Session ID - Note over Adapter,Spark Session: Session startup: 1-5 minutes + Adapter->>HC Livy API: POST /highConcurrencySessions (acquire REPL) + alt Underlying session exists (warm) + HC Livy API-->>Adapter: HC session ID + REPL ID + else No underlying session + Note over HC Livy API,Spark Session: Spark startup + HC Livy API-->>Adapter: HC session ID + REPL ID end - Adapter->>Livy API: POST /sessions/{id}/statements - Livy API->>Spark Session: Execute Spark SQL + Adapter->>HC Livy API: POST /highConcurrencySessions/{id}/repls/{replId}/statements + HC Livy API->>Spark Session: Execute Spark SQL (in REPL) loop Poll every 3 seconds - Adapter->>Livy API: GET /statements/{id} - Livy API-->>Adapter: Status + results (when done) + Adapter->>HC Livy API: GET /highConcurrencySessions/{id}/repls/{replId}/statements/{stmtId} + HC Livy API-->>Adapter: Status + results (when done) end Adapter-->>dbt: Parsed results ``` Key technical details: -- **Session reuse** -- All statements in a dbt run share the same Livy session (named `dbt-fabric-samdebruyn` by default). This avoids the overhead of creating a new Spark session for each model. -- **Session TTL** -- Sessions are created with a TTL of 30 seconds. If the session is idle for longer than that after the dbt run finishes, Fabric will automatically clean it up. +- **One REPL per thread** -- Each dbt thread acquires its own REPL inside a shared underlying Livy session. Statements from different REPLs execute in parallel. +- **Deterministic session tag** -- The adapter computes a session tag from `(workspace_id, lakehouse_id)`. Fabric packs all REPLs with the same tag onto one underlying Livy session, enabling warm session reuse across dbt invocations. - **Polling interval** -- The adapter polls for statement completion every 3 seconds. - **Rate limiting** -- The Fabric Livy API enforces rate limits. The adapter handles HTTP 429 responses automatically using the `Retry-After` header. - **DB-API 2.0 cursor** -- Results are returned as JSON and parsed into a [PEP 249](https://peps.python.org/pep-0249/) compatible cursor, so dbt interacts with the Lakehouse the same way it interacts with any other database. @@ -125,17 +124,45 @@ SELECT [my column] FROM [my_schema].[my_table] --- +## High-concurrency Livy + +The adapter uses Fabric's [high-concurrency Livy API](https://learn.microsoft.com/en-us/fabric/data-engineering/high-concurrency-livy?WT.mc_id=MVP_310840). Each dbt thread acquires its own HC session -- and therefore its own REPL -- inside a single underlying Livy session shared via a deterministic `sessionTag` derived from `(workspace_id, lakehouse_id)`. Statements from different REPLs execute in **parallel** inside the same Spark application, so increasing `threads` in your profile directly increases throughput. + +### Session reuse across runs + +The session tag is deterministic: every dbt invocation targeting the same workspace + lakehouse produces the same tag. Fabric snap-attaches new REPLs onto the still-warm underlying Livy session, skipping the Spark cold-start entirely on subsequent runs. + +### `threads > 5` + +Fabric packs up to **5 REPLs onto one underlying Livy session** (see the [HC Livy key concepts](https://learn.microsoft.com/en-us/fabric/data-engineering/high-concurrency-livy?WT.mc_id=MVP_310840#key-concepts)). With `threads > 5`, dbt still works correctly -- Fabric spins up a second underlying Livy session to host the 6th REPL onwards. + +| Property | Shared across underlying sessions? | +| --- | --- | +| OneLake Delta tables (dbt model outputs) | Yes -- same lakehouse storage | +| Catalog / metastore (`SELECT FROM `) | Yes -- same Fabric catalog | +| Temp views (`CREATE TEMPORARY VIEW ...`) | No -- REPL/session-local | +| Session-level Spark configs (`SET spark.sql.X = ...`) | No | +| Cached datasets / UDFs / broadcast vars | No | + +Because dbt-fabricspark materializations always write permanent Delta / lake view objects, model-to-model `ref`s resolve correctly regardless of which underlying session produced or consumes the table. + +!!! note "Cost tradeoff" + + Each additional underlying Livy session is a separate Spark cluster billed for the duration of the run plus the idle timeout. Keep `threads ≤ 5` for the cheapest profile; raise it only when the extra parallelism beats the extra compute spend. + +--- + ## Performance considerations The Livy API architecture has inherent performance characteristics that are important to understand. ### Session startup -Creating a new Spark session can take **1-5 minutes**. The adapter reuses sessions within a run, so this overhead is paid once per `dbt run`. Subsequent runs may reuse an existing session if it is still alive. +Creating a new Spark session takes some time. The adapter reuses sessions within a run, so this overhead is paid once per `dbt run`. Subsequent runs may reuse an existing session if it is still alive. The [high-concurrency Livy](#high-concurrency-livy) session tag is deterministic, so subsequent runs can skip startup entirely by reattaching to a warm session. ### Statement execution -Each SQL statement involves multiple HTTP API calls (submit + poll). This is inherently slower than a direct database connection like the TDS protocol used by the Data Warehouse adapter. +Each SQL statement involves multiple HTTP API calls (submit + poll). This is inherently slower than a direct database connection like the TDS protocol used by the Data Warehouse adapter. Statements from different threads execute in parallel via [high-concurrency Livy](#high-concurrency-livy), significantly improving wall-clock time for multi-model runs. ### Polling overhead @@ -147,7 +174,7 @@ Fabric applies rate limits to the Livy API. The adapter handles HTTP 429 respons ### Practical impact -A dbt run with many models will be significantly slower on FabricSpark than on Fabric Data Warehouse. This is inherent to the Livy API architecture, not a limitation of the adapter. +A dbt run with many models will be significantly slower on FabricSpark than on Fabric Data Warehouse. This is inherent to the Livy API architecture, not a limitation of the adapter. [High-concurrency Livy](#high-concurrency-livy) reduces this gap by running statements in parallel. ### Recommendations @@ -194,7 +221,7 @@ See the [Python models guide](python-models.md) for writing and debugging Python - **No Spark SQL views** -- only tables and materialized lake views (Fabric lake views) are supported. - **No incremental merge strategy** -- the Spark SQL `MERGE` syntax in Fabric Lakehouse is not supported by the adapter. Use `append` or `insert_overwrite` instead. - **API rate limiting** -- can slow down large runs with many models. -- **Session startup time** -- 1-5 minutes for the first statement in a run. +- **Session startup time** -- creating a new Spark session adds latency to the first statement in a run. - **Data Warehouse-only features** -- [CLUSTER BY](cluster-by.md), [warehouse snapshots](warehouse-snapshots.md), and [catalog statistics](catalog-stats.md) are not available for Lakehouse. --- diff --git a/src/dbt/adapters/fabric/base_fabric_adapter.py b/src/dbt/adapters/fabric/base_fabric_adapter.py index ca6c7685..e97d8567 100644 --- a/src/dbt/adapters/fabric/base_fabric_adapter.py +++ b/src/dbt/adapters/fabric/base_fabric_adapter.py @@ -6,7 +6,7 @@ from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.fabric.fabric_livy_helper import FabricLivyHelper -from dbt.adapters.fabric.fabric_livy_session import LivySubmissionResult +from dbt.adapters.fabric.livy_result import LivySubmissionResult from dbt.adapters.fabric.purview_sync import PurviewSync, extract_syncable_models from dbt.adapters.sql.impl import SQLAdapter diff --git a/src/dbt/adapters/fabric/fabric_api_client.py b/src/dbt/adapters/fabric/fabric_api_client.py index 0af7c722..9b90c78d 100644 --- a/src/dbt/adapters/fabric/fabric_api_client.py +++ b/src/dbt/adapters/fabric/fabric_api_client.py @@ -1,5 +1,4 @@ import logging -import threading import time import urllib.parse from typing import Any, Self @@ -12,8 +11,6 @@ logger = logging.getLogger(__name__) -_livy_session_thread_lock = threading.Lock() - class FabricApiError(dbt_common.exceptions.DbtRuntimeError): def __init__(self, method: str, url: str, status_code: int, response_text: str) -> None: @@ -39,7 +36,6 @@ def __init__( self._workspace_id: str | None = None self._cached_warehouses: list[dict] | None = None self._cached_lakehouses: list[dict] | None = None - self._livy_session_id: str | None = None self._warehouse_snapshot_operations: dict[str, str] = {} @classmethod @@ -418,123 +414,81 @@ def get_livy_base_api_uri(self) -> str: f"/lakehouses/{lakehouse_id}/livyapi/versions/{self._LIVY_API_VERSION}" ) - def get_existing_livy_session(self) -> str | None: - """Find an active Livy session matching the configured name, or return None.""" - url = self.get_livy_base_api_uri() + "/sessions" - response = self._api_get(url) - sessions = response.json().get("items", []) - for session in sessions: - if session["name"] == self._credentials.livy_session_name and session["livyState"] in ( - "idle", - "starting", - "running", - "busy", - ): - return session["id"] - return None - - def initialize_livy_session(self) -> str: - """Create a new Livy session and wait briefly for it to start.""" - url = self.get_livy_base_api_uri() + "/sessions" - body = {"name": self._credentials.livy_session_name, "ttl": "30s"} - - max_attempts = 3 - backoff_seconds = 5 - last_exception: Exception | None = None - - for attempt in range(1, max_attempts + 1): - try: - response = self._api_post(url, body) - time.sleep(10) - return response.json()["id"] - except FabricApiError as e: - is_transient = e.status_code == 404 or 500 <= e.status_code < 600 - - if not is_transient or attempt == max_attempts: - raise - - last_exception = e - wait_time = backoff_seconds * (2 ** (attempt - 1)) - logger.warning( - f"Livy session creation returned a transient error " - f"(attempt {attempt}/{max_attempts}), retrying in {wait_time}s: {e}" - ) - time.sleep(wait_time) + def acquire_hc_session(self, session_tag: str) -> dict[str, Any]: + """POST /highConcurrencySessions to acquire an HC session (= one REPL). - assert last_exception is not None - raise last_exception - - def get_livy_session_id(self) -> str: - """Return the active Livy session ID, reusing an existing session or creating one. + Args: + session_tag: Deterministic tag so Fabric packs all REPLs from + the same process onto one underlying Livy session. - Thread-safe: uses a lock to prevent multiple sessions from being created - concurrently when dbt runs with multiple threads. + Returns: + The JSON response body containing at least ``id`` and ``state``. """ - if self._livy_session_id is None: - with _livy_session_thread_lock: - self._livy_session_id = ( - self.get_existing_livy_session() or self.initialize_livy_session() - ) - return self._livy_session_id - - def get_livy_session_base_uri(self) -> str: - """Build the API URI for the current Livy session.""" - return self.get_livy_base_api_uri() + f"/sessions/{self.get_livy_session_id()}" - - def get_livy_session_state(self) -> str: - """Query the current state of the Livy session (idle, busy, starting, etc.).""" - response = self._api_get(self.get_livy_session_base_uri()) - return response.json().get("state", "unknown") + url = self.get_livy_base_api_uri() + "/highConcurrencySessions" + body: dict[str, Any] = { + "sessionTag": session_tag, + "name": self._credentials.livy_session_name, + } + response = self._api_post(url, body) + return response.json() - def get_livy_statement(self, statement_id: int) -> dict[str, Any]: - """Fetch the current status and output of a Livy statement. + def get_hc_session(self, hc_id: str) -> dict[str, Any]: + """Poll the state of an HC session. - Args: - statement_id: The statement ID returned by a submit call. + Returns: + JSON with ``state``, and when idle also ``sessionId`` and ``replId``. """ - url = self.get_livy_session_base_uri() + f"/statements/{statement_id}" + url = self.get_livy_base_api_uri() + f"/highConcurrencySessions/{hc_id}" response = self._api_get(url) return response.json() - def submit_livy_python_statement(self, code: str) -> int: - """Submit Python code to the Livy session and return the statement ID. + def submit_hc_sql_statement(self, livy_session_id: str, repl_id: str, code: str) -> int: + """Submit a SQL statement via an HC REPL. Returns the statement ID.""" + url = ( + self.get_livy_base_api_uri() + + f"/highConcurrencySessions/{livy_session_id}" + + f"/repls/{repl_id}/statements" + ) + response = self._api_post(url, {"code": code, "kind": "sql"}) + return response.json()["id"] - Args: - code: The Python/PySpark code to execute. - """ - url = self.get_livy_session_base_uri() + "/statements" + def submit_hc_python_statement(self, livy_session_id: str, repl_id: str, code: str) -> int: + """Submit a Python statement via an HC REPL. Returns the statement ID.""" + url = ( + self.get_livy_base_api_uri() + + f"/highConcurrencySessions/{livy_session_id}" + + f"/repls/{repl_id}/statements" + ) response = self._api_post(url, {"code": code, "kind": "pyspark"}) return response.json()["id"] - def submit_livy_sql_statement(self, code: str) -> int: - """Submit SQL code to the Livy session and return the statement ID. + def get_hc_statement( + self, livy_session_id: str, repl_id: str, statement_id: int + ) -> dict[str, Any]: + """Fetch the status and output of an HC REPL statement.""" + url = ( + self.get_livy_base_api_uri() + + f"/highConcurrencySessions/{livy_session_id}" + + f"/repls/{repl_id}/statements/{statement_id}" + ) + response = self._api_get(url) + return response.json() - Args: - code: The Spark SQL code to execute. - """ - url = self.get_livy_session_base_uri() + "/statements" - response = self._api_post(url, {"code": code, "kind": "sql"}) - return response.json()["id"] + def cancel_hc_statement(self, livy_session_id: str, repl_id: str, statement_id: int) -> str: + """Cancel a running HC REPL statement.""" + url = ( + self.get_livy_base_api_uri() + + f"/highConcurrencySessions/{livy_session_id}" + + f"/repls/{repl_id}/statements/{statement_id}/cancel" + ) + response = self._api_post(url, {}) + return response.json()["msg"] - def delete_livy_session(self) -> None: - """Delete the current Livy session and clear the cached session ID.""" - if self._livy_session_id is None: - return - session_id = self._livy_session_id - url = self.get_livy_base_api_uri() + f"/sessions/{session_id}" + def delete_hc_session(self, hc_id: str) -> None: + """Release an HC session (REPL slot). Best-effort; ignores 404.""" + url = self.get_livy_base_api_uri() + f"/highConcurrencySessions/{hc_id}" try: self._api_delete(url) except FabricApiError as e: if e.status_code != 404: raise - self._livy_session_id = None - - def cancel_livy_statement(self, statement_id: int) -> str: - """Cancel a running Livy statement. - - Args: - statement_id: The statement ID to cancel. - """ - url = self.get_livy_session_base_uri() + f"/statements/{statement_id}/cancel" - response = self._api_post(url, {}) - return response.json()["msg"] diff --git a/src/dbt/adapters/fabric/fabric_hc_livy_session.py b/src/dbt/adapters/fabric/fabric_hc_livy_session.py new file mode 100644 index 00000000..7e716c67 --- /dev/null +++ b/src/dbt/adapters/fabric/fabric_hc_livy_session.py @@ -0,0 +1,289 @@ +import contextlib +import hashlib +import json +import time +from dataclasses import dataclass +from typing import Any + +import requests + +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.fabric.fabric_api_client import FabricApiClient, FabricApiError +from dbt.adapters.fabric.livy_result import LivySessionResult + +logger = AdapterLogger("fabricspark") + +_TERMINAL_BAD_STATES = frozenset({"Dead", "Killed", "Failed", "Error"}) +_TRANSIENT_EXCEPTIONS = ( + FabricApiError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ChunkedEncodingError, + json.JSONDecodeError, +) + + +def derive_session_tag(workspace_id: str, lakehouse_id: str) -> str: + """Deterministic session tag from (workspace_id, lakehouse_id). + + All dbt threads in the same process produce the same tag, so Fabric packs + their REPLs onto one underlying Livy session. Successive dbt invocations + targeting the same workspace + lakehouse also produce the same tag, letting + Fabric snap-attach new REPLs onto the still-warm session. + """ + material = f"{workspace_id}|{lakehouse_id}" + digest = hashlib.sha256(material.encode("utf-8")).hexdigest()[:24] + return f"dbt-fabricspark-{digest}" + + +@dataclass +class HCSessionState: + hc_id: str | None = None + session_id: str | None = None + repl_id: str | None = None + is_dead: bool = False + + +class HighConcurrencyLivySession: + """One HC REPL per dbt thread. + + Acquires an HC session via ``POST /highConcurrencySessions``, polls until + the underlying Livy session is idle and a REPL is allocated, then submits + statements through the REPL endpoint. + + ``close()`` DELETEs this instance's HC session (REPL slot) only — the + underlying Spark session is managed by Fabric and stays alive for other + REPLs and processes. + """ + + _POLLING_INTERVAL = 3 + _MAX_CONSECUTIVE_TRANSIENT_ERRORS = 5 + _TERMINAL_STATEMENT_STATES = frozenset({"available", "error", "cancelled", "cancelling"}) + + def __init__(self, fabric_api_client: FabricApiClient) -> None: + self._fabric_api_client = fabric_api_client + self._state = HCSessionState() + self._session_tag: str | None = None + + def _get_session_tag(self) -> str: + if self._session_tag is None: + workspace_id = self._fabric_api_client.get_workspace_id() + lakehouse_id = self._fabric_api_client.get_lakehouse_id() + self._session_tag = derive_session_tag(workspace_id, lakehouse_id) + return self._session_tag + + def get_logs_url(self) -> str: + """Build the Fabric Portal URL to the Spark monitor logs for this session.""" + api_uri = self._fabric_api_client._credentials.fabric_base_api_uri + portal_host = api_uri.replace("://api.", "://app.").split("/v")[0] + lakehouse_id = self._fabric_api_client.get_lakehouse_id() + session_id = self._state.session_id or "unknown" + return f"{portal_host}/workloads/de-ds/sparkmonitor/{lakehouse_id}/{session_id}" + + # ---- acquire ----------------------------------------------------------- + + def wait_for_session_ready(self) -> None: + """Acquire an HC session and poll until the REPL is ready.""" + tag = self._get_session_tag() + logger.debug(f"Acquiring HC session (sessionTag={tag})") + + max_attempts = 3 + backoff_seconds = 5 + last_exception: Exception | None = None + + for attempt in range(1, max_attempts + 1): + try: + body = self._fabric_api_client.acquire_hc_session(tag) + break + except _TRANSIENT_EXCEPTIONS as e: + is_api_error = isinstance(e, FabricApiError) + if is_api_error and not (e.status_code == 404 or 500 <= e.status_code < 600): + raise + if attempt == max_attempts: + raise + last_exception = e + wait_time = backoff_seconds * (2 ** (attempt - 1)) + logger.warning( + f"HC session acquire returned a transient error " + f"(attempt {attempt}/{max_attempts}), retrying in {wait_time}s: {e}" + ) + time.sleep(wait_time) + else: + assert last_exception is not None + raise last_exception + + hc_id = body.get("id") + if not hc_id: + raise RuntimeError(f"HC acquire response missing 'id': {body}") + + self._state.hc_id = str(hc_id) + try: + self._poll_until_idle() + except Exception: + with contextlib.suppress(Exception): + self._fabric_api_client.delete_hc_session(str(hc_id)) + self._state = HCSessionState() + raise + self._state.is_dead = False + logger.debug( + f"HC session ready: hc_id={self._state.hc_id} " + f"sessionId={self._state.session_id} replId={self._state.repl_id}" + ) + + def _poll_until_idle(self) -> None: + start_time = time.time() + timeout = self._fabric_api_client._credentials.spark_session_timeout + consecutive_errors = 0 + + while True: + if time.time() - start_time >= timeout: + raise TimeoutError( + f"Timeout ({timeout}s) waiting for HC session {self._state.hc_id} " + f"to become Idle. Increase `spark_session_timeout` in profiles.yml." + ) + + try: + body = self._fabric_api_client.get_hc_session(self._state.hc_id) + consecutive_errors = 0 + except _TRANSIENT_EXCEPTIONS as e: + consecutive_errors += 1 + if consecutive_errors >= self._MAX_CONSECUTIVE_TRANSIENT_ERRORS: + raise + logger.warning( + f"Transient error polling HC session {self._state.hc_id} " + f"({consecutive_errors}/{self._MAX_CONSECUTIVE_TRANSIENT_ERRORS}): {e}" + ) + time.sleep(self._POLLING_INTERVAL) + continue + + state = body.get("state", "") + + if state in _TERMINAL_BAD_STATES: + err = body.get("fabricSessionStateInfo", {}).get("errorMessage") or state + raise RuntimeError(f"HC session {self._state.hc_id} state={state}: {err}") + + if state == "Idle" and body.get("sessionId") and body.get("replId"): + self._state.session_id = str(body["sessionId"]) + self._state.repl_id = str(body["replId"]) + return + + time.sleep(self._POLLING_INTERVAL) + + def _ensure_repl(self) -> None: + """Re-acquire this thread's HC session if it was marked dead.""" + if self._state.is_dead or self._state.hc_id is None: + logger.debug("HC REPL marked stale — re-acquiring") + if self._state.hc_id is not None: + with contextlib.suppress(Exception): + self._fabric_api_client.delete_hc_session(self._state.hc_id) + self._state = HCSessionState() + self.wait_for_session_ready() + + def cancel_statement(self, statement_id: int) -> None: + """Cancel a running statement via the HC REPL endpoint.""" + assert self._state.session_id is not None + assert self._state.repl_id is not None + self._fabric_api_client.cancel_hc_statement( + self._state.session_id, self._state.repl_id, statement_id + ) + + # ---- statement execution ----------------------------------------------- + + def run_statement( + self, statement_code: str, statement_language: str, wait_for_result: bool = True + ) -> LivySessionResult | int: + """Submit a statement and optionally wait for its result. + + Same interface as ``LivySession.run_statement``. + """ + self._ensure_repl() + assert self._state.session_id is not None + assert self._state.repl_id is not None + + try: + if statement_language == "sql": + statement_id = self._fabric_api_client.submit_hc_sql_statement( + self._state.session_id, self._state.repl_id, statement_code + ) + else: + statement_id = self._fabric_api_client.submit_hc_python_statement( + self._state.session_id, self._state.repl_id, statement_code + ) + except FabricApiError as e: + if e.status_code == 404: + self._state.is_dead = True + logger.debug("HC statement submit returned 404 — flagging REPL for re-acquire") + return LivySessionResult(success=False, error_message=str(e)) + + if wait_for_result: + return self.wait_and_get_statement_result(statement_id) + else: + return statement_id + + def wait_for_statement_ready(self, statement_id: int) -> dict[str, Any]: + """Poll an HC REPL statement until it reaches a terminal state.""" + assert self._state.session_id is not None + assert self._state.repl_id is not None + + start_time = time.time() + while True: + response = self._fabric_api_client.get_hc_statement( + self._state.session_id, self._state.repl_id, statement_id + ) + statement_state = response.get("state", "unknown") + if statement_state in self._TERMINAL_STATEMENT_STATES: + return response + if time.time() - start_time >= self._fabric_api_client._credentials.query_timeout: + raise TimeoutError("HC Livy statement did not become available in time.") + time.sleep(self._POLLING_INTERVAL) + + def wait_and_get_statement_result(self, statement_id: int) -> LivySessionResult: + """Wait for a statement to complete and return its result.""" + try: + response = self.wait_for_statement_ready(statement_id) + output = response.get("output", {}) + success = response["state"] == "available" and output.get("status") == "ok" + error_message = output.get("evalue") + if not success and not error_message: + error_message = f"Statement ended with state '{response.get('state')}'" + return LivySessionResult( + statement_id=statement_id, + success=success, + error_message=error_message, + status_code=output.get("status"), + json_data=output.get("data", {}).get("application/json", {}), + ) + except FabricApiError as e: + if e.status_code == 404: + self._state.is_dead = True + logger.debug("HC statement poll returned 404 — flagging REPL for re-acquire") + logger.error( + f"Error while waiting for HC statement to be ready. " + f"Logs URL: {self.get_logs_url()}" + ) + logger.exception(e) + return LivySessionResult( + statement_id=statement_id, success=False, error_message=str(e) + ) + except Exception as e: + logger.error( + f"Error while waiting for HC statement to be ready. " + f"Logs URL: {self.get_logs_url()}" + ) + logger.exception(e) + return LivySessionResult( + statement_id=statement_id, success=False, error_message=str(e) + ) + + # ---- cleanup ----------------------------------------------------------- + + def close(self) -> None: + """Release the HC session, freeing the REPL slot.""" + if self._state.hc_id is not None: + try: + self._fabric_api_client.delete_hc_session(self._state.hc_id) + logger.debug(f"Released HC session {self._state.hc_id}") + except Exception as ex: + logger.warning(f"Failed to delete HC session {self._state.hc_id}: {ex}") + finally: + self._state = HCSessionState() diff --git a/src/dbt/adapters/fabric/fabric_livy_helper.py b/src/dbt/adapters/fabric/fabric_livy_helper.py index 96fd3b53..fee5fbf1 100644 --- a/src/dbt/adapters/fabric/fabric_livy_helper.py +++ b/src/dbt/adapters/fabric/fabric_livy_helper.py @@ -1,3 +1,4 @@ +import threading from typing import Any from dbt_common.exceptions import DbtRuntimeError @@ -5,12 +6,14 @@ from dbt.adapters.base.impl import PythonJobHelper from dbt.adapters.fabric.fabric_api_client import FabricApiClient from dbt.adapters.fabric.fabric_credentials import FabricCredentials -from dbt.adapters.fabric.fabric_livy_session import LivySession, LivySessionResult +from dbt.adapters.fabric.fabric_hc_livy_session import HighConcurrencyLivySession from dbt.adapters.fabric.fabric_token_provider import FabricTokenProvider +from dbt.adapters.fabric.livy_result import LivySessionResult + +_thread_local = threading.local() class FabricLivyHelper(PythonJobHelper): - _livy_session: LivySession | None = None _sql_endpoint: str | None = None def __init__(self, parsed_model: dict, credential: FabricCredentials) -> None: @@ -18,22 +21,22 @@ def __init__(self, parsed_model: dict, credential: FabricCredentials) -> None: credential, FabricTokenProvider(credential) ) - if not self._livy_session: - self._livy_session = LivySession(fabric_api_client) + if not getattr(_thread_local, "livy_session", None): + _thread_local.livy_session = HighConcurrencyLivySession(fabric_api_client) if not self._sql_endpoint: self._sql_endpoint = fabric_api_client.get_warehouse_connection_string() def submit(self, compiled_code: str) -> Any: - assert self._livy_session is not None + livy_session: HighConcurrencyLivySession = _thread_local.livy_session assert self._sql_endpoint is not None compiled_code = compiled_code.replace("DBT_FABRIC_REPLACED_WITH_HOST", self._sql_endpoint) - result = self._livy_session.run_statement(compiled_code, "python") + result = livy_session.run_statement(compiled_code, "python") assert isinstance(result, LivySessionResult) if not result.success: raise DbtRuntimeError( f"Python statement execution failed. " - f"Logs URL: {self._livy_session.get_logs_url()}. " + f"Logs URL: {livy_session.get_logs_url()}. " f"Error: {result.error_message}" ) return result.to_submission_result(compiled_code) diff --git a/src/dbt/adapters/fabric/fabric_livy_session.py b/src/dbt/adapters/fabric/fabric_livy_session.py deleted file mode 100644 index ff926392..00000000 --- a/src/dbt/adapters/fabric/fabric_livy_session.py +++ /dev/null @@ -1,198 +0,0 @@ -import json -import time -from dataclasses import dataclass, field -from typing import Any - -import requests - -from dbt.adapters.base.impl import PythonSubmissionResult -from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.fabric.fabric_api_client import FabricApiClient - -logger = AdapterLogger("fabricspark") - - -@dataclass -class LivySubmissionResult(PythonSubmissionResult): - success: bool - error_message: str | None = None - - -@dataclass -class LivySessionResult: - statement_id: int = -1 - success: bool = False - error_message: str | None = None - status_code: str | None = None - json_data: dict[str, Any] | None = field(default_factory=dict) - - def to_submission_result(self, code: str) -> LivySubmissionResult: - """Convert this result to a LivySubmissionResult for the dbt adapter response. - - Args: - code: The compiled Python code that was submitted. - """ - return LivySubmissionResult( - run_id=str(self.statement_id), - compiled_code=code, - success=self.success, - error_message=self.error_message, - ) - - -class LivySession: - _POLLING_INTERVAL = 3 # seconds - _MAX_CONSECUTIVE_TRANSIENT_ERRORS = 5 - _FATAL_SESSION_STATES = frozenset({"dead", "killed", "error", "shutting_down"}) - _TERMINAL_STATEMENT_STATES = frozenset({"available", "error", "cancelled", "cancelling"}) - - def __init__(self, fabric_api_client: FabricApiClient) -> None: - self._fabric_api_client = fabric_api_client - - def get_logs_url(self) -> str: - """Build the Fabric Portal URL to the Spark monitor logs for this session.""" - api_uri = self._fabric_api_client._credentials.fabric_base_api_uri - portal_host = api_uri.replace("://api.", "://app.").split("/v")[0] - lakehouse_id = self._fabric_api_client.get_lakehouse_id() - session_id = self._fabric_api_client.get_livy_session_id() - return f"{portal_host}/workloads/de-ds/sparkmonitor/{lakehouse_id}/{session_id}" - - def wait_for_session_ready(self) -> None: - """Poll until the Livy session reaches the idle state. - - Raises: - TimeoutError: If the session does not become idle within - the configured ``spark_session_timeout``. - """ - start_time = time.time() - consecutive_errors = 0 - - while True: - try: - state = self._fabric_api_client.get_livy_session_state() - consecutive_errors = 0 - except ( - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.ChunkedEncodingError, - json.JSONDecodeError, - ) as e: - consecutive_errors += 1 - if consecutive_errors >= self._MAX_CONSECUTIVE_TRANSIENT_ERRORS: - raise - logger.warning( - f"Transient error polling Livy session state " - f"({consecutive_errors}/{self._MAX_CONSECUTIVE_TRANSIENT_ERRORS}): {e}" - ) - time.sleep(self._POLLING_INTERVAL) - continue - - if state == "idle": - return - - if state in self._FATAL_SESSION_STATES: - raise RuntimeError( - f"Livy session entered fatal state '{state}' and cannot recover." - ) - - if ( - time.time() - start_time - >= self._fabric_api_client._credentials.spark_session_timeout - ): - raise TimeoutError("Livy session did not become idle in time.") - time.sleep(self._POLLING_INTERVAL) - - def wait_for_statement_ready(self, statement_id: int) -> dict[str, Any]: - """Poll a Livy statement until it reaches a terminal state. - - Args: - statement_id: The statement ID to poll. - - Raises: - TimeoutError: If the statement does not complete within - the configured ``query_timeout``. - """ - start_time = time.time() - while True: - statement_response = self._fabric_api_client.get_livy_statement(statement_id) - statement_state = statement_response.get("state", "unknown") - if statement_state in self._TERMINAL_STATEMENT_STATES: - return statement_response - if time.time() - start_time >= self._fabric_api_client._credentials.query_timeout: - raise TimeoutError("Livy statement did not become available in time.") - time.sleep(self._POLLING_INTERVAL) - - def wait_and_get_statement_result(self, statement_id: int) -> LivySessionResult: - """Wait for a statement to complete and return its result. - - Unlike ``wait_for_statement_ready``, this method catches all exceptions - and returns a failed ``LivySessionResult`` instead of raising. - - Args: - statement_id: The statement ID to wait for. - """ - try: - response = self.wait_for_statement_ready(statement_id) - output = response.get("output", {}) - success = response["state"] == "available" and output.get("status") == "ok" - error_message = output.get("evalue") - if not success and not error_message: - error_message = f"Statement ended with state '{response.get('state')}'" - return LivySessionResult( - statement_id=statement_id, - success=success, - error_message=error_message, - status_code=output.get("status"), - json_data=output.get("data", {}).get("application/json", {}), - ) - except TimeoutError as e: - logger.error( - f"Timeout (> {self._fabric_api_client._credentials.query_timeout}s) while waiting " - f"for Livy statement to be ready. Logs URL: {self.get_logs_url()}" - ) - logger.exception(e) - return LivySessionResult( - statement_id=statement_id, success=False, error_message=str(e) - ) - except Exception as e: - logger.error( - f"Error while waiting for Livy statement to be ready. " - f"Logs URL: {self.get_logs_url()}" - ) - logger.exception(e) - return LivySessionResult( - statement_id=statement_id, success=False, error_message=str(e) - ) - - def run_statement( - self, statement_code: str, statement_language: str, wait_for_result: bool = True - ) -> LivySessionResult | int: - """Submit a Python or SQL statement and optionally wait for its result. - - Waits for the session to be idle before submitting. If submission fails, - returns a failed ``LivySessionResult`` instead of raising. - - Args: - statement_code: The code to execute. - statement_language: Either ``"sql"`` or ``"python"``. - wait_for_result: If True, block until the statement completes and - return a ``LivySessionResult``. If False, return the statement ID. - """ - try: - self.wait_for_session_ready() - func = ( - self._fabric_api_client.submit_livy_sql_statement - if statement_language == "sql" - else self._fabric_api_client.submit_livy_python_statement - ) - statement_id = func(statement_code) - except Exception as e: - logger.error( - f"Error while creating for Livy statement. Logs URL: {self.get_logs_url()}" - ) - logger.exception(e) - return LivySessionResult(success=False, error_message=str(e)) - if wait_for_result: - return self.wait_and_get_statement_result(statement_id) - else: - return statement_id diff --git a/src/dbt/adapters/fabric/livy_result.py b/src/dbt/adapters/fabric/livy_result.py new file mode 100644 index 00000000..91e37129 --- /dev/null +++ b/src/dbt/adapters/fabric/livy_result.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass, field +from typing import Any + +from dbt.adapters.base.impl import PythonSubmissionResult + + +@dataclass +class LivySubmissionResult(PythonSubmissionResult): + success: bool + error_message: str | None = None + + +@dataclass +class LivySessionResult: + statement_id: int = -1 + success: bool = False + error_message: str | None = None + status_code: str | None = None + json_data: dict[str, Any] | None = field(default_factory=dict) + + def to_submission_result(self, code: str) -> LivySubmissionResult: + return LivySubmissionResult( + run_id=str(self.statement_id), + compiled_code=code, + success=self.success, + error_message=self.error_message, + ) diff --git a/src/dbt/adapters/fabricspark/fabricspark_connection.py b/src/dbt/adapters/fabricspark/fabricspark_connection.py index 2e2dac2f..6ca1ecd8 100644 --- a/src/dbt/adapters/fabricspark/fabricspark_connection.py +++ b/src/dbt/adapters/fabricspark/fabricspark_connection.py @@ -1,23 +1,30 @@ +from __future__ import annotations + import weakref +from typing import TYPE_CHECKING from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.fabric.fabric_livy_session import LivySession from dbt.adapters.fabricspark.fabricspark_cursor import FabricSparkCursor +if TYPE_CHECKING: + from dbt.adapters.fabric.fabric_hc_livy_session import HighConcurrencyLivySession + logger = AdapterLogger("fabricspark") class FabricSparkConnection: """A DB-API 2.0 (PEP 249) compatible connection for Fabric Spark.""" - def __init__(self, livy_session: LivySession) -> None: - self._livy_session: LivySession | None = livy_session + def __init__(self, livy_session: HighConcurrencyLivySession) -> None: + self._livy_session: HighConcurrencyLivySession | None = livy_session self._cursors: weakref.WeakSet[FabricSparkCursor] = weakref.WeakSet() def close(self) -> None: for cursor in self._cursors: cursor.close() self._cursors.clear() + if self._livy_session is not None: + self._livy_session.close() self._livy_session = None def cancel(self) -> None: @@ -32,6 +39,6 @@ def cursor(self) -> FabricSparkCursor: self._cursors.add(cursor) return cursor - def get_livy_session(self) -> LivySession: + def get_livy_session(self) -> HighConcurrencyLivySession: assert self._livy_session is not None, "Connection is closed" return self._livy_session diff --git a/src/dbt/adapters/fabricspark/fabricspark_connection_manager.py b/src/dbt/adapters/fabricspark/fabricspark_connection_manager.py index 434eb341..a6c6576f 100644 --- a/src/dbt/adapters/fabricspark/fabricspark_connection_manager.py +++ b/src/dbt/adapters/fabricspark/fabricspark_connection_manager.py @@ -4,7 +4,7 @@ from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.fabric.base_connection_manager import BaseFabricConnectionManager -from dbt.adapters.fabric.fabric_livy_session import LivySession +from dbt.adapters.fabric.fabric_hc_livy_session import HighConcurrencyLivySession from dbt.adapters.fabricspark.fabricspark_connection import FabricSparkConnection logger = AdapterLogger("fabricspark") @@ -48,8 +48,11 @@ def open(cls, connection: Connection) -> Connection: logger.debug("Connection is already open, skipping open.") return connection + credentials = connection.credentials + def connect() -> FabricSparkConnection: - livy_session = LivySession(cls.get_fabric_api_client(connection.credentials)) + api_client = cls.get_fabric_api_client(credentials) + livy_session = HighConcurrencyLivySession(api_client) livy_session.wait_for_session_ready() return FabricSparkConnection(livy_session) @@ -57,7 +60,7 @@ def connect() -> FabricSparkConnection: connection, connect=connect, logger=logger, - retry_limit=connection.credentials.retries, + retry_limit=credentials.retries, retry_timeout=10, retryable_exceptions=[TimeoutError], ) diff --git a/src/dbt/adapters/fabricspark/fabricspark_cursor.py b/src/dbt/adapters/fabricspark/fabricspark_cursor.py index 992cae15..15430e4f 100644 --- a/src/dbt/adapters/fabricspark/fabricspark_cursor.py +++ b/src/dbt/adapters/fabricspark/fabricspark_cursor.py @@ -6,7 +6,8 @@ from dbt_common.exceptions import DbtDatabaseError, DbtRuntimeError -from dbt.adapters.fabric.fabric_livy_session import LivySession, LivySessionResult +from dbt.adapters.fabric.fabric_hc_livy_session import HighConcurrencyLivySession +from dbt.adapters.fabric.livy_result import LivySessionResult class FabricSparkCursor: @@ -36,7 +37,7 @@ def close(self) -> None: self._position = 0 self._statement_id = None - def get_livy_session(self) -> LivySession: + def get_livy_session(self) -> HighConcurrencyLivySession: return self.connection.get_livy_session() def __enter__(self) -> Self: @@ -142,7 +143,7 @@ def cancel(self) -> None: if self._connection is None: return if self._statement_id is not None and self._result is None: - self.get_livy_session()._fabric_api_client.cancel_livy_statement(self._statement_id) + self.get_livy_session().cancel_statement(self._statement_id) self._statement_id = None @property diff --git a/tests/unit/test_fabric_api_client.py b/tests/unit/test_fabric_api_client.py index b1c541f4..25927493 100644 --- a/tests/unit/test_fabric_api_client.py +++ b/tests/unit/test_fabric_api_client.py @@ -340,145 +340,6 @@ def test_raises_when_no_match(self, mock_request, client): client.get_warehouse_id() -class TestLivySessionManagement: - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_get_existing_livy_session_matches_by_name_and_state(self, mock_request, client): - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response( - 200, - { - "items": [ - {"name": "other-session", "livyState": "idle", "id": "s1"}, - {"name": "test-session", "livyState": "idle", "id": "s2"}, - ] - }, - ) - - result = client.get_existing_livy_session() - assert result == "s2" - - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_get_existing_livy_session_returns_none_when_no_match(self, mock_request, client): - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response( - 200, {"items": [{"name": "other-session", "livyState": "idle", "id": "s1"}]} - ) - - result = client.get_existing_livy_session() - assert result is None - - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_get_existing_livy_session_ignores_dead_sessions(self, mock_request, client): - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response( - 200, - {"items": [{"name": "test-session", "livyState": "dead", "id": "s1"}]}, - ) - - result = client.get_existing_livy_session() - assert result is None - - @patch("dbt.adapters.fabric.fabric_api_client.time.sleep") - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_initialize_livy_session_returns_session_id(self, mock_request, mock_sleep, client): - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response(200, {"id": "new-session-id"}) - - result = client.initialize_livy_session() - assert result == "new-session-id" - - @patch("dbt.adapters.fabric.fabric_api_client.time.sleep") - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_initialize_livy_session_retries_on_transient_error( - self, mock_request, mock_sleep, client - ): - client._lakehouse_id = "lh-id" - error_resp = _make_response(500, text="Server Error") - success_resp = _make_response(200, {"id": "new-session-id"}) - mock_request.side_effect = [error_resp, success_resp] - - result = client.initialize_livy_session() - assert result == "new-session-id" - assert mock_request.call_count == 2 - - @patch("dbt.adapters.fabric.fabric_api_client.time.sleep") - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_initialize_livy_session_retries_on_404(self, mock_request, mock_sleep, client): - client._lakehouse_id = "lh-id" - error_resp = _make_response(404, text="Not Found") - success_resp = _make_response(200, {"id": "new-session-id"}) - mock_request.side_effect = [error_resp, success_resp] - - result = client.initialize_livy_session() - assert result == "new-session-id" - - @patch("dbt.adapters.fabric.fabric_api_client.time.sleep") - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_initialize_livy_session_raises_on_non_transient_error( - self, mock_request, mock_sleep, client - ): - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response(400, text="Bad Request") - - with pytest.raises(FabricApiError) as exc_info: - client.initialize_livy_session() - assert exc_info.value.status_code == 400 - - @patch("dbt.adapters.fabric.fabric_api_client.time.sleep") - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_initialize_livy_session_raises_after_max_attempts( - self, mock_request, mock_sleep, client - ): - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response(500, text="Server Error") - - with pytest.raises(FabricApiError): - client.initialize_livy_session() - assert mock_request.call_count == 3 - - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_get_livy_session_id_reuses_cached(self, mock_request, client): - client._livy_session_id = "existing-session" - - result = client.get_livy_session_id() - assert result == "existing-session" - mock_request.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_api_client.time.sleep") - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_get_livy_session_id_creates_when_no_existing(self, mock_request, mock_sleep, client): - client._lakehouse_id = "lh-id" - no_sessions = _make_response(200, {"items": []}) - new_session = _make_response(200, {"id": "new-session-id"}) - mock_request.side_effect = [no_sessions, new_session] - - result = client.get_livy_session_id() - assert result == "new-session-id" - - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_delete_livy_session_clears_cached_id(self, mock_request, client): - client._livy_session_id = "session-to-delete" - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response(200) - - client.delete_livy_session() - - assert client._livy_session_id is None - - @patch("dbt.adapters.fabric.fabric_api_client.requests.request") - def test_delete_livy_session_tolerates_404(self, mock_request, client): - client._livy_session_id = "session-to-delete" - client._lakehouse_id = "lh-id" - mock_request.return_value = _make_response(404, text="Not Found") - - client.delete_livy_session() - assert client._livy_session_id is None - - def test_delete_livy_session_noop_when_no_session(self, client): - client._livy_session_id = None - client.delete_livy_session() - - class TestWarehouseSnapshots: @patch("dbt.adapters.fabric.fabric_api_client.requests.request") def test_get_warehouse_snapshots_filters_by_warehouse(self, mock_request, client): diff --git a/tests/unit/test_fabricspark_cursor.py b/tests/unit/test_fabricspark_cursor.py index 090d125c..955dbd44 100644 --- a/tests/unit/test_fabricspark_cursor.py +++ b/tests/unit/test_fabricspark_cursor.py @@ -6,7 +6,7 @@ import pytest from dbt_common.exceptions import DbtDatabaseError -from dbt.adapters.fabric.fabric_livy_session import LivySessionResult +from dbt.adapters.fabric.livy_result import LivySessionResult from dbt.adapters.fabricspark.fabricspark_cursor import FabricSparkCursor SAMPLE_FIELDS = [ @@ -89,7 +89,7 @@ def test_cancel_with_pending_statement(self): cursor.cancel() - livy_session._fabric_api_client.cancel_livy_statement.assert_called_once_with(42) + livy_session.cancel_statement.assert_called_once_with(42) assert cursor._statement_id is None def test_cancel_noop_when_no_statement(self): diff --git a/tests/unit/test_hc_livy_session.py b/tests/unit/test_hc_livy_session.py new file mode 100644 index 00000000..57901666 --- /dev/null +++ b/tests/unit/test_hc_livy_session.py @@ -0,0 +1,359 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.fabric.fabric_api_client import FabricApiError +from dbt.adapters.fabric.fabric_hc_livy_session import ( + HighConcurrencyLivySession, + derive_session_tag, +) + + +@pytest.fixture +def credentials(): + mock = MagicMock() + mock.spark_session_timeout = 60 + mock.query_timeout = 120 + mock.fabric_base_api_uri = "https://api.fabric.microsoft.com/v1" + return mock + + +@pytest.fixture +def api_client(credentials): + client = MagicMock() + client._credentials = credentials + client.get_workspace_id.return_value = "ws-123" + client.get_lakehouse_id.return_value = "lh-456" + return client + + +@pytest.fixture +def session(api_client): + return HighConcurrencyLivySession(api_client) + + +def _ready_session(session): + session._state.hc_id = "hc-1" + session._state.session_id = "sess-1" + session._state.repl_id = "repl-1" + session._state.is_dead = False + + +class TestDeriveSessionTag: + def test_deterministic(self): + tag1 = derive_session_tag("ws-123", "lh-456") + tag2 = derive_session_tag("ws-123", "lh-456") + assert tag1 == tag2 + + def test_different_inputs_produce_different_tags(self): + tag1 = derive_session_tag("ws-123", "lh-456") + tag2 = derive_session_tag("ws-123", "lh-789") + assert tag1 != tag2 + + def test_prefix(self): + tag = derive_session_tag("ws-123", "lh-456") + assert tag.startswith("dbt-fabricspark-") + + +class TestGetLogsUrl: + def test_builds_url_with_session_id(self, session): + session._state.session_id = "sess-42" + url = session.get_logs_url() + assert "lh-456" in url + assert "sess-42" in url + assert "app.fabric" in url + + def test_uses_unknown_when_no_session(self, session): + url = session.get_logs_url() + assert "unknown" in url + + +class TestWaitForSessionReady: + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_success_path(self, mock_sleep, session, api_client): + api_client.acquire_hc_session.return_value = {"id": "hc-1"} + api_client.get_hc_session.return_value = { + "state": "Idle", + "sessionId": "sess-1", + "replId": "repl-1", + } + + session.wait_for_session_ready() + + assert session._state.hc_id == "hc-1" + assert session._state.session_id == "sess-1" + assert session._state.repl_id == "repl-1" + assert session._state.is_dead is False + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_retries_on_transient_acquire_error(self, mock_sleep, session, api_client): + api_client.acquire_hc_session.side_effect = [ + FabricApiError("POST", "url", 500, "Server Error"), + {"id": "hc-1"}, + ] + api_client.get_hc_session.return_value = { + "state": "Idle", + "sessionId": "sess-1", + "replId": "repl-1", + } + + session.wait_for_session_ready() + + assert api_client.acquire_hc_session.call_count == 2 + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_raises_non_transient_acquire_error(self, mock_sleep, session, api_client): + api_client.acquire_hc_session.side_effect = FabricApiError( + "POST", "url", 400, "Bad Request" + ) + + with pytest.raises(FabricApiError) as exc_info: + session.wait_for_session_ready() + assert exc_info.value.status_code == 400 + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_raises_on_missing_id(self, mock_sleep, session, api_client): + api_client.acquire_hc_session.return_value = {"state": "Starting"} + + with pytest.raises(RuntimeError, match="missing 'id'"): + session.wait_for_session_ready() + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_cleans_up_on_poll_failure(self, mock_sleep, session, api_client): + api_client.acquire_hc_session.return_value = {"id": "hc-leak"} + api_client.get_hc_session.return_value = {"state": "Dead"} + + with pytest.raises(RuntimeError): + session.wait_for_session_ready() + + api_client.delete_hc_session.assert_called_once_with("hc-leak") + assert session._state.hc_id is None + + +class TestPollUntilIdle: + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.time") + def test_polls_until_idle(self, mock_time, mock_sleep, session, api_client): + mock_time.side_effect = [0, 1, 2] + session._state.hc_id = "hc-1" + api_client.get_hc_session.side_effect = [ + {"state": "Starting"}, + {"state": "Idle", "sessionId": "sess-1", "replId": "repl-1"}, + ] + + session._poll_until_idle() + + assert session._state.session_id == "sess-1" + assert session._state.repl_id == "repl-1" + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.time") + def test_raises_on_timeout(self, mock_time, mock_sleep, session, api_client): + mock_time.side_effect = [0, 61] + session._state.hc_id = "hc-1" + + with pytest.raises(TimeoutError, match="spark_session_timeout"): + session._poll_until_idle() + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.time") + def test_raises_on_fatal_state(self, mock_time, mock_sleep, session, api_client): + mock_time.return_value = 0 + session._state.hc_id = "hc-1" + api_client.get_hc_session.return_value = { + "state": "Dead", + "fabricSessionStateInfo": {"errorMessage": "OOM"}, + } + + with pytest.raises(RuntimeError, match="OOM"): + session._poll_until_idle() + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.time") + def test_retries_transient_errors(self, mock_time, mock_sleep, session, api_client): + mock_time.side_effect = [0, 1, 2, 3] + session._state.hc_id = "hc-1" + api_client.get_hc_session.side_effect = [ + FabricApiError("GET", "url", 500, "transient"), + {"state": "Idle", "sessionId": "sess-1", "replId": "repl-1"}, + ] + + session._poll_until_idle() + + assert session._state.session_id == "sess-1" + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.time") + def test_raises_after_max_consecutive_transient_errors( + self, mock_time, mock_sleep, session, api_client + ): + mock_time.return_value = 0 + session._state.hc_id = "hc-1" + api_client.get_hc_session.side_effect = FabricApiError("GET", "url", 500, "transient") + + with pytest.raises(FabricApiError): + session._poll_until_idle() + + assert api_client.get_hc_session.call_count == 5 + + +class TestEnsureRepl: + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_noop_when_healthy(self, mock_sleep, session): + _ready_session(session) + session._ensure_repl() + session._fabric_api_client.acquire_hc_session.assert_not_called() + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_reacquires_when_dead(self, mock_sleep, session, api_client): + _ready_session(session) + session._state.is_dead = True + + api_client.acquire_hc_session.return_value = {"id": "hc-new"} + api_client.get_hc_session.return_value = { + "state": "Idle", + "sessionId": "sess-new", + "replId": "repl-new", + } + + session._ensure_repl() + + api_client.delete_hc_session.assert_called_once_with("hc-1") + assert session._state.hc_id == "hc-new" + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_acquires_when_no_hc_id(self, mock_sleep, session, api_client): + api_client.acquire_hc_session.return_value = {"id": "hc-first"} + api_client.get_hc_session.return_value = { + "state": "Idle", + "sessionId": "sess-1", + "replId": "repl-1", + } + + session._ensure_repl() + + assert session._state.hc_id == "hc-first" + + +class TestRunStatement: + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_submits_sql(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.submit_hc_sql_statement.return_value = 42 + api_client.get_hc_statement.return_value = { + "state": "available", + "output": {"status": "ok", "data": {"application/json": {"rows": []}}}, + } + + result = session.run_statement("SELECT 1", "sql") + + api_client.submit_hc_sql_statement.assert_called_once_with("sess-1", "repl-1", "SELECT 1") + assert result.success is True + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_submits_python(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.submit_hc_python_statement.return_value = 42 + api_client.get_hc_statement.return_value = { + "state": "available", + "output": {"status": "ok", "data": {"application/json": {}}}, + } + + result = session.run_statement("print(1)", "python") + + api_client.submit_hc_python_statement.assert_called_once_with( + "sess-1", "repl-1", "print(1)" + ) + assert result.success is True + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_returns_statement_id_when_not_waiting(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.submit_hc_sql_statement.return_value = 99 + + result = session.run_statement("SELECT 1", "sql", wait_for_result=False) + + assert result == 99 + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_marks_dead_on_404(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.submit_hc_sql_statement.side_effect = FabricApiError( + "POST", "url", 404, "Not Found" + ) + + result = session.run_statement("SELECT 1", "sql") + + assert result.success is False + assert session._state.is_dead is True + + +class TestWaitAndGetStatementResult: + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_success(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.get_hc_statement.return_value = { + "state": "available", + "output": { + "status": "ok", + "data": {"application/json": {"key": "value"}}, + }, + } + + result = session.wait_and_get_statement_result(42) + + assert result.success is True + assert result.statement_id == 42 + assert result.json_data == {"key": "value"} + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_error_statement(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.get_hc_statement.return_value = { + "state": "error", + "output": {"status": "error", "evalue": "division by zero"}, + } + + result = session.wait_and_get_statement_result(42) + + assert result.success is False + assert result.error_message == "division by zero" + + @patch("dbt.adapters.fabric.fabric_hc_livy_session.time.sleep") + def test_marks_dead_on_404(self, mock_sleep, session, api_client): + _ready_session(session) + api_client.get_hc_statement.side_effect = FabricApiError("GET", "url", 404, "Not Found") + + result = session.wait_and_get_statement_result(42) + + assert result.success is False + assert session._state.is_dead is True + + +class TestClose: + def test_deletes_session(self, session, api_client): + _ready_session(session) + session.close() + + api_client.delete_hc_session.assert_called_once_with("hc-1") + assert session._state.hc_id is None + + def test_noop_when_no_session(self, session, api_client): + session.close() + api_client.delete_hc_session.assert_not_called() + + def test_resets_state_even_on_delete_failure(self, session, api_client): + _ready_session(session) + api_client.delete_hc_session.side_effect = Exception("network error") + + session.close() + + assert session._state.hc_id is None + + +class TestCancelStatement: + def test_delegates_to_api_client(self, session, api_client): + _ready_session(session) + session.cancel_statement(42) + + api_client.cancel_hc_statement.assert_called_once_with("sess-1", "repl-1", 42) diff --git a/tests/unit/test_livy_session.py b/tests/unit/test_livy_session.py index b0c6feb2..a6581c0a 100644 --- a/tests/unit/test_livy_session.py +++ b/tests/unit/test_livy_session.py @@ -1,428 +1,4 @@ -import itertools -import json -from unittest.mock import MagicMock, patch - -import pytest -import requests - -from dbt.adapters.fabric.fabric_livy_session import LivySession, LivySessionResult - - -@pytest.fixture -def credentials(): - mock = MagicMock() - mock.spark_session_timeout = 60 - mock.query_timeout = 120 - return mock - - -@pytest.fixture -def fabric_api_client(credentials): - mock = MagicMock() - mock._credentials = credentials - mock.get_lakehouse_id.return_value = "lakehouse-id" - mock.get_livy_session_id.return_value = "session-id" - return mock - - -@pytest.fixture -def session(fabric_api_client): - return LivySession(fabric_api_client) - - -class TestGetLogsUrl: - @pytest.mark.parametrize( - ("api_uri", "expected_host"), - [ - ( - "https://api.fabric.microsoft.com/v1", - "https://app.fabric.microsoft.com", - ), - ( - "https://api.msit.fabric.microsoft.com/v1", - "https://app.msit.fabric.microsoft.com", - ), - ], - ) - def test_derives_portal_url_from_base_api_uri( - self, api_uri, expected_host, session, credentials - ): - credentials.fabric_base_api_uri = api_uri - - url = session.get_logs_url() - - assert url == f"{expected_host}/workloads/de-ds/sparkmonitor/lakehouse-id/session-id" - - -class TestWaitForSessionReady: - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_returns_immediately_when_idle(self, mock_sleep, session, fabric_api_client): - fabric_api_client.get_livy_session_state.return_value = "idle" - - session.wait_for_session_ready() - - fabric_api_client.get_livy_session_state.assert_called_once() - mock_sleep.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_polls_through_non_idle_states( - self, mock_sleep, mock_time, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.side_effect = ["starting", "busy", "idle"] - mock_time.side_effect = itertools.count(0, 10) - - session.wait_for_session_ready() - - assert fabric_api_client.get_livy_session_state.call_count == 3 - assert mock_sleep.call_count == 2 - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_raises_timeout_error_when_session_timeout_exceeded( - self, mock_sleep, mock_time, session, fabric_api_client, credentials - ): - credentials.spark_session_timeout = 30 - fabric_api_client.get_livy_session_state.return_value = "starting" - mock_time.side_effect = itertools.chain([0], itertools.repeat(31)) - - with pytest.raises(TimeoutError, match="did not become idle"): - session.wait_for_session_ready() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_tolerates_transient_errors_below_threshold( - self, mock_sleep, mock_time, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.side_effect = [ - requests.exceptions.ConnectionError("conn refused"), - requests.exceptions.Timeout("timed out"), - requests.exceptions.ChunkedEncodingError("chunked"), - json.JSONDecodeError("bad json", "", 0), - "idle", - ] - mock_time.side_effect = itertools.count(0, 5) - - session.wait_for_session_ready() - - assert fabric_api_client.get_livy_session_state.call_count == 5 - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_reraises_after_max_consecutive_transient_errors( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.side_effect = requests.exceptions.ConnectionError( - "conn refused" - ) - - with pytest.raises(requests.exceptions.ConnectionError): - session.wait_for_session_ready() - - assert fabric_api_client.get_livy_session_state.call_count == 5 - - @pytest.mark.parametrize("fatal_state", ["dead", "killed", "error", "shutting_down"]) - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_raises_immediately_on_fatal_session_state( - self, mock_sleep, fatal_state, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.return_value = fatal_state - - with pytest.raises(RuntimeError, match=f"fatal state '{fatal_state}'"): - session.wait_for_session_ready() - - fabric_api_client.get_livy_session_state.assert_called_once() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_raises_when_session_transitions_to_fatal_state( - self, mock_sleep, mock_time, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.side_effect = ["starting", "busy", "dead"] - mock_time.side_effect = itertools.count(0, 10) - - with pytest.raises(RuntimeError, match="fatal state 'dead'"): - session.wait_for_session_ready() - - assert fabric_api_client.get_livy_session_state.call_count == 3 - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_resets_error_counter_on_success( - self, mock_sleep, mock_time, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.side_effect = [ - requests.exceptions.ConnectionError("err1"), - requests.exceptions.ConnectionError("err2"), - requests.exceptions.ConnectionError("err3"), - requests.exceptions.ConnectionError("err4"), - "starting", - requests.exceptions.ConnectionError("err5"), - requests.exceptions.ConnectionError("err6"), - requests.exceptions.ConnectionError("err7"), - requests.exceptions.ConnectionError("err8"), - "idle", - ] - mock_time.side_effect = itertools.count(0, 5) - - session.wait_for_session_ready() - - assert fabric_api_client.get_livy_session_state.call_count == 10 - - -class TestWaitForStatementReady: - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_returns_when_state_is_available(self, mock_sleep, session, fabric_api_client): - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": {"status": "ok", "data": {}}, - } - - result = session.wait_for_statement_ready(42) - - assert result["state"] == "available" - mock_sleep.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_returns_when_state_is_error(self, mock_sleep, session, fabric_api_client): - fabric_api_client.get_livy_statement.return_value = { - "state": "error", - "output": {"status": "error", "evalue": "something went wrong"}, - } - - result = session.wait_for_statement_ready(42) - - assert result["state"] == "error" - mock_sleep.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_polls_through_non_terminal_states( - self, mock_sleep, mock_time, session, fabric_api_client - ): - fabric_api_client.get_livy_statement.side_effect = [ - {"state": "waiting"}, - {"state": "running"}, - {"state": "available", "output": {"status": "ok"}}, - ] - mock_time.side_effect = itertools.count(0, 10) - - result = session.wait_for_statement_ready(42) - - assert result["state"] == "available" - assert mock_sleep.call_count == 2 - - @pytest.mark.parametrize("terminal_state", ["cancelled", "cancelling"]) - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_returns_when_state_is_cancelled( - self, mock_sleep, terminal_state, session, fabric_api_client - ): - fabric_api_client.get_livy_statement.return_value = { - "state": terminal_state, - "output": {"status": "error", "evalue": "statement was cancelled"}, - } - - result = session.wait_for_statement_ready(42) - - assert result["state"] == terminal_state - mock_sleep.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_raises_timeout_error_when_query_timeout_exceeded( - self, mock_sleep, mock_time, session, fabric_api_client, credentials - ): - credentials.query_timeout = 60 - fabric_api_client.get_livy_statement.return_value = {"state": "running"} - mock_time.side_effect = itertools.chain([0], itertools.repeat(61)) - - with pytest.raises(TimeoutError, match="did not become available"): - session.wait_for_statement_ready(42) - - -class TestWaitAndGetStatementResult: - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_available_with_status_ok_returns_success( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": { - "status": "ok", - "data": {"application/json": {"key": "value"}}, - }, - } - - result = session.wait_and_get_statement_result(7) - - assert result.success is True - assert result.statement_id == 7 - assert result.status_code == "ok" - assert result.json_data == {"key": "value"} - assert result.error_message is None - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_available_with_status_not_ok_returns_failure( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": { - "status": "error", - "evalue": "NameError: name 'x' is not defined", - }, - } - - result = session.wait_and_get_statement_result(7) - - assert result.success is False - assert result.error_message == "NameError: name 'x' is not defined" - assert result.status_code == "error" - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_error_state_returns_failure(self, mock_sleep, session, fabric_api_client): - fabric_api_client.get_livy_statement.return_value = { - "state": "error", - "output": { - "status": "error", - "evalue": "session crashed", - }, - } - - result = session.wait_and_get_statement_result(7) - - assert result.success is False - assert result.error_message == "session crashed" - - @pytest.mark.parametrize("terminal_state", ["cancelled", "cancelling"]) - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_cancelled_without_evalue_uses_state_as_error_message( - self, mock_sleep, terminal_state, session, fabric_api_client - ): - fabric_api_client.get_livy_statement.return_value = { - "state": terminal_state, - "output": {"status": "error"}, - } - - result = session.wait_and_get_statement_result(7) - - assert result.success is False - assert terminal_state in result.error_message - - @patch("dbt.adapters.fabric.fabric_livy_session.time.time") - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_catches_timeout_error_and_returns_failed_result( - self, mock_sleep, mock_time, session, fabric_api_client, credentials - ): - credentials.query_timeout = 10 - fabric_api_client.get_livy_statement.return_value = {"state": "running"} - mock_time.side_effect = itertools.chain([0], itertools.repeat(11)) - - result = session.wait_and_get_statement_result(7) - - assert result.success is False - assert result.statement_id == 7 - assert "did not become available" in result.error_message - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_catches_generic_exception_and_returns_failed_result( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_statement.side_effect = RuntimeError("unexpected failure") - - result = session.wait_and_get_statement_result(7) - - assert result.success is False - assert result.statement_id == 7 - assert "unexpected failure" in result.error_message - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_extracts_json_data_from_output(self, mock_sleep, session, fabric_api_client): - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": { - "status": "ok", - "data": {"application/json": {"rows": [1, 2, 3], "schema": "test"}}, - }, - } - - result = session.wait_and_get_statement_result(7) - - assert result.json_data == {"rows": [1, 2, 3], "schema": "test"} - - -class TestRunStatement: - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_dispatches_sql_to_submit_livy_sql_statement( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.return_value = "idle" - fabric_api_client.submit_livy_sql_statement.return_value = 10 - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": {"status": "ok", "data": {}}, - } - - session.run_statement("SELECT 1", "sql", wait_for_result=True) - - fabric_api_client.submit_livy_sql_statement.assert_called_once_with("SELECT 1") - fabric_api_client.submit_livy_python_statement.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_dispatches_python_to_submit_livy_python_statement( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.return_value = "idle" - fabric_api_client.submit_livy_python_statement.return_value = 11 - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": {"status": "ok", "data": {}}, - } - - session.run_statement("print('hello')", "python", wait_for_result=True) - - fabric_api_client.submit_livy_python_statement.assert_called_once_with("print('hello')") - fabric_api_client.submit_livy_sql_statement.assert_not_called() - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_wait_for_result_true_returns_livy_session_result( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.return_value = "idle" - fabric_api_client.submit_livy_sql_statement.return_value = 10 - fabric_api_client.get_livy_statement.return_value = { - "state": "available", - "output": {"status": "ok", "data": {}}, - } - - result = session.run_statement("SELECT 1", "sql", wait_for_result=True) - - assert isinstance(result, LivySessionResult) - assert result.success is True - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_wait_for_result_false_returns_statement_id( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.return_value = "idle" - fabric_api_client.submit_livy_sql_statement.return_value = 10 - - result = session.run_statement("SELECT 1", "sql", wait_for_result=False) - - assert result == 10 - assert isinstance(result, int) - - @patch("dbt.adapters.fabric.fabric_livy_session.time.sleep") - def test_returns_failed_result_on_submission_error( - self, mock_sleep, session, fabric_api_client - ): - fabric_api_client.get_livy_session_state.return_value = "idle" - fabric_api_client.submit_livy_sql_statement.side_effect = RuntimeError("API down") - - result = session.run_statement("SELECT 1", "sql", wait_for_result=True) - - assert isinstance(result, LivySessionResult) - assert result.success is False - assert "API down" in result.error_message +from dbt.adapters.fabric.livy_result import LivySessionResult class TestLivySessionResultToSubmissionResult: