diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f75447..cead080 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v1.12.0 + +### New Features + +- **High-concurrency Livy support** for true parallel statement execution. Each dbt thread acquires its own REPL inside one underlying Livy session via [Fabric's HC Livy API](https://learn.microsoft.com/en-us/fabric/data-engineering/high-concurrency-livy) (`/highConcurrencySessions` + `/repls/{replId}/statements`). All threads in a process share a deterministic `sessionTag` derived from `(workspaceid, lakehouseid)` when `reuse_session: true`, so Fabric snap-attaches new REPLs onto the still-warm underlying session across runs — observed **3.6× wall-clock speedup** on the 2nd run of the issue's repro (442s → 122s). Singleton mode remains available via `high_concurrency: false`; the new flag defaults to `true` for Fabric mode and is a no-op in local mode. See the new "High-concurrency Livy" section in the README for the `threads > 5` cross-REPL state table (#185, #186) + +### Infrastructure + +- Refactored the Livy backend behind a new `LivyBackend` ABC with two implementations — `singleton_livy.py` (existing single-session path) and `concurrent_livy.py` (new HC path) — selected at connect time by the `high_concurrency` credential. Shared auth/header/retry/lakehouse-property helpers remain in `livysession.py`; the existing class names continue to be re-exported from there for backwards compatibility with downstream importers and the test patch surface (#186) + +--- + ## v1.11.0 ### New Features diff --git a/README.md b/README.md index 9be18ba..ed5c5a7 100644 --- a/README.md +++ b/README.md @@ -327,6 +327,7 @@ Each segment is independently backtick-quoted, so workspace names with spaces or | `reuse_session` | bool | `false` | Keep Livy sessions alive for reuse across runs | | `session_id_file` | string | `./livy-session-id.txt` | Path to file storing session ID for reuse | | `session_idle_timeout` | string | `30m` | Livy session idle timeout (e.g. `30m`, `1h`) | +| `high_concurrency` | bool | `true` | Use high-concurrency Livy API so each dbt thread gets its own REPL — see [High-concurrency Livy](#high-concurrency-livy) | | **Timeouts & Polling** | | | | | `connect_retries` | int | `1` | Number of connection retries | | `connect_timeout` | int | `10` | Connection timeout in seconds | @@ -350,6 +351,51 @@ Each segment is independently backtick-quoted, so workspace names with spaces or | **Service Principal** | `SPN` | CI/CD and automation. Uses Azure AD app registration. | `client_id`, `tenant_id`, `client_secret` | | **Fabric Notebook** | `fabric_notebook` | Running dbt inside a Fabric notebook. Uses `notebookutils.credentials`. | None (runs in Fabric runtime) | +### High-concurrency Livy + +By default the adapter uses Fabric's [high-concurrency Livy API](https://learn.microsoft.com/en-us/fabric/data-engineering/high-concurrency-livy) +(`high_concurrency: true`). Each dbt thread acquires its own HC session — and therefore its own REPL — inside a single underlying Livy session +shared via a deterministic `sessionTag` derived from `(workspaceid, lakehouseid)`. Statements from different REPLs execute in +parallel inside the same Spark application, so increasing `threads` buys us throughput. + +When `reuse_session: true`, the underlying Livy session also stays warm between dbt invocations (until Fabric's +`spark.livy.session.idle.timeout` elapses), so the next run skips Spark cold-start entirely. + +Set `high_concurrency: false` to fall back to the single-session-per-process mode, where one Livy session +serves every thread and statements queue FIFO inside — useful as an escape hatch +when debugging any problems with the high-concurrency API. + +Fabric packs up to **5 REPLs onto one underlying Livy session** (see the +["Limits"](https://learn.microsoft.com/en-us/fabric/data-engineering/high-concurrency-livy#key-concepts) +note in the Microsoft Learn HC Livy docs). With `threads > 5`, dbt still +works correctly — Fabric simply spins up a second underlying Livy session +to host the 6th REPL onwards, and the same `sessionTag` makes future +acquires snap-attach to whichever underlying session has room. + +What that means in practice: + +| Property | Shared across underlying sessions? | +| ----------------------------------------------------- | ---------------------------------- | +| OneLake Delta tables (dbt model outputs) | Yes — same lakehouse storage | +| Catalog / metastore (`SELECT FROM `) | Yes — same Fabric catalog | +| Temp views (`CREATE TEMPORARY VIEW ...`) | No — REPL/session-local | +| Session-level Spark configs (`SET spark.sql.X = ...`) | No | +| Cached datasets / UDFs / broadcast vars | No | + +Because dbt-fabricspark materializations always write permanent Delta / +MLV objects, model-to-model `ref`s resolve correctly regardless of which +underlying session produced or consumes the table. Macros that depend on +session-local state (temp views, in-session configs) are the only ones +that could surprise — none ship with this adapter today. + +Cost tradeoff: each additional underlying Livy session is a separate +Spark cluster billed for the duration of the run plus the +`spark.livy.session.idle.timeout` afterwards. Keep `threads ≤ 5` for the +cheapest profile; raise it only when the extra parallelism beats the +extra compute spend. + +High-concurrency has no effect in local mode as this is a Fabric specific construct. + ### Materialized Lake Views [Materialized lake views](https://learn.microsoft.com/en-us/fabric/data-engineering/materialized-lake-views/overview-materialized-lake-view) are a Fabric-native construct that materializes a SQL query as a Delta table in your lakehouse, with automatic lineage-based refresh managed by Fabric. diff --git a/src/dbt/adapters/fabricspark/__version__.py b/src/dbt/adapters/fabricspark/__version__.py index b6c3033..134ed00 100644 --- a/src/dbt/adapters/fabricspark/__version__.py +++ b/src/dbt/adapters/fabricspark/__version__.py @@ -1 +1 @@ -version = "1.11.0" +version = "1.12.0" diff --git a/src/dbt/adapters/fabricspark/concurrent_livy.py b/src/dbt/adapters/fabricspark/concurrent_livy.py new file mode 100644 index 0000000..1456b78 --- /dev/null +++ b/src/dbt/adapters/fabricspark/concurrent_livy.py @@ -0,0 +1,768 @@ +from __future__ import annotations + +import atexit +import datetime as dt +import hashlib +import json +import re +import threading +import time +import uuid +from types import TracebackType +from typing import Any, Optional + +import requests +from dbt_common.exceptions import DbtDatabaseError, DbtRuntimeError +from dbt_common.utils.encoding import DECIMALS + +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions import FailedToConnectError +from dbt.adapters.fabricspark import livysession as _livy_helpers +from dbt.adapters.fabricspark.credentials import FabricSparkCredentials +from dbt.adapters.fabricspark.livy_backend import LivyBackend +from dbt.adapters.fabricspark.shortcuts import ShortcutClient + +logger = AdapterLogger("Microsoft Fabric-Spark") +NUMBERS = DECIMALS + (int, float) + +# HC sessions whose state transitions through these values have not yet +# produced sessionId/replId; keep polling until state leaves the set. +_ACQUIRING_STATES = frozenset({"NotStarted", "starting", "AcquiringHighConcurrencySession"}) +_TERMINAL_BAD_STATES = frozenset({"Dead", "Killed", "Failed", "Error"}) + + +_active_sessions_lock = threading.Lock() +# All in-flight HighConcurrencySession instances across every dbt thread. +# Used by the atexit handler to DELETE each HC id on process exit so REPL +# slots free up promptly instead of waiting for Fabric's idle reaper. +_active_sessions: "set[HighConcurrencySession]" = set() + + +_session_tag_lock = threading.Lock() +# Deterministic tag per (workspaceid, lakehouseid) when reuse_session is true, +# uuid per process otherwise. Cached at module scope so every per-thread +# manager generates the same tag and Fabric packs every acquire onto the +# same underlying Livy session. +_session_tags: dict[tuple[str, str, bool], str] = {} + + +_shortcuts_done_lock = threading.Lock() +# Process-level guard so OneLake shortcuts are created exactly once per +# (workspaceid, lakehouseid) even when multiple threads acquire HC sessions +# in parallel. +_shortcuts_done: "set[tuple[str, str]]" = set() + + +def _get_headers(credentials: FabricSparkCredentials, tokenPrint: bool = False) -> dict[str, str]: + return _livy_helpers.get_headers(credentials, tokenPrint) + + +def _parse_retry_after(response: requests.Response) -> float: + return _livy_helpers._parse_retry_after(response) + + +def derive_session_tag(credentials: FabricSparkCredentials) -> str: + """Return the sessionTag used by all HC acquires from this process. + + When ``reuse_session`` is true: a deterministic hash of + ``(workspaceid, lakehouseid)`` so successive dbt invocations get packed + onto the same underlying Livy session while it's still warm. Different + profiles targeting the same workspace+lakehouse intentionally collide on + the same tag — they share a Spark cluster, which is the cheapest outcome. + + When ``reuse_session`` is false: a fresh uuid the first time we're asked + in this process, cached thereafter so every per-thread manager sees the + same tag. + """ + key = (credentials.workspaceid or "", credentials.lakehouseid or "", credentials.reuse_session) + with _session_tag_lock: + if key in _session_tags: + return _session_tags[key] + if credentials.reuse_session: + material = f"{credentials.workspaceid}|{credentials.lakehouseid}" + digest = hashlib.sha256(material.encode("utf-8")).hexdigest()[:24] + tag = f"dbt-fabricspark-{digest}" + else: + tag = f"dbt-fabricspark-{uuid.uuid4().hex}" + _session_tags[key] = tag + return tag + + +class HighConcurrencySession: + """Owns the lifecycle of one HC session (= one REPL). + + One instance per dbt thread. Acquires via ``POST /highConcurrencySessions``, + polls until Fabric reports ``Idle`` (which means the underlying Livy + session is up and a REPL has been allocated), then exposes the + ``sessionId`` (underlying Livy id) and ``replId`` for statement + submission. + """ + + def __init__(self, credentials: FabricSparkCredentials, spark_config: dict[str, Any]): + self.credential = credentials + self.spark_config = spark_config + self.connect_url = credentials.lakehouse_endpoint + self.session_tag = derive_session_tag(credentials) + self.hc_id: Optional[str] = None + self.session_id: Optional[str] = None + self.repl_id: Optional[str] = None + self.is_new_session_required = True + # Instance-level flag set by retry helpers when a 404 indicates the + # REPL is gone. Read by HighConcurrencyCursor before submitting the + # next statement so it can transparently re-acquire. + self.is_dead = False + self._lock = threading.Lock() + + def __enter__(self) -> HighConcurrencySession: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, + ) -> bool: + return True + + # ---- acquire --------------------------------------------------------- + + def acquire(self) -> None: + """POST /highConcurrencySessions then poll until Idle. + + On success, ``self.hc_id``, ``self.session_id`` and ``self.repl_id`` + are all populated and the REPL is ready for statement submission. + """ + payload = self._build_acquire_payload() + url = self.connect_url + "/highConcurrencySessions" + logger.debug(f"Acquiring HC session (sessionTag={self.session_tag})") + + response = None + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.post( + url, + data=json.dumps(payload), + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + if response.status_code in (200, 201, 202): + break + # Fabric returns 404 transiently after a lakehouse is + # provisioned before the Livy endpoint is fully wired. + if attempt < max_retries - 1 and ( + response.status_code == 404 or response.status_code >= 500 + ): + wait = 5 * (2**attempt) + logger.warning( + f"HC acquire returned HTTP {response.status_code}, " + f"retrying in {wait}s (attempt {attempt + 1}/{max_retries})" + ) + time.sleep(wait) + continue + response.raise_for_status() + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ) as exc: + if attempt >= max_retries - 1: + raise FailedToConnectError(f"HC session acquire failed: {exc}") from exc + time.sleep(2**attempt) + + if response is None: + raise FailedToConnectError("HC acquire produced no response") + + try: + body = response.json() + except requests.exceptions.JSONDecodeError as exc: + raise FailedToConnectError( + f"HC acquire returned non-JSON response: {response.text}" + ) from exc + + self.hc_id = body.get("id") + if not self.hc_id: + raise FailedToConnectError(f"HC acquire response missing 'id': {body}") + + with _active_sessions_lock: + _active_sessions.add(self) + + self._poll_until_idle() + self.is_new_session_required = False + self.is_dead = False + logger.debug( + f"HC session ready: hc_id={self.hc_id} sessionId={self.session_id} replId={self.repl_id}" + ) + + def _build_acquire_payload(self) -> dict[str, Any]: + cfg = dict(self.spark_config) + # The HC payload accepts the same conf/numExecutors/etc. as the + # singleton /sessions POST — we just add the sessionTag. + payload: dict[str, Any] = {"sessionTag": self.session_tag} + for key in ( + "name", + "conf", + "driverMemory", + "driverCores", + "executorMemory", + "executorCores", + "numExecutors", + "jars", + "files", + "pyFiles", + "archives", + "args", + "className", + "file", + "tags", + "artifactName", + ): + if key in cfg: + payload[key] = cfg[key] + + conf = dict(payload.get("conf") or {}) + if self.credential.environmentId: + conf["spark.fabric.environment.id"] = self.credential.environmentId + if self.credential.session_idle_timeout: + conf["spark.livy.session.idle.timeout"] = self.credential.session_idle_timeout + if conf: + payload["conf"] = conf + return payload + + def _poll_until_idle(self) -> None: + deadline = time.time() + self.credential.session_start_timeout + url = self.connect_url + "/highConcurrencySessions/" + self.hc_id + + while True: + if time.time() > deadline: + raise FailedToConnectError( + f"Timeout ({self.credential.session_start_timeout}s) waiting for HC session " + f"{self.hc_id} to become Idle. Increase `session_start_timeout` in profiles.yml." + ) + try: + resp = requests.get( + url, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + body = resp.json() + except ( + requests.exceptions.RequestException, + requests.exceptions.JSONDecodeError, + ) as exc: + logger.warning( + f"Transient error polling HC session {self.hc_id}: {exc}; " + f"retrying in {self.credential.poll_wait}s" + ) + time.sleep(self.credential.poll_wait) + continue + + state = body.get("state", "") + session_id = body.get("sessionId") + repl_id = body.get("replId") + + if state in _TERMINAL_BAD_STATES: + err = body.get("fabricSessionStateInfo", {}).get("errorMessage") or state + raise FailedToConnectError(f"HC session {self.hc_id} state={state}: {err}") + + if state == "Idle" and session_id and repl_id: + self.session_id = session_id + self.repl_id = repl_id + return + + if state not in _ACQUIRING_STATES and state != "Idle": + logger.debug(f"HC session {self.hc_id} in unfamiliar state '{state}', polling on") + + time.sleep(self.credential.poll_wait) + + # ---- statement URLs -------------------------------------------------- + + def statements_url(self) -> str: + return ( + self.connect_url + + "/highConcurrencySessions/" + + self.session_id + + "/repls/" + + self.repl_id + + "/statements" + ) + + # ---- release --------------------------------------------------------- + + def delete(self) -> None: + """DELETE /highConcurrencySessions/{hc_id}; best-effort. + + Deletes only this HC id; the underlying Livy session continues to host + any other REPLs in the same packing group and is reaped by Fabric on + idle timeout. + """ + if not self.hc_id: + return + try: + res = requests.delete( + self.connect_url + "/highConcurrencySessions/" + self.hc_id, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + if res.status_code in (200, 202, 204, 404): + logger.debug(f"Released HC session {self.hc_id} (HTTP {res.status_code})") + else: + logger.warning(f"HC session delete returned HTTP {res.status_code}: {res.text}") + except Exception as ex: + logger.warning(f"Failed to delete HC session {self.hc_id}: {ex}") + finally: + with _active_sessions_lock: + _active_sessions.discard(self) + self.hc_id = None + self.session_id = None + self.repl_id = None + self.is_new_session_required = True + + +class HighConcurrencyCursor: + """Cursor backed by one HC REPL. Mirrors :class:`LivyCursor`'s surface. + + The HC statement-result payload uses the same JSON envelope as singleton + Livy (``output.data.application/json.{schema,data}``), so the parsing and + fetch* helpers are intentionally aligned. + """ + + def __init__(self, credential: FabricSparkCredentials, hc_session: HighConcurrencySession): + self.credential = credential + self.connect_url = credential.lakehouse_endpoint + self.hc_session = hc_session + self._rows: Optional[list] = None + self._schema: Optional[list] = None + self._fetch_index = 0 + + def __enter__(self) -> HighConcurrencyCursor: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, + ) -> bool: + self.close() + return True + + @property + def description( + self, + ) -> list[tuple[str, str, None, None, None, None, bool]]: + if self._schema is None: + return [] + return [ + ( + field["name"], + field["type"], + None, + None, + None, + None, + field["nullable"], + ) + for field in self._schema + ] + + def close(self) -> None: + self._rows = None + + # ---- submit + poll --------------------------------------------------- + + def _ensure_repl(self) -> None: + """Re-acquire this thread's HC session if it was marked dead. + + Called before every statement submit so that 404s on a stale REPL + recover transparently. Only acts when ``is_dead`` or + ``is_new_session_required`` is set. + """ + if self.hc_session.is_dead or self.hc_session.is_new_session_required: + logger.debug("HC REPL marked stale — re-acquiring") + self.hc_session.acquire() + + def _submit(self, code: str) -> requests.Response: + self._ensure_repl() + url = self.hc_session.statements_url() + data = {"code": code, "kind": "sql"} + logger.debug(f"Submitted: {data} {url}") + + max_retries = 5 + res = None + for attempt in range(max_retries): + try: + res = requests.post( + url, + data=json.dumps(data), + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + except ( + requests.exceptions.SSLError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ChunkedEncodingError, + ) as exc: + if attempt >= max_retries - 1: + raise DbtRuntimeError( + f"HC statement submit failed after {max_retries} retries: {exc}" + ) + wait = 2**attempt + logger.debug( + f"HC statement submit got transient network error " + f"({type(exc).__name__}), retrying in {wait}s" + ) + time.sleep(wait) + continue + if res.status_code == 429: + retry_after = _parse_retry_after(res) + wait = max(retry_after, 2**attempt) + logger.debug(f"HC statement submit got HTTP 429, retrying in {wait:.0f}s") + time.sleep(wait) + continue + if res.status_code < 500: + break + if attempt < max_retries - 1: + wait = 2**attempt + logger.debug( + f"HC statement submit got HTTP {res.status_code}, retrying in {wait}s" + ) + time.sleep(wait) + + if res.status_code >= 400: + if res.status_code == 404: + # The REPL or underlying session is gone — flag this thread's + # HC session for re-acquisition; the next add_query retry on + # the dbt side will rebuild it transparently. + self.hc_session.is_dead = True + self.hc_session.is_new_session_required = True + logger.debug("HC statement submit returned 404 — flagging REPL for re-acquire") + raise DbtRuntimeError( + f"HC statement submit failed (HTTP {res.status_code}): {res.text}" + ) + + body = res.json() + if "id" not in body: + raise DbtRuntimeError( + f"HC statement submit returned unexpected response (missing 'id'): {body}" + ) + return res + + def _poll(self, submit_response: requests.Response) -> dict: + body = submit_response.json() + statement_id = repr(body["id"]) + url = self.hc_session.statements_url() + "/" + statement_id + + deadline = ( + (time.time() + self.credential.statement_timeout) + if self.credential.statement_timeout > 0 + else None + ) + consecutive_failures = 0 + max_poll_retries = 30 + poll_interval = 0.3 + poll_cap = max(self.credential.poll_statement_wait * 3, 1.5) + not_found_retries = 0 + max_not_found_retries = 20 + + while True: + if deadline is not None and time.time() > deadline: + raise DbtDatabaseError( + f"Timeout ({self.credential.statement_timeout}s) waiting for HC statement " + f"{statement_id}. Increase `statement_timeout` in profiles.yml." + ) + try: + resp = requests.get( + url, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + except ( + requests.exceptions.SSLError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ChunkedEncodingError, + ) as exc: + consecutive_failures += 1 + if consecutive_failures > max_poll_retries: + raise DbtRuntimeError( + f"HC statement poll failed after {max_poll_retries} retries: {exc}" + ) + wait = min(2 ** (consecutive_failures - 1), 30) + logger.debug(f"HC statement poll got transient error, retrying in {wait}s") + time.sleep(wait) + continue + if resp.status_code == 429: + consecutive_failures += 1 + retry_after = _parse_retry_after(resp) + wait = max(retry_after, 2 ** (consecutive_failures - 1)) + logger.debug(f"HC statement poll got HTTP 429, retrying in {wait:.0f}s") + time.sleep(wait) + if consecutive_failures > max_poll_retries: + raise DbtRuntimeError( + f"HC statement poll failed after {max_poll_retries} retries (HTTP 429)" + ) + continue + if resp.status_code >= 500: + consecutive_failures += 1 + if consecutive_failures <= max_poll_retries: + wait = 2 ** (consecutive_failures - 1) + logger.debug( + f"HC statement poll got HTTP {resp.status_code}, retrying in {wait}s" + ) + time.sleep(wait) + continue + raise DbtRuntimeError( + f"HC statement poll failed after {max_poll_retries} retries " + f"(HTTP {resp.status_code}): {resp.text}" + ) + if resp.status_code == 404 and not_found_retries < max_not_found_retries: + not_found_retries += 1 + wait = min(0.3 * (2.0 ** (not_found_retries - 1)), 5.0) + logger.debug( + f"HC statement poll got HTTP 404, retrying in {wait:.2f}s " + f"(not-found {not_found_retries}/{max_not_found_retries})" + ) + time.sleep(wait) + continue + if resp.status_code >= 400: + if resp.status_code == 404: + self.hc_session.is_dead = True + self.hc_session.is_new_session_required = True + raise DbtRuntimeError( + f"HC statement poll failed (HTTP {resp.status_code}): {resp.text}" + ) + consecutive_failures = 0 + + body = resp.json() + if "state" not in body: + raise DbtRuntimeError( + f"HC statement poll returned unexpected response (missing 'state'): {body}" + ) + + if body["state"] == "available": + return body + if body["state"] in ("error", "cancelled", "cancelling"): + error_msg = body.get("output", {}).get("evalue", "Unknown error") + raise DbtDatabaseError( + f"Statement {statement_id} failed with state '{body['state']}': {error_msg}" + ) + time.sleep(poll_interval) + poll_interval = min(poll_interval * 1.5, poll_cap) + + @staticmethod + def _strip_block_comments(sql: str) -> str: + return re.sub(r"\s*/\*(.|\n)*?\*/\s*", "\n", sql, re.DOTALL).strip() + + def execute(self, sql: str, *parameters: Any) -> None: + if len(parameters) > 0: + sql = sql % parameters + self._fetch_index = 0 + + code = self._strip_block_comments(sql) + result = self._poll(self._submit(code)) + logger.debug(result) + + output = result.get("output", {}) + if output.get("status") == "ok": + data = output.get("data", {}) + payload = data.get("application/json") + if isinstance(payload, dict) and "data" in payload: + self._rows = payload["data"] + self._schema = payload.get("schema", {}).get("fields", []) + else: + # DDL / DML or unexpected envelope — produce an empty result set + self._rows = [] + self._schema = [] + else: + self._rows = None + self._schema = None + raise DbtDatabaseError( + "Error while executing query: " + output.get("evalue", "") + ) + + def fetchall(self): + return self._rows + + def fetchmany(self, size=None): + if self._rows is None: + return None + if size is None: + return self._rows + return self._rows[:size] + + def fetchone(self): + if self._rows is not None and self._fetch_index < len(self._rows): + row = self._rows[self._fetch_index] + self._fetch_index += 1 + return row + return None + + +class HighConcurrencyConnection: + """DB-API-shaped connection backed by a single HC REPL.""" + + def __init__(self, credentials: FabricSparkCredentials, hc_session: HighConcurrencySession): + self.credential = credentials + self.connect_url = credentials.lakehouse_endpoint + self.hc_session = hc_session + self._cursor = HighConcurrencyCursor(credentials, hc_session) + + def get_session_id(self) -> Optional[str]: + return self.hc_session.session_id + + def get_headers(self) -> dict[str, str]: + return _get_headers(self.credential, False) + + def get_connect_url(self) -> str: + return self.connect_url + + def cursor(self) -> HighConcurrencyCursor: + return self._cursor + + def close(self) -> None: + logger.debug("HC Connection.close()") + self._cursor.close() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, + ) -> bool: + self.close() + return True + + +def _maybe_create_shortcuts(credentials: FabricSparkCredentials) -> None: + """Create OneLake shortcuts once per process per (workspace, lakehouse).""" + if not credentials.create_shortcuts: + return + key = (credentials.workspaceid or "", credentials.lakehouseid or "") + with _shortcuts_done_lock: + if key in _shortcuts_done: + return + _shortcuts_done.add(key) + + # Force a header build so the module-level accessToken is populated + # before instantiating ShortcutClient. + _ = _get_headers(credentials, False) + + try: + shortcut_client = ShortcutClient( + _livy_helpers.accessToken.token, + credentials.workspaceid, + credentials.lakehouseid, + credentials.endpoint, + ) + shortcut_client.create_shortcuts(credentials.shortcuts_json_str) + except Exception as ex: + logger.error(f"Unable to create shortcuts: {ex}") + + +class HighConcurrencySessionManager(LivyBackend): + """Per-dbt-thread backend. One instance owns one HC session = one REPL. + + Acquires lazily on the first :meth:`connect` call; cleanup happens in + :meth:`disconnect` (called explicitly by `connections.cleanup_all` or via + the module-level atexit handler). + """ + + def __init__(self) -> None: + self._hc_session: Optional[HighConcurrencySession] = None + self._connection: Optional[HighConcurrencyConnection] = None + + def connect(self, credentials: FabricSparkCredentials) -> HighConcurrencyConnection: # type: ignore[override] + if self._hc_session is None or self._hc_session.is_new_session_required: + self._hc_session = HighConcurrencySession(credentials, credentials.spark_config) + self._hc_session.acquire() + _maybe_create_shortcuts(credentials) + self._connection = HighConcurrencyConnection(credentials, self._hc_session) + return self._connection # type: ignore[return-value] + + def disconnect(self) -> None: # type: ignore[override] + """Release this thread's HC id. The underlying Livy session lives on.""" + if self._hc_session is not None: + self._hc_session.delete() + self._hc_session = None + self._connection = None + + +class HighConcurrencyConnectionWrapper(object): + """DB-API connection wrapper used by ``FabricSparkConnectionManager``. + + Surface is intentionally identical to + :class:`dbt.adapters.fabricspark.singleton_livy.LivySessionConnectionWrapper` + so the rest of the SQL connection manager doesn't know which backend + produced the handle. + """ + + def __init__(self, handle: HighConcurrencyConnection): + self.handle = handle + self._cursor: Optional[HighConcurrencyCursor] = None + + def cursor(self) -> HighConcurrencyConnectionWrapper: + self._cursor = self.handle.cursor() + return self + + def cancel(self): + logger.debug("NotImplemented: cancel") + + def close(self): + self.handle.close() + + def rollback(self, *args, **kwargs): + logger.debug("NotImplemented: rollback") + + def fetchall(self): + return self._cursor.fetchall() + + def fetchmany(self, size=None): + return self._cursor.fetchmany(size) + + def fetchone(self): + return self._cursor.fetchone() + + def execute(self, sql, bindings=None): + if sql.strip().endswith(";"): + sql = sql.strip()[:-1] + if bindings is None: + self._cursor.execute(sql) + else: + bindings = [self._fix_binding(b) for b in bindings] + self._cursor.execute(sql, *bindings) + + @property + def description(self): + return self._cursor.description + + @classmethod + def _fix_binding(cls, value) -> float | str: + if isinstance(value, NUMBERS): + return float(value) + elif isinstance(value, dt.datetime): + return f"'{value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}'" + elif value is None: + return "''" + else: + escaped = str(value).replace("'", "\\'") + return f"'{escaped}'" + + +def _atexit_cleanup_hc() -> None: + """DELETE every still-active HC session on process exit. + + Iterates ``_active_sessions`` rather than relying on + ``connection_managers`` in ``connections.py``, which can be cleared by + ``cleanup_all`` before exit. + """ + with _active_sessions_lock: + sessions = list(_active_sessions) + for s in sessions: + try: + s.delete() + except Exception as ex: + logger.debug(f"atexit HC delete failed for {s.hc_id}: {ex}") + + +atexit.register(_atexit_cleanup_hc) diff --git a/src/dbt/adapters/fabricspark/connections.py b/src/dbt/adapters/fabricspark/connections.py index 26a11dc..8c0d10e 100644 --- a/src/dbt/adapters/fabricspark/connections.py +++ b/src/dbt/adapters/fabricspark/connections.py @@ -27,6 +27,10 @@ from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus from dbt.adapters.exceptions import FailedToConnectError +from dbt.adapters.fabricspark.concurrent_livy import ( + HighConcurrencyConnectionWrapper, + HighConcurrencySessionManager, +) from dbt.adapters.fabricspark.livysession import ( LivySessionConnectionWrapper, LivySessionManager, @@ -171,10 +175,16 @@ def open(cls, connection: Connection) -> Connection: try: if creds.method == FabricSparkConnectionMethod.LIVY: thread_id = cls.get_thread_identifier() + use_hc = creds.high_concurrency and not creds.is_local_mode if thread_id not in cls.connection_managers: - cls.connection_managers[thread_id] = LivySessionManager() - handle = LivySessionConnectionWrapper( - cls.connection_managers[thread_id].connect(creds) + cls.connection_managers[thread_id] = ( + HighConcurrencySessionManager() if use_hc else LivySessionManager() + ) + raw_handle = cls.connection_managers[thread_id].connect(creds) + handle = ( + HighConcurrencyConnectionWrapper(raw_handle) + if use_hc + else LivySessionConnectionWrapper(raw_handle) ) connection.state = ConnectionState.OPEN @@ -241,7 +251,17 @@ def cleanup_all(self) -> None: Connections must persist because the Livy session is shared. Sessions are deleted on process exit via an atexit handler registered in LivySessionManager. + + For HC backends, however, each per-thread manager owns its own HC + session id; releasing those promptly here frees REPL slots so the + underlying Livy session can host new acquirers without bumping into + the 5-REPL packing cap. """ + for manager in self.connection_managers.values(): + try: + manager.disconnect() + except Exception as ex: + logger.debug(f"connection manager disconnect raised: {ex}") self.connection_managers.clear() @classmethod diff --git a/src/dbt/adapters/fabricspark/credentials.py b/src/dbt/adapters/fabricspark/credentials.py index eaff5ba..259ecfb 100644 --- a/src/dbt/adapters/fabricspark/credentials.py +++ b/src/dbt/adapters/fabricspark/credentials.py @@ -68,6 +68,15 @@ class FabricSparkCredentials(Credentials): reuse_session: bool = False # When True, Fabric sessions are kept alive and reused across runs session_idle_timeout: str = "30m" # Livy session idle timeout (e.g. "30m", "1h") + # High-concurrency Livy. When True (default), each dbt thread acquires + # its own REPL inside a single underlying Livy session shared via a + # deterministic sessionTag. Statements from different REPLs execute in + # parallel inside the Spark application. When False, falls back to the + # legacy single-session-per-process behaviour where statements queue + # FIFO inside the default Spark scheduling pool. + # Has no effect in local mode (livy_mode=local). + high_concurrency: bool = True + # Livy session stability settings http_timeout: int = 120 # seconds for each HTTP request to Fabric API session_start_timeout: int = 600 # max seconds to wait for session start (10 min) diff --git a/src/dbt/adapters/fabricspark/livy_backend.py b/src/dbt/adapters/fabricspark/livy_backend.py new file mode 100644 index 0000000..ba1332f --- /dev/null +++ b/src/dbt/adapters/fabricspark/livy_backend.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from dbt.adapters.fabricspark.credentials import FabricSparkCredentials + + +class LivyBackend(ABC): + """Pluggable Livy backend. + + Two implementations live in this package: + + - :class:`dbt.adapters.fabricspark.singleton_livy.LivySessionManager` — + one Livy session per process; statements run sequentially inside that + session's single interpreter. + - :class:`dbt.adapters.fabricspark.concurrent_livy.HighConcurrencySessionManager` — + one HC session (= one REPL) per dbt thread, all sharing one underlying + Livy session via a deterministic ``sessionTag``. Different REPLs run in + parallel inside the same Spark application. + + Selection is driven by ``FabricSparkCredentials.high_concurrency``. + ``open()`` in :mod:`connections` instantiates one backend per thread and + calls :meth:`connect` to obtain a DB-API-shaped connection wrapper. + """ + + @abstractmethod + def connect(self, credentials: FabricSparkCredentials) -> Any: + """Acquire (or reuse) a Livy session/REPL and return a connection handle. + + The returned object must expose ``cursor()`` and ``close()`` methods + plus the cursor surface used by the SQL connection manager + (``execute``, ``fetchall``, ``fetchmany``, ``fetchone``, ``description``). + """ + + @abstractmethod + def disconnect(self) -> None: + """Release backend-owned resources for this instance. + + Singleton mode keeps the underlying Livy session alive when + ``reuse_session`` is true; HC mode always deletes its per-thread HC + session so the REPL slot frees up immediately. + """ diff --git a/src/dbt/adapters/fabricspark/livysession.py b/src/dbt/adapters/fabricspark/livysession.py index 76f684c..3638daa 100644 --- a/src/dbt/adapters/fabricspark/livysession.py +++ b/src/dbt/adapters/fabricspark/livysession.py @@ -1,6 +1,5 @@ from __future__ import annotations -import atexit import datetime as dt import importlib import json @@ -8,23 +7,17 @@ import re import threading import time -from types import TracebackType from typing import Any, Optional import requests from azure.core.credentials import AccessToken, TokenCredential from azure.identity import AzureCliCredential, ClientSecretCredential -from dbt_common.exceptions import DbtDatabaseError, DbtRuntimeError -from dbt_common.utils.encoding import DECIMALS -from requests.models import Response +from dbt_common.exceptions import DbtRuntimeError from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.exceptions import FailedToConnectError from dbt.adapters.fabricspark.credentials import FabricSparkCredentials -from dbt.adapters.fabricspark.shortcuts import ShortcutClient logger = AdapterLogger("Microsoft Fabric-Spark") -NUMBERS = DECIMALS + (int, float) livysession_credentials: FabricSparkCredentials @@ -34,9 +27,6 @@ FABRIC_NOTEBOOK_CREDENTIAL_SCOPE = "pbi" accessToken: AccessToken = None -# Global lock to ensure thread-safe session creation/reuse -_session_lock = threading.Lock() - # Global lock to ensure thread-safe token refresh _token_lock = threading.Lock() @@ -475,1132 +465,28 @@ def get_lakehouse_properties(credentials: FabricSparkCredentials) -> dict: return {} -class LivySession: - def __init__(self, credentials: FabricSparkCredentials): - self.credential = credentials - self.connect_url = credentials.lakehouse_endpoint - self.session_id = None - self.is_new_session_required = True - self.is_local_mode = credentials.is_local_mode - - def __enter__(self) -> LivySession: - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: Exception | None, - exc_tb: TracebackType | None, - ) -> bool: - return True - - def try_reuse_session(self, session_id: str) -> bool: - """Try to reuse an existing session by ID. - - Checks if the session exists in Livy and is in a usable state. - - Parameters - ---------- - session_id : str - The session ID to try to reuse. - - Returns - ------- - bool - True if session was successfully reused, False otherwise. - """ - try: - logger.debug(f"Attempting to reuse existing session: {session_id}") - self.session_id = session_id - - # Check if session exists and is valid - res = requests.get( - self.connect_url + "/sessions/" + session_id, - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ) - - # If session doesn't exist (404 or other error), return False - if res.status_code != 200: - logger.debug(f"Session {session_id} not found (status: {res.status_code})") - self.session_id = None - return False - - res_json = res.json() - - # Check session state - invalid_states = ["dead", "shutting_down", "killed", "error", "not_found"] - - if self.is_local_mode: - current_state = res_json.get("state", "dead") - top_level_state = current_state - else: - # Fabric mode: check both top-level state and livyInfo - # When session is starting, livyInfo may not exist yet - top_level_state = res_json.get("state", "") - livy_info = res_json.get("livyInfo", {}) - current_state = livy_info.get("currentState", "") - - # If livyInfo doesn't exist yet but top-level state shows starting, it's still valid - if not current_state and top_level_state in ("starting", "not_started"): - current_state = top_level_state - - if current_state in invalid_states: - logger.debug(f"Session {session_id} is in invalid state: {current_state}") - self.session_id = None - return False - - # Check if session is idle (ready to use) or starting - if self.is_local_mode: - if current_state == "idle": - logger.info(f"Successfully reusing existing Livy session: {session_id}") - self.is_new_session_required = False - return True - elif current_state in ("starting", "not_started", "busy"): - # Wait for session to become idle - logger.debug(f"Session {session_id} is {current_state}, waiting...") - self._wait_for_existing_session(session_id) - logger.info(f"Successfully reusing existing Livy session: {session_id}") - self.is_new_session_required = False - return True - else: - if current_state == "idle": - logger.info(f"Successfully reusing existing Livy session: {session_id}") - self.is_new_session_required = False - return True - elif current_state in ("starting", "not_started", "busy") or top_level_state in ( - "starting", - "not_started", - ): - logger.debug( - f"Session {session_id} is {current_state} (top: {top_level_state}), waiting..." - ) - self._wait_for_existing_session(session_id) - logger.info(f"Successfully reusing existing Livy session: {session_id}") - self.is_new_session_required = False - return True - - logger.debug(f"Session {session_id} in unexpected state: {current_state}") - self.session_id = None - return False - - except requests.exceptions.RequestException as ex: - logger.debug(f"Error checking session {session_id}: {ex}") - self.session_id = None - return False - except Exception as ex: - logger.debug(f"Unexpected error reusing session {session_id}: {ex}") - self.session_id = None - return False - - def _wait_for_existing_session(self, session_id: str) -> None: - """Wait for an existing session to become idle.""" - deadline = time.time() + self.credential.session_start_timeout - - while time.time() < deadline: - res = requests.get( - self.connect_url + "/sessions/" + session_id, - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ).json() - - if self.is_local_mode: - state = res.get("state", "") - if state == "idle": - return - elif state in ("dead", "error", "killed"): - raise FailedToConnectError(f"Session {session_id} died while waiting") - else: - logger.debug(f"Session {session_id} is {state}, waiting...") - else: - # Fabric mode: check both top-level state and livyInfo - top_level_state = res.get("state", "") - livy_info = res.get("livyInfo", {}) - livy_state = livy_info.get("currentState", "") - - if livy_state == "idle": - return - elif livy_state in ("dead", "error", "killed") or top_level_state in ( - "dead", - "error", - "killed", - ): - raise FailedToConnectError(f"Session {session_id} died while waiting") - else: - # Session still starting or in transition - logger.debug( - f"Session {session_id} state: top={top_level_state}, livy={livy_state}, waiting..." - ) - - time.sleep(self.credential.poll_wait) - - raise FailedToConnectError( - f"Timeout ({self.credential.session_start_timeout}s) waiting for session {session_id} to become idle" - ) - - def create_session(self, spark_config) -> str: - # Create sessions with retry for transient 404 and 5xx errors. - # Fabric sometimes returns 404 on the /sessions endpoint right after - # a lakehouse is provisioned before the Livy feature becomes available. - response = None - logger.debug("Creating Livy session (this may take a few minutes)") - - # For local Livy, we need to use "kind" parameter instead of "name" - if self.is_local_mode: - # Local Livy expects {"kind": "sql"} or {"kind": "spark"} - session_data = {"kind": "sql"} - if "kind" in spark_config: - session_data["kind"] = spark_config["kind"] - else: - session_data = spark_config - - max_create_retries = 5 - for attempt in range(max_create_retries): - try: - response = requests.post( - self.connect_url + "/sessions", - data=json.dumps(session_data), - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ) - if response.status_code in (200, 201, 202): - logger.debug("Initiated Livy Session...") - break # Success — exit retry loop - # Retry on 404 (Fabric Livy endpoint transiently unavailable) - # and 5xx server errors; give up on the last attempt. - if attempt < max_create_retries - 1 and ( - response.status_code == 404 or response.status_code >= 500 - ): - wait_time = 5 * (2**attempt) # 5, 10, 20, 40 s - logger.warning( - f"Livy session create returned HTTP {response.status_code}, " - f"retrying in {wait_time}s (attempt {attempt + 1}/{max_create_retries})" - ) - time.sleep(wait_time) - continue - response.raise_for_status() - except requests.exceptions.ConnectionError as c_err: - err_detail = c_err.response.json() if c_err.response else str(c_err) - raise Exception("Connection Error :", err_detail) - except requests.exceptions.HTTPError as h_err: - err_detail = h_err.response.json() if h_err.response else str(h_err) - raise Exception("Http Error: ", err_detail) - except requests.exceptions.Timeout as t_err: - err_detail = t_err.response.json() if t_err.response else str(t_err) - raise Exception("Timeout Error: ", err_detail) - except requests.exceptions.RequestException as a_err: - err_detail = a_err.response.json() if a_err.response else str(a_err) - raise Exception("Authorization Error: ", err_detail) - except Exception as ex: - raise Exception(ex) from ex - - if response is None: - raise Exception("Invalid response from Livy server") - - self.session_id = None - try: - self.session_id = str(response.json()["id"]) - except requests.exceptions.JSONDecodeError as json_err: - raise Exception("Json decode error to get session_id") from json_err - - # Wait for the session to start - self.wait_for_session_start() - - logger.debug("Livy session created successfully") - return self.session_id - - def wait_for_session_start(self) -> None: - """Wait for the Livy session to reach the 'idle' state.""" - deadline = time.time() + self.credential.session_start_timeout - while True: - if time.time() > deadline: - raise FailedToConnectError( - f"Timeout ({self.credential.session_start_timeout}s) waiting for session " - f"{self.session_id} to start. Increase `session_start_timeout` in profiles.yml." - ) - try: - response = requests.get( - self.connect_url + "/sessions/" + self.session_id, - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ) - res = response.json() - except ( - requests.exceptions.RequestException, - requests.exceptions.JSONDecodeError, - ) as exc: - # Transient network error or non-JSON response (e.g. 429/5xx from Fabric - # under heavy load). Log and retry on the next poll interval rather than - # crashing the whole session-start wait. - logger.warning( - f"Transient error polling session {self.session_id} status: {exc}; " - f"will retry in {self.credential.poll_wait}s" - ) - time.sleep(self.credential.poll_wait) - continue - - # Local Livy uses "state" directly, Fabric uses "livyInfo.currentState" - if self.is_local_mode: - state = res.get("state", "") - if state in ("starting", "not_started"): - time.sleep(self.credential.poll_wait) - elif state == "idle": - logger.debug(f"New livy session id is: {self.session_id}, {res}") - self.is_new_session_required = False - break - elif state in ("dead", "error"): - logger.error("ERROR, cannot create a livy session") - raise FailedToConnectError("failed to connect") - else: - # Fabric Livy: check top-level state first - # When session is starting, "livyInfo" may not exist yet - top_level_state = res.get("state", "") - livy_info = res.get("livyInfo", {}) - livy_state = livy_info.get("currentState", "") - - if top_level_state in ("starting", "not_started"): - # Session still starting, continue polling - logger.debug(f"Session {self.session_id} is {top_level_state}, waiting...") - time.sleep(self.credential.poll_wait) - elif livy_state == "idle": - logger.debug(f"New livy session id is: {self.session_id}, {res}") - self.is_new_session_required = False - break - elif livy_state == "dead" or top_level_state == "dead": - logger.error("ERROR, cannot create a livy session") - raise FailedToConnectError("failed to connect") - else: - # Unknown state, keep waiting (could be transitioning) - logger.debug( - f"Session {self.session_id} in state: top={top_level_state}, livy={livy_state}, waiting..." - ) - time.sleep(self.credential.poll_wait) - - def delete_session(self) -> None: - try: - # delete the session_id - res = requests.delete( - self.connect_url + "/sessions/" + self.session_id, - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ) - if res.status_code == 200: - logger.debug(f"Closed the livy session: {self.session_id}") - else: - res.raise_for_status() - - except Exception as ex: - logger.error(f"Unable to close the livy session {self.session_id}, error: {ex}") - - def is_valid_session(self) -> bool: - if self.session_id is None: - logger.error("Session ID is None") - return False - try: - res = requests.get( - self.connect_url + "/sessions/" + self.session_id, - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ).json() - except Exception as ex: - logger.debug(f"is_valid_session HTTP error: {ex}") - return False - - # we can reuse the session so long as it is not dead, killed, or being shut down - invalid_states = ["dead", "shutting_down", "killed", "error"] - - # Local Livy uses "state" directly, Fabric uses "livyInfo.currentState" - if self.is_local_mode: - current_state = res.get("state", "dead") - else: - # Fabric mode: check both top-level state and livyInfo - # When session is starting, livyInfo may not exist yet - top_level_state = res.get("state", "") - livy_info = res.get("livyInfo", {}) - current_state = livy_info.get("currentState", "") - - # If livyInfo doesn't exist yet but top-level state is valid, use that - if not current_state: - current_state = top_level_state if top_level_state else "dead" - - return current_state not in invalid_states - - -# cursor object - wrapped for livy API -class LivyCursor: - """ - Mock a pyodbc cursor. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor - """ - - def __init__(self, credential, livy_session) -> None: - self._rows = None - self._schema = None - self._fetch_index = 0 - self.credential = credential - self.connect_url = credential.lakehouse_endpoint - self.session_id = livy_session.session_id - self.livy_session = livy_session - self.is_local_mode = credential.is_local_mode - - def __enter__(self) -> LivyCursor: - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: Exception | None, - exc_tb: TracebackType | None, - ) -> bool: - self.close() - return True - - @property - def description( - self, - ) -> list[tuple[str, str, None, None, None, None, bool]]: - """ - Get the description. - - Returns - ------- - out : list[tuple[str, str, None, None, None, None, bool]] - The description. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#description - """ - if self._schema is None: - description = list() - else: - description = [ - ( - field["name"], - field["type"], # field['dataType'], - None, - None, - None, - None, - field["nullable"], - ) - for field in self._schema - ] - return description - - def close(self) -> None: - """ - Close the connection. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#close - """ - self._rows = None - - def _submitLivyCode(self, code) -> Response: - if self.livy_session.is_new_session_required: - LivySessionManager.connect(self.credential) - # connect() may replace livy_global_session with a new LivySession - # object; update our reference so session_id reflects the new session. - self.livy_session = LivySessionManager.livy_global_session - self.session_id = self.livy_session.session_id - - # Submit code with retry for transient 5xx and 429 (rate-limit) errors - data = {"code": code, "kind": "sql"} - url = self.connect_url + "/sessions/" + self.session_id + "/statements" - logger.debug(f"Submitted: {data} {url}") - - max_retries = 5 - res = None - for attempt in range(max_retries): - try: - res = requests.post( - url, - data=json.dumps(data), - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ) - except ( - requests.exceptions.SSLError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.ChunkedEncodingError, - ) as exc: - if attempt >= max_retries - 1: - raise DbtRuntimeError( - f"Livy statement submit failed after {max_retries} retries: {exc}" - ) - wait_time = 2**attempt * 1 - logger.debug( - f"Livy statement submit got transient network error " - f"({type(exc).__name__}: {exc}), retrying in {wait_time}s " - f"(attempt {attempt + 1}/{max_retries})" - ) - time.sleep(wait_time) - continue - if res.status_code == 429: - retry_after = _parse_retry_after(res) - wait_time = max(retry_after, 2**attempt * 1) - logger.debug( - f"Livy statement submit got HTTP 429, " - f"retrying in {wait_time:.0f}s (attempt {attempt + 1}/{max_retries})" - ) - time.sleep(wait_time) - continue - if res.status_code < 500: - break - if attempt < max_retries - 1: - wait_time = 2**attempt * 1 # 1s, 2s, 4s, 8s - logger.debug( - f"Livy statement submit got HTTP {res.status_code}, " - f"retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})" - ) - time.sleep(wait_time) - - if res.status_code >= 400: - # A 404 on submit means the Livy session is gone; flag for reconnect - # so the next _execute_query_with_retry attempt gets a fresh session. - if res.status_code == 404 and LivySessionManager.livy_global_session is not None: - LivySessionManager.livy_global_session.is_new_session_required = True - logger.debug("Livy statement submit returned 404 — flagging session for reconnect") - raise DbtRuntimeError( - f"Livy statement submit failed (HTTP {res.status_code}): {res.text}" - ) - json_body = res.json() - if "id" not in json_body: - raise DbtRuntimeError( - f"Livy statement submit returned unexpected response (missing 'id'): {json_body}" - ) - return res - - def _getLivySQL(self, sql) -> str: - # Comment, what is going on?! - # The following code is actually injecting SQL to pyspark object for executing it via the Livy session - over an HTTP post request. - # Basically, it is like code inside a code. As a result the strings passed here in 'escapedSQL' variable are unescapted and interpreted on the server side. - # This may have repurcursions of code injection not only as SQL, but also arbritary Python code. An alternate way safer way to acheive this is still unknown. - # TODO: since the above code is not changed to sending direct SQL to the livy backend, client side string escaping is probably not needed - - code = re.sub(r"\s*/\*(.|\n)*?\*/\s*", "\n", sql, re.DOTALL).strip() - return code - - def _getLivyResult(self, res_obj) -> Response: - json_res = res_obj.json() - statement_id = repr(json_res["id"]) - url = self.connect_url + "/sessions/" + self.session_id + "/statements/" + statement_id - # statement_timeout == 0 means no timeout (poll indefinitely), matching - # the pre-1.9.5 behavior where long-running models were never interrupted. - deadline = ( - (time.time() + self.credential.statement_timeout) - if self.credential.statement_timeout > 0 - else None - ) - consecutive_failures = 0 - max_poll_retries = 30 - # Adaptive polling: start small so quick statements don't sit idle, grow - # to a cap so slow statements don't hammer the server. Initial value is - # intentionally not *too* small — Fabric Livy sometimes returns 404 for a - # just-submitted statement id that has not yet registered on the server - # (handled by the 404 retry block below). - _poll_interval = 0.3 - _poll_interval_cap = max(self.credential.poll_statement_wait * 3, 1.5) - # 404 can appear transiently right after submit before the statement id - # is registered, or when the Fabric Livy service briefly loses track of - # the session/statement. Retry with exponential backoff before giving up. - not_found_retries = 0 - max_not_found_retries = 20 - while True: - if deadline is not None and time.time() > deadline: - raise DbtDatabaseError( - f"Timeout ({self.credential.statement_timeout}s) waiting for statement " - f"{statement_id} to complete. Increase `statement_timeout` in profiles.yml." - ) - try: - poll_res = requests.get( - url, - headers=get_headers(self.credential, False), - timeout=self.credential.http_timeout, - ) - except ( - requests.exceptions.SSLError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.ChunkedEncodingError, - ) as exc: - consecutive_failures += 1 - if consecutive_failures > max_poll_retries: - raise DbtRuntimeError( - f"Livy statement poll failed after {max_poll_retries} retries " - f"({type(exc).__name__}: {exc})" - ) - wait_time = min(2 ** (consecutive_failures - 1), 30) - logger.debug( - f"Livy statement poll got transient network error " - f"({type(exc).__name__}: {exc}), retrying in {wait_time}s " - f"(attempt {consecutive_failures}/{max_poll_retries})" - ) - time.sleep(wait_time) - continue - if poll_res.status_code == 429: - consecutive_failures += 1 - retry_after = _parse_retry_after(poll_res) - wait_time = max(retry_after, 2 ** (consecutive_failures - 1) * 1) - logger.debug( - f"Livy statement poll got HTTP 429, " - f"retrying in {wait_time:.0f}s (attempt {consecutive_failures}/{max_poll_retries})" - ) - time.sleep(wait_time) - if consecutive_failures > max_poll_retries: - raise DbtRuntimeError( - f"Livy statement poll failed after {max_poll_retries} retries " - f"(HTTP 429): {poll_res.text}" - ) - continue - if poll_res.status_code >= 500: - consecutive_failures += 1 - if consecutive_failures <= max_poll_retries: - wait_time = 2 ** (consecutive_failures - 1) * 1 # 1s, 2s, 4s, ... - logger.debug( - f"Livy statement poll got HTTP {poll_res.status_code}, " - f"retrying in {wait_time}s (attempt {consecutive_failures}/{max_poll_retries})" - ) - time.sleep(wait_time) - continue - raise DbtRuntimeError( - f"Livy statement poll failed after {max_poll_retries} retries " - f"(HTTP {poll_res.status_code}): {poll_res.text}" - ) - if poll_res.status_code == 404 and not_found_retries < max_not_found_retries: - # Statement id not yet visible on the server; back off briefly and retry. - not_found_retries += 1 - wait_time = min(0.3 * (2.0 ** (not_found_retries - 1)), 5.0) - logger.debug( - f"Livy statement poll got HTTP 404, retrying in {wait_time:.2f}s " - f"(not-found attempt {not_found_retries}/{max_not_found_retries})" - ) - time.sleep(wait_time) - continue - if poll_res.status_code >= 400: - # A 404 that survived all not-found retries means the session (not - # just the statement) is gone. Flag for reconnect so that the outer - # _execute_query_with_retry creates a fresh session on the next attempt. - if ( - poll_res.status_code == 404 - and LivySessionManager.livy_global_session is not None - ): - LivySessionManager.livy_global_session.is_new_session_required = True - logger.debug( - "Livy statement poll exhausted 404 retries — flagging session for reconnect" - ) - raise DbtRuntimeError( - f"Livy statement poll failed (HTTP {poll_res.status_code}): {poll_res.text}" - ) - consecutive_failures = 0 - res = poll_res.json() - if "state" not in res: - raise DbtRuntimeError( - f"Livy statement poll returned unexpected response (missing 'state'): {res}" - ) - - if res["state"] == "available": - return res - elif res["state"] in ("error", "cancelled", "cancelling"): - error_msg = res.get("output", {}).get("evalue", "Unknown error") - raise DbtDatabaseError( - f"Statement {statement_id} failed with state '{res['state']}': {error_msg}" - ) - time.sleep(_poll_interval) - _poll_interval = min(_poll_interval * 1.5, _poll_interval_cap) - - def execute(self, sql: str, *parameters: Any) -> None: - """ - Execute a sql statement. - - Parameters - ---------- - sql : str - Execute a sql statement. - *parameters : Any - The parameters. - - Raises - ------ - NotImplementedError - If there are parameters given. We do not format sql statements. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#executesql-parameters - """ - if len(parameters) > 0: - sql = sql % parameters - - # Reset fetch position for the new query - self._fetch_index = 0 - - # TODO: handle parameterised sql - - res = self._getLivyResult(self._submitLivyCode(self._getLivySQL(sql))) - logger.debug(res) - if res["output"]["status"] == "ok": - # Local and Fabric Livy have different output structures - if self.is_local_mode: - # Local Livy returns data in "text/plain" or "application/json" format - output_data = res["output"].get("data", {}) - if "application/json" in output_data: - values = output_data["application/json"] - if isinstance(values, dict) and "data" in values: - self._rows = values["data"] - self._schema = values.get("schema", {}).get("fields", []) - elif isinstance(values, list): - # Direct list of results - self._rows = values - self._schema = [] - else: - self._rows = [] - self._schema = [] - elif "text/plain" in output_data: - # Text output - parse if possible - self._rows = [] - self._schema = [] - else: - self._rows = [] - self._schema = [] - else: - # Fabric Livy format - values = res["output"]["data"]["application/json"] - if len(values) >= 1: - self._rows = values["data"] # values[0]['values'] - self._schema = values["schema"]["fields"] # values[0]['schema'] - else: - self._rows = [] - self._schema = [] - else: - self._rows = None - self._schema = None - - raise DbtDatabaseError("Error while executing query: " + res["output"]["evalue"]) - - def fetchall(self): - """ - Fetch all data. - - Returns - ------- - out : list() | None - The rows. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#fetchall - """ - return self._rows - - def fetchmany(self, size=None): - """ - Fetch up to *size* rows. - - Fabric's Livy statement-result API returns the entire result set in - one JSON response — there is no server-side cursor or streaming - primitive. The full result set is therefore already materialised in - ``self._rows`` before this method is called. Slicing locally is - faithful to the actual underlying behaviour. - - Parameters - ---------- - size : int | None - Maximum number of rows to return. When ``None`` all rows are - returned (equivalent to ``fetchall``). - - Returns - ------- - out : list | None - Up to *size* rows, or ``None`` if the query produced no result - set. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#fetchmany - """ - if self._rows is None: - return None - if size is None: - return self._rows - return self._rows[:size] - - def fetchone(self): - """ - Fetch the first output. - - Returns - ------- - out : one row | None - The first row. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#fetchone - """ - - if self._rows is not None and self._fetch_index < len(self._rows): - row = self._rows[self._fetch_index] - self._fetch_index += 1 - else: - row = None - - return row - - -class LivyConnection: - """ - Mock a pyodbc connection. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Connection - """ - - def __init__(self, credentials, livy_session) -> None: - self.credential: FabricSparkCredentials = credentials - self.connect_url = credentials.lakehouse_endpoint - self.session_id = livy_session.session_id - - self._cursor = LivyCursor(self.credential, livy_session) - - def get_session_id(self) -> str: - return self.session_id - - def get_headers(self) -> dict[str, str]: - return get_headers(self.credential, False) - - def get_connect_url(self) -> str: - return self.connect_url - - def cursor(self) -> LivyCursor: - """ - Get a cursor. - - Returns - ------- - out : Cursor - The cursor. - """ - return self._cursor - - def close(self) -> None: - """ - Close the connection. - - Source - ------ - https://github.com/mkleehammer/pyodbc/wiki/Cursor#close - """ - logger.debug("Connection.close()") - self._cursor.close() - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: Exception | None, - exc_tb: TracebackType | None, - ) -> bool: - self.close() - return True - - -def _atexit_cleanup() -> None: - """Delete the Fabric Livy session on process exit. - - Local-mode sessions are kept alive for reuse across runs. - """ - LivySessionManager.disconnect() - - -atexit.register(_atexit_cleanup) - - -# TODO: How to authenticate -class LivySessionManager: - livy_global_session = None - - @staticmethod - def connect(credentials: FabricSparkCredentials) -> LivyConnection: - """Connect to a Livy session. - - For local mode: reuses existing sessions via session ID file persistence. - For Fabric mode: always creates a new session. - - This method is thread-safe and uses a lock to prevent race conditions - when multiple threads attempt to create sessions simultaneously. - """ - with _session_lock: - spark_config = credentials.spark_config - - if credentials.is_local_mode: - LivySessionManager._connect_local(credentials, spark_config) - else: - LivySessionManager._connect_fabric(credentials, spark_config) - - livyConnection = LivyConnection(credentials, LivySessionManager.livy_global_session) - return livyConnection - - @staticmethod - def _connect_local(credentials: FabricSparkCredentials, spark_config) -> None: - """Connect in local mode with session file reuse. - - Local mode persists the Livy session ID to a file so that subsequent - dbt invocations can reuse the same session instead of creating a new one. - - Connection strategy (in order): - 1. Reuse the in-memory session if it's still valid and ready. - 2. Read the session ID from the persisted file and try to reattach. - Skip if the file contains the same ID we already hold (it already - failed validity above, so retrying would be redundant). - 3. Create a brand-new session and persist its ID to the file. - """ - session_file_path = credentials.resolved_session_id_file - session = LivySessionManager.livy_global_session - - # 1. Fast path: reuse current in-memory session if it's valid and idle - if ( - session is not None - and session.is_valid_session() - and not session.is_new_session_required - ): - logger.debug(f"Reusing session: {session.session_id}") - return - - # Ensure we have a LivySession instance to work with - if session is None: - session = LivySession(credentials) - LivySessionManager.livy_global_session = session - - # 2. Try to reattach to an existing session persisted in the file. - # Skip if the file holds the same session ID we already have — - # that session was just found invalid above, no point retrying. - existing_session_id = read_session_id_from_file(session_file_path) - if existing_session_id and existing_session_id != session.session_id: - if session.try_reuse_session(existing_session_id): - logger.debug(f"Reused session from file: {existing_session_id}") - return - - # 3. No reusable session available — create a new one and persist its ID - LivySessionManager._create_and_persist_session(spark_config, session_file_path) - - @staticmethod - def _connect_fabric(credentials: FabricSparkCredentials, spark_config) -> None: - """Connect in Fabric mode. - - When reuse_session is False (default): - Creates a new session each time unless there is already a valid, - ready session in memory. Session is deleted at exit. - - When reuse_session is True: - Reuses existing sessions via session ID file persistence, similar - to local mode. Session is kept alive at exit for reuse by subsequent - dbt runs. Fabric will auto-kill it after the configured idle timeout. - - After session creation, any configured OneLake shortcuts are also created. - """ - if credentials.reuse_session: - LivySessionManager._connect_fabric_reuse(credentials, spark_config) - else: - LivySessionManager._connect_fabric_fresh(credentials, spark_config) - - @staticmethod - def _connect_fabric_fresh(credentials: FabricSparkCredentials, spark_config) -> None: - """Connect in Fabric mode — always creates a new session.""" - session = LivySessionManager.livy_global_session - needs_new_session = ( - session is None or not session.is_valid_session() or session.is_new_session_required - ) - - if not needs_new_session: - logger.debug(f"Reusing session: {session.session_id}") - return - - LivySessionManager._create_fabric_session(credentials, spark_config) - - @staticmethod - def _connect_fabric_reuse(credentials: FabricSparkCredentials, spark_config) -> None: - """Connect in Fabric mode with session reuse across runs. - - Connection strategy (same as local mode): - 1. Reuse the in-memory session if it's still valid and ready. - 2. Read the session ID from the persisted file and try to reattach. - 3. Create a brand-new session and persist its ID to the file. - """ - session_file_path = credentials.resolved_session_id_file - session = LivySessionManager.livy_global_session - - # 1. Fast path: reuse current in-memory session if it's valid and idle - if ( - session is not None - and session.is_valid_session() - and not session.is_new_session_required - ): - logger.debug(f"Reusing Fabric session: {session.session_id}") - return - - # Ensure we have a LivySession instance to work with - if session is None: - session = LivySession(credentials) - LivySessionManager.livy_global_session = session - - # 2. Try to reattach to an existing session persisted in the file. - existing_session_id = read_session_id_from_file(session_file_path) - if existing_session_id and existing_session_id != session.session_id: - if session.try_reuse_session(existing_session_id): - logger.info(f"Reused existing Fabric session from file: {existing_session_id}") - return - - # 3. No reusable session — create a new one and persist its ID - LivySessionManager._create_fabric_session(credentials, spark_config) - write_session_id_to_file( - session_file_path, - LivySessionManager.livy_global_session.session_id, - ) - - @staticmethod - def _create_fabric_session(credentials: FabricSparkCredentials, spark_config) -> None: - """Create a new Fabric Livy session and set up shortcuts.""" - LivySessionManager.livy_global_session = LivySession(credentials) - - # Inject environmentId into spark_config if configured - if credentials.environmentId: - spark_config = { - **spark_config, - "conf": { - **spark_config.get("conf", {}), - "spark.fabric.environment.id": credentials.environmentId, - }, - } - logger.debug(f"Using Fabric Environment: {credentials.environmentId}") - - # Inject session idle timeout into spark_config - if credentials.session_idle_timeout: - spark_config = { - **spark_config, - "conf": { - **spark_config.get("conf", {}), - "spark.livy.session.idle.timeout": credentials.session_idle_timeout, - }, - } - logger.debug(f"Session idle timeout: {credentials.session_idle_timeout}") - - LivySessionManager.livy_global_session.create_session(spark_config) - LivySessionManager.livy_global_session.is_new_session_required = False - - # Create OneLake shortcuts if configured - if credentials.create_shortcuts: - try: - shortcut_client = ShortcutClient( - accessToken.token, - credentials.workspaceid, - credentials.lakehouseid, - credentials.endpoint, - ) - shortcut_client.create_shortcuts(credentials.shortcuts_json_str) - except Exception as ex: - logger.error(f"Unable to create shortcuts: {ex}") - - @staticmethod - def _create_and_persist_session(spark_config, session_file_path: str) -> None: - """Create a new session and write the session ID to file (local mode only).""" - LivySessionManager.livy_global_session.create_session(spark_config) - LivySessionManager.livy_global_session.is_new_session_required = False - write_session_id_to_file( - session_file_path, LivySessionManager.livy_global_session.session_id - ) - - @staticmethod - def disconnect() -> None: - """Disconnect from the session manager. - - - Local mode: keeps the Livy session alive for reuse. - - Fabric mode with reuse_session=True: keeps session alive for reuse. - - Fabric mode with reuse_session=False: deletes the session. - - This method is thread-safe. - """ - with _session_lock: - if LivySessionManager.livy_global_session is None: - logger.debug("No session to disconnect") - return - - session = LivySessionManager.livy_global_session - session_id = session.session_id - - if session.is_local_mode or session.credential.reuse_session: - # Local mode or Fabric reuse mode: keep the session alive - logger.debug( - f"Disconnecting from session manager (session {session_id} kept alive for reuse)" - ) - else: - # Fabric mode: delete the session since it won't be reused - logger.debug(f"Deleting Fabric Livy session: {session_id}") - session.delete_session() - - # Reset the local reference in both cases - LivySessionManager.livy_global_session = None - - -class LivySessionConnectionWrapper(object): - """Connection wrapper for the livy sessoin connection method.""" - - def __init__(self, handle): - self.handle = handle - self._cursor = None - - def cursor(self) -> LivySessionConnectionWrapper: - self._cursor = self.handle.cursor() - return self - - def cancel(self): - logger.debug("NotImplemented: cancel") - - def close(self): - self.handle.close() - - def rollback(self, *args, **kwargs): - logger.debug("NotImplemented: rollback") - - def fetchall(self): - return self._cursor.fetchall() - - def fetchmany(self, size=None): - return self._cursor.fetchmany(size) - - def fetchone(self): - return self._cursor.fetchone() - - def execute(self, sql, bindings=None): - if sql.strip().endswith(";"): - sql = sql.strip()[:-1] - - if bindings is None: - self._cursor.execute(sql) - else: - bindings = [self._fix_binding(binding) for binding in bindings] - self._cursor.execute(sql, *bindings) - - @property - def description(self): - return self._cursor.description - - @classmethod - def _fix_binding(cls, value) -> float | str: - """Convert complex datatypes to primitives that can be loaded by - the Spark driver""" - if isinstance(value, NUMBERS): - return float(value) - elif isinstance(value, dt.datetime): - return f"'{value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}'" - elif value is None: - return "''" - else: - escaped = str(value).replace("'", "\\'") - return f"'{escaped}'" +from dbt.adapters.fabricspark.singleton_livy import ( # noqa: E402 + LivyConnection, + LivyCursor, + LivySession, + LivySessionConnectionWrapper, + LivySessionManager, +) + +__all__ = [ + "LivyConnection", + "LivyCursor", + "LivySession", + "LivySessionConnectionWrapper", + "LivySessionManager", + "get_cli_access_token", + "get_default_access_token", + "get_fabric_notebook_access_token", + "get_headers", + "get_lakehouse_properties", + "get_sp_access_token", + "get_token_credential_access_token", + "is_token_refresh_necessary", + "read_session_id_from_file", + "write_session_id_to_file", +] diff --git a/src/dbt/adapters/fabricspark/singleton_livy.py b/src/dbt/adapters/fabricspark/singleton_livy.py new file mode 100644 index 0000000..b768024 --- /dev/null +++ b/src/dbt/adapters/fabricspark/singleton_livy.py @@ -0,0 +1,974 @@ +from __future__ import annotations + +import atexit +import datetime as dt +import json +import re +import threading +import time +from types import TracebackType +from typing import Any, Optional + +import requests +from dbt_common.exceptions import DbtDatabaseError, DbtRuntimeError +from dbt_common.utils.encoding import DECIMALS +from requests.models import Response + +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions import FailedToConnectError +from dbt.adapters.fabricspark import livysession as _livy_helpers +from dbt.adapters.fabricspark.credentials import FabricSparkCredentials +from dbt.adapters.fabricspark.livy_backend import LivyBackend +from dbt.adapters.fabricspark.shortcuts import ShortcutClient + +logger = AdapterLogger("Microsoft Fabric-Spark") +NUMBERS = DECIMALS + (int, float) + +_session_lock = threading.Lock() + + +def _get_headers(credentials: FabricSparkCredentials, tokenPrint: bool = False) -> dict[str, str]: + return _livy_helpers.get_headers(credentials, tokenPrint) + + +def _parse_retry_after(response: requests.Response) -> float: + return _livy_helpers._parse_retry_after(response) + + +class LivySession: + def __init__(self, credentials: FabricSparkCredentials): + self.credential = credentials + self.connect_url = credentials.lakehouse_endpoint + self.session_id = None + self.is_new_session_required = True + self.is_local_mode = credentials.is_local_mode + + def __enter__(self) -> LivySession: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, + ) -> bool: + return True + + def try_reuse_session(self, session_id: str) -> bool: + """Try to reuse an existing session by ID. + + Checks if the session exists in Livy and is in a usable state. + + Parameters + ---------- + session_id : str + The session ID to try to reuse. + + Returns + ------- + bool + True if session was successfully reused, False otherwise. + """ + try: + logger.debug(f"Attempting to reuse existing session: {session_id}") + self.session_id = session_id + + res = requests.get( + self.connect_url + "/sessions/" + session_id, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + + if res.status_code != 200: + logger.debug(f"Session {session_id} not found (status: {res.status_code})") + self.session_id = None + return False + + res_json = res.json() + + invalid_states = ["dead", "shutting_down", "killed", "error", "not_found"] + + if self.is_local_mode: + current_state = res_json.get("state", "dead") + top_level_state = current_state + else: + top_level_state = res_json.get("state", "") + livy_info = res_json.get("livyInfo", {}) + current_state = livy_info.get("currentState", "") + + if not current_state and top_level_state in ("starting", "not_started"): + current_state = top_level_state + + if current_state in invalid_states: + logger.debug(f"Session {session_id} is in invalid state: {current_state}") + self.session_id = None + return False + + if self.is_local_mode: + if current_state == "idle": + logger.info(f"Successfully reusing existing Livy session: {session_id}") + self.is_new_session_required = False + return True + elif current_state in ("starting", "not_started", "busy"): + logger.debug(f"Session {session_id} is {current_state}, waiting...") + self._wait_for_existing_session(session_id) + logger.info(f"Successfully reusing existing Livy session: {session_id}") + self.is_new_session_required = False + return True + else: + if current_state == "idle": + logger.info(f"Successfully reusing existing Livy session: {session_id}") + self.is_new_session_required = False + return True + elif current_state in ("starting", "not_started", "busy") or top_level_state in ( + "starting", + "not_started", + ): + logger.debug( + f"Session {session_id} is {current_state} (top: {top_level_state}), waiting..." + ) + self._wait_for_existing_session(session_id) + logger.info(f"Successfully reusing existing Livy session: {session_id}") + self.is_new_session_required = False + return True + + logger.debug(f"Session {session_id} in unexpected state: {current_state}") + self.session_id = None + return False + + except requests.exceptions.RequestException as ex: + logger.debug(f"Error checking session {session_id}: {ex}") + self.session_id = None + return False + except Exception as ex: + logger.debug(f"Unexpected error reusing session {session_id}: {ex}") + self.session_id = None + return False + + def _wait_for_existing_session(self, session_id: str) -> None: + """Wait for an existing session to become idle.""" + deadline = time.time() + self.credential.session_start_timeout + + while time.time() < deadline: + res = requests.get( + self.connect_url + "/sessions/" + session_id, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ).json() + + if self.is_local_mode: + state = res.get("state", "") + if state == "idle": + return + elif state in ("dead", "error", "killed"): + raise FailedToConnectError(f"Session {session_id} died while waiting") + else: + logger.debug(f"Session {session_id} is {state}, waiting...") + else: + top_level_state = res.get("state", "") + livy_info = res.get("livyInfo", {}) + livy_state = livy_info.get("currentState", "") + + if livy_state == "idle": + return + elif livy_state in ("dead", "error", "killed") or top_level_state in ( + "dead", + "error", + "killed", + ): + raise FailedToConnectError(f"Session {session_id} died while waiting") + else: + logger.debug( + f"Session {session_id} state: top={top_level_state}, livy={livy_state}, waiting..." + ) + + time.sleep(self.credential.poll_wait) + + raise FailedToConnectError( + f"Timeout ({self.credential.session_start_timeout}s) waiting for session {session_id} to become idle" + ) + + def create_session(self, spark_config) -> str: + # Fabric Livy returns 404 transiently right after a lakehouse is + # provisioned, before the Livy feature is fully wired up. + response = None + logger.debug("Creating Livy session (this may take a few minutes)") + + if self.is_local_mode: + session_data = {"kind": "sql"} + if "kind" in spark_config: + session_data["kind"] = spark_config["kind"] + else: + session_data = spark_config + + max_create_retries = 5 + for attempt in range(max_create_retries): + try: + response = requests.post( + self.connect_url + "/sessions", + data=json.dumps(session_data), + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + if response.status_code in (200, 201, 202): + logger.debug("Initiated Livy Session...") + break + if attempt < max_create_retries - 1 and ( + response.status_code == 404 or response.status_code >= 500 + ): + wait_time = 5 * (2**attempt) + logger.warning( + f"Livy session create returned HTTP {response.status_code}, " + f"retrying in {wait_time}s (attempt {attempt + 1}/{max_create_retries})" + ) + time.sleep(wait_time) + continue + response.raise_for_status() + except requests.exceptions.ConnectionError as c_err: + err_detail = c_err.response.json() if c_err.response else str(c_err) + raise Exception("Connection Error :", err_detail) + except requests.exceptions.HTTPError as h_err: + err_detail = h_err.response.json() if h_err.response else str(h_err) + raise Exception("Http Error: ", err_detail) + except requests.exceptions.Timeout as t_err: + err_detail = t_err.response.json() if t_err.response else str(t_err) + raise Exception("Timeout Error: ", err_detail) + except requests.exceptions.RequestException as a_err: + err_detail = a_err.response.json() if a_err.response else str(a_err) + raise Exception("Authorization Error: ", err_detail) + except Exception as ex: + raise Exception(ex) from ex + + if response is None: + raise Exception("Invalid response from Livy server") + + self.session_id = None + try: + self.session_id = str(response.json()["id"]) + except requests.exceptions.JSONDecodeError as json_err: + raise Exception("Json decode error to get session_id") from json_err + + self.wait_for_session_start() + + logger.debug("Livy session created successfully") + return self.session_id + + def wait_for_session_start(self) -> None: + """Wait for the Livy session to reach the 'idle' state.""" + deadline = time.time() + self.credential.session_start_timeout + while True: + if time.time() > deadline: + raise FailedToConnectError( + f"Timeout ({self.credential.session_start_timeout}s) waiting for session " + f"{self.session_id} to start. Increase `session_start_timeout` in profiles.yml." + ) + try: + response = requests.get( + self.connect_url + "/sessions/" + self.session_id, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + res = response.json() + except ( + requests.exceptions.RequestException, + requests.exceptions.JSONDecodeError, + ) as exc: + logger.warning( + f"Transient error polling session {self.session_id} status: {exc}; " + f"will retry in {self.credential.poll_wait}s" + ) + time.sleep(self.credential.poll_wait) + continue + + if self.is_local_mode: + state = res.get("state", "") + if state in ("starting", "not_started"): + time.sleep(self.credential.poll_wait) + elif state == "idle": + logger.debug(f"New livy session id is: {self.session_id}, {res}") + self.is_new_session_required = False + break + elif state in ("dead", "error"): + logger.error("ERROR, cannot create a livy session") + raise FailedToConnectError("failed to connect") + else: + top_level_state = res.get("state", "") + livy_info = res.get("livyInfo", {}) + livy_state = livy_info.get("currentState", "") + + if top_level_state in ("starting", "not_started"): + logger.debug(f"Session {self.session_id} is {top_level_state}, waiting...") + time.sleep(self.credential.poll_wait) + elif livy_state == "idle": + logger.debug(f"New livy session id is: {self.session_id}, {res}") + self.is_new_session_required = False + break + elif livy_state == "dead" or top_level_state == "dead": + logger.error("ERROR, cannot create a livy session") + raise FailedToConnectError("failed to connect") + else: + logger.debug( + f"Session {self.session_id} in state: top={top_level_state}, livy={livy_state}, waiting..." + ) + time.sleep(self.credential.poll_wait) + + def delete_session(self) -> None: + try: + res = requests.delete( + self.connect_url + "/sessions/" + self.session_id, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + if res.status_code == 200: + logger.debug(f"Closed the livy session: {self.session_id}") + else: + res.raise_for_status() + + except Exception as ex: + logger.error(f"Unable to close the livy session {self.session_id}, error: {ex}") + + def is_valid_session(self) -> bool: + if self.session_id is None: + logger.error("Session ID is None") + return False + try: + res = requests.get( + self.connect_url + "/sessions/" + self.session_id, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ).json() + except Exception as ex: + logger.debug(f"is_valid_session HTTP error: {ex}") + return False + + invalid_states = ["dead", "shutting_down", "killed", "error"] + + if self.is_local_mode: + current_state = res.get("state", "dead") + else: + top_level_state = res.get("state", "") + livy_info = res.get("livyInfo", {}) + current_state = livy_info.get("currentState", "") + + if not current_state: + current_state = top_level_state if top_level_state else "dead" + + return current_state not in invalid_states + + +class LivyCursor: + """Mock a pyodbc cursor. + + Source: https://github.com/mkleehammer/pyodbc/wiki/Cursor + """ + + def __init__(self, credential, livy_session) -> None: + self._rows = None + self._schema = None + self._fetch_index = 0 + self.credential = credential + self.connect_url = credential.lakehouse_endpoint + self.session_id = livy_session.session_id + self.livy_session = livy_session + self.is_local_mode = credential.is_local_mode + + def __enter__(self) -> LivyCursor: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, + ) -> bool: + self.close() + return True + + @property + def description( + self, + ) -> list[tuple[str, str, None, None, None, None, bool]]: + if self._schema is None: + description = list() + else: + description = [ + ( + field["name"], + field["type"], + None, + None, + None, + None, + field["nullable"], + ) + for field in self._schema + ] + return description + + def close(self) -> None: + self._rows = None + + def _submitLivyCode(self, code) -> Response: + if self.livy_session.is_new_session_required: + LivySessionManager._connect_impl(self.credential) + # connect() may swap in a new LivySession; resync our reference. + self.livy_session = LivySessionManager.livy_global_session + self.session_id = self.livy_session.session_id + + data = {"code": code, "kind": "sql"} + url = self.connect_url + "/sessions/" + self.session_id + "/statements" + logger.debug(f"Submitted: {data} {url}") + + max_retries = 5 + res = None + for attempt in range(max_retries): + try: + res = requests.post( + url, + data=json.dumps(data), + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + except ( + requests.exceptions.SSLError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ChunkedEncodingError, + ) as exc: + if attempt >= max_retries - 1: + raise DbtRuntimeError( + f"Livy statement submit failed after {max_retries} retries: {exc}" + ) + wait_time = 2**attempt * 1 + logger.debug( + f"Livy statement submit got transient network error " + f"({type(exc).__name__}: {exc}), retrying in {wait_time}s " + f"(attempt {attempt + 1}/{max_retries})" + ) + time.sleep(wait_time) + continue + if res.status_code == 429: + retry_after = _parse_retry_after(res) + wait_time = max(retry_after, 2**attempt * 1) + logger.debug( + f"Livy statement submit got HTTP 429, " + f"retrying in {wait_time:.0f}s (attempt {attempt + 1}/{max_retries})" + ) + time.sleep(wait_time) + continue + if res.status_code < 500: + break + if attempt < max_retries - 1: + wait_time = 2**attempt * 1 + logger.debug( + f"Livy statement submit got HTTP {res.status_code}, " + f"retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})" + ) + time.sleep(wait_time) + + if res.status_code >= 400: + if res.status_code == 404 and LivySessionManager.livy_global_session is not None: + LivySessionManager.livy_global_session.is_new_session_required = True + logger.debug("Livy statement submit returned 404 — flagging session for reconnect") + raise DbtRuntimeError( + f"Livy statement submit failed (HTTP {res.status_code}): {res.text}" + ) + json_body = res.json() + if "id" not in json_body: + raise DbtRuntimeError( + f"Livy statement submit returned unexpected response (missing 'id'): {json_body}" + ) + return res + + def _getLivySQL(self, sql) -> str: + # The Livy SQL submit path interpolates this string into a code block + # for the server-side interpreter, so embedded /* ... */ comments are + # stripped here before submission. Client-side escaping is unnecessary + # because submission uses POST JSON, not URL encoding. + code = re.sub(r"\s*/\*(.|\n)*?\*/\s*", "\n", sql, re.DOTALL).strip() + return code + + def _getLivyResult(self, res_obj) -> Response: + json_res = res_obj.json() + statement_id = repr(json_res["id"]) + url = self.connect_url + "/sessions/" + self.session_id + "/statements/" + statement_id + deadline = ( + (time.time() + self.credential.statement_timeout) + if self.credential.statement_timeout > 0 + else None + ) + consecutive_failures = 0 + max_poll_retries = 30 + _poll_interval = 0.3 + _poll_interval_cap = max(self.credential.poll_statement_wait * 3, 1.5) + # 404 can appear transiently right after submit before the statement id + # is registered, or when the Fabric Livy service briefly loses track of + # the session/statement. Retry with exponential backoff before giving up. + not_found_retries = 0 + max_not_found_retries = 20 + while True: + if deadline is not None and time.time() > deadline: + raise DbtDatabaseError( + f"Timeout ({self.credential.statement_timeout}s) waiting for statement " + f"{statement_id} to complete. Increase `statement_timeout` in profiles.yml." + ) + try: + poll_res = requests.get( + url, + headers=_get_headers(self.credential, False), + timeout=self.credential.http_timeout, + ) + except ( + requests.exceptions.SSLError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ChunkedEncodingError, + ) as exc: + consecutive_failures += 1 + if consecutive_failures > max_poll_retries: + raise DbtRuntimeError( + f"Livy statement poll failed after {max_poll_retries} retries " + f"({type(exc).__name__}: {exc})" + ) + wait_time = min(2 ** (consecutive_failures - 1), 30) + logger.debug( + f"Livy statement poll got transient network error " + f"({type(exc).__name__}: {exc}), retrying in {wait_time}s " + f"(attempt {consecutive_failures}/{max_poll_retries})" + ) + time.sleep(wait_time) + continue + if poll_res.status_code == 429: + consecutive_failures += 1 + retry_after = _parse_retry_after(poll_res) + wait_time = max(retry_after, 2 ** (consecutive_failures - 1) * 1) + logger.debug( + f"Livy statement poll got HTTP 429, " + f"retrying in {wait_time:.0f}s (attempt {consecutive_failures}/{max_poll_retries})" + ) + time.sleep(wait_time) + if consecutive_failures > max_poll_retries: + raise DbtRuntimeError( + f"Livy statement poll failed after {max_poll_retries} retries " + f"(HTTP 429): {poll_res.text}" + ) + continue + if poll_res.status_code >= 500: + consecutive_failures += 1 + if consecutive_failures <= max_poll_retries: + wait_time = 2 ** (consecutive_failures - 1) * 1 + logger.debug( + f"Livy statement poll got HTTP {poll_res.status_code}, " + f"retrying in {wait_time}s (attempt {consecutive_failures}/{max_poll_retries})" + ) + time.sleep(wait_time) + continue + raise DbtRuntimeError( + f"Livy statement poll failed after {max_poll_retries} retries " + f"(HTTP {poll_res.status_code}): {poll_res.text}" + ) + if poll_res.status_code == 404 and not_found_retries < max_not_found_retries: + not_found_retries += 1 + wait_time = min(0.3 * (2.0 ** (not_found_retries - 1)), 5.0) + logger.debug( + f"Livy statement poll got HTTP 404, retrying in {wait_time:.2f}s " + f"(not-found attempt {not_found_retries}/{max_not_found_retries})" + ) + time.sleep(wait_time) + continue + if poll_res.status_code >= 400: + if ( + poll_res.status_code == 404 + and LivySessionManager.livy_global_session is not None + ): + LivySessionManager.livy_global_session.is_new_session_required = True + logger.debug( + "Livy statement poll exhausted 404 retries — flagging session for reconnect" + ) + raise DbtRuntimeError( + f"Livy statement poll failed (HTTP {poll_res.status_code}): {poll_res.text}" + ) + consecutive_failures = 0 + res = poll_res.json() + if "state" not in res: + raise DbtRuntimeError( + f"Livy statement poll returned unexpected response (missing 'state'): {res}" + ) + + if res["state"] == "available": + return res + elif res["state"] in ("error", "cancelled", "cancelling"): + error_msg = res.get("output", {}).get("evalue", "Unknown error") + raise DbtDatabaseError( + f"Statement {statement_id} failed with state '{res['state']}': {error_msg}" + ) + time.sleep(_poll_interval) + _poll_interval = min(_poll_interval * 1.5, _poll_interval_cap) + + def execute(self, sql: str, *parameters: Any) -> None: + if len(parameters) > 0: + sql = sql % parameters + + # Reset fetch position for the new query + self._fetch_index = 0 + + res = self._getLivyResult(self._submitLivyCode(self._getLivySQL(sql))) + logger.debug(res) + if res["output"]["status"] == "ok": + if self.is_local_mode: + output_data = res["output"].get("data", {}) + if "application/json" in output_data: + values = output_data["application/json"] + if isinstance(values, dict) and "data" in values: + self._rows = values["data"] + self._schema = values.get("schema", {}).get("fields", []) + elif isinstance(values, list): + self._rows = values + self._schema = [] + else: + self._rows = [] + self._schema = [] + elif "text/plain" in output_data: + self._rows = [] + self._schema = [] + else: + self._rows = [] + self._schema = [] + else: + values = res["output"]["data"]["application/json"] + if len(values) >= 1: + self._rows = values["data"] + self._schema = values["schema"]["fields"] + else: + self._rows = [] + self._schema = [] + else: + self._rows = None + self._schema = None + + raise DbtDatabaseError("Error while executing query: " + res["output"]["evalue"]) + + def fetchall(self): + return self._rows + + def fetchmany(self, size=None): + """Fabric's Livy statement-result API returns the entire result set in + one JSON response — there is no server-side cursor or streaming + primitive. The full result set is therefore already materialised in + ``self._rows`` before this method is called. Slicing locally is + faithful to the actual underlying behaviour. + """ + if self._rows is None: + return None + if size is None: + return self._rows + return self._rows[:size] + + def fetchone(self): + if self._rows is not None and self._fetch_index < len(self._rows): + row = self._rows[self._fetch_index] + self._fetch_index += 1 + else: + row = None + + return row + + +class LivyConnection: + """Mock a pyodbc connection. + + Source: https://github.com/mkleehammer/pyodbc/wiki/Connection + """ + + def __init__(self, credentials, livy_session) -> None: + self.credential: FabricSparkCredentials = credentials + self.connect_url = credentials.lakehouse_endpoint + self.session_id = livy_session.session_id + + self._cursor = LivyCursor(self.credential, livy_session) + + def get_session_id(self) -> str: + return self.session_id + + def get_headers(self) -> dict[str, str]: + return _get_headers(self.credential, False) + + def get_connect_url(self) -> str: + return self.connect_url + + def cursor(self) -> LivyCursor: + return self._cursor + + def close(self) -> None: + logger.debug("Connection.close()") + self._cursor.close() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: Exception | None, + exc_tb: TracebackType | None, + ) -> bool: + self.close() + return True + + +def _atexit_cleanup() -> None: + """Delete the Fabric Livy session on process exit. + + Local-mode sessions are kept alive for reuse across runs. + """ + LivySessionManager._disconnect_impl() + + +atexit.register(_atexit_cleanup) + + +class LivySessionManager(LivyBackend): + livy_global_session: Optional[LivySession] = None + + @classmethod + def connect(cls, credentials: FabricSparkCredentials) -> LivyConnection: # type: ignore[override] + return cls._connect_impl(credentials) + + @classmethod + def disconnect(cls) -> None: # type: ignore[override] + cls._disconnect_impl() + + @staticmethod + def _connect_impl(credentials: FabricSparkCredentials) -> LivyConnection: + """Singleton Livy session — one per process, shared across threads. + + This is the legacy code path preserved verbatim from the pre-HC + adapter. Statements submitted to one Livy session execute inside its + single interpreter context and are queued FIFO inside the default + Spark scheduling pool. + """ + with _session_lock: + spark_config = credentials.spark_config + + if credentials.is_local_mode: + LivySessionManager._connect_local(credentials, spark_config) + else: + LivySessionManager._connect_fabric(credentials, spark_config) + + livyConnection = LivyConnection(credentials, LivySessionManager.livy_global_session) + return livyConnection + + @staticmethod + def _connect_local(credentials: FabricSparkCredentials, spark_config) -> None: + """Local mode connection with session-file reuse. + + Strategy: + 1. Reuse the in-memory session if valid and ready. + 2. Read the persisted session ID and try to reattach. + 3. Create a brand-new session and persist its ID. + """ + session_file_path = credentials.resolved_session_id_file + session = LivySessionManager.livy_global_session + + if ( + session is not None + and session.is_valid_session() + and not session.is_new_session_required + ): + logger.debug(f"Reusing session: {session.session_id}") + return + + if session is None: + session = LivySession(credentials) + LivySessionManager.livy_global_session = session + + existing_session_id = _livy_helpers.read_session_id_from_file(session_file_path) + if existing_session_id and existing_session_id != session.session_id: + if session.try_reuse_session(existing_session_id): + logger.debug(f"Reused session from file: {existing_session_id}") + return + + LivySessionManager._create_and_persist_session(spark_config, session_file_path) + + @staticmethod + def _connect_fabric(credentials: FabricSparkCredentials, spark_config) -> None: + if credentials.reuse_session: + LivySessionManager._connect_fabric_reuse(credentials, spark_config) + else: + LivySessionManager._connect_fabric_fresh(credentials, spark_config) + + @staticmethod + def _connect_fabric_fresh(credentials: FabricSparkCredentials, spark_config) -> None: + """Always create a new session unless a valid one is already in memory.""" + session = LivySessionManager.livy_global_session + needs_new_session = ( + session is None or not session.is_valid_session() or session.is_new_session_required + ) + + if not needs_new_session: + logger.debug(f"Reusing session: {session.session_id}") + return + + LivySessionManager._create_fabric_session(credentials, spark_config) + + @staticmethod + def _connect_fabric_reuse(credentials: FabricSparkCredentials, spark_config) -> None: + """Same strategy as local mode: in-memory > file > create.""" + session_file_path = credentials.resolved_session_id_file + session = LivySessionManager.livy_global_session + + if ( + session is not None + and session.is_valid_session() + and not session.is_new_session_required + ): + logger.debug(f"Reusing Fabric session: {session.session_id}") + return + + if session is None: + session = LivySession(credentials) + LivySessionManager.livy_global_session = session + + existing_session_id = _livy_helpers.read_session_id_from_file(session_file_path) + if existing_session_id and existing_session_id != session.session_id: + if session.try_reuse_session(existing_session_id): + logger.info(f"Reused existing Fabric session from file: {existing_session_id}") + return + + LivySessionManager._create_fabric_session(credentials, spark_config) + _livy_helpers.write_session_id_to_file( + session_file_path, + LivySessionManager.livy_global_session.session_id, + ) + + @staticmethod + def _create_fabric_session(credentials: FabricSparkCredentials, spark_config) -> None: + LivySessionManager.livy_global_session = LivySession(credentials) + + if credentials.environmentId: + spark_config = { + **spark_config, + "conf": { + **spark_config.get("conf", {}), + "spark.fabric.environment.id": credentials.environmentId, + }, + } + logger.debug(f"Using Fabric Environment: {credentials.environmentId}") + + if credentials.session_idle_timeout: + spark_config = { + **spark_config, + "conf": { + **spark_config.get("conf", {}), + "spark.livy.session.idle.timeout": credentials.session_idle_timeout, + }, + } + logger.debug(f"Session idle timeout: {credentials.session_idle_timeout}") + + LivySessionManager.livy_global_session.create_session(spark_config) + LivySessionManager.livy_global_session.is_new_session_required = False + + if credentials.create_shortcuts: + try: + shortcut_client = ShortcutClient( + _livy_helpers.accessToken.token, + credentials.workspaceid, + credentials.lakehouseid, + credentials.endpoint, + ) + shortcut_client.create_shortcuts(credentials.shortcuts_json_str) + except Exception as ex: + logger.error(f"Unable to create shortcuts: {ex}") + + @staticmethod + def _create_and_persist_session(spark_config, session_file_path: str) -> None: + LivySessionManager.livy_global_session.create_session(spark_config) + LivySessionManager.livy_global_session.is_new_session_required = False + _livy_helpers.write_session_id_to_file( + session_file_path, LivySessionManager.livy_global_session.session_id + ) + + @staticmethod + def _disconnect_impl() -> None: + """Disconnect from the session manager. + + - Local mode: keeps the Livy session alive for reuse. + - Fabric mode with reuse_session=True: keeps session alive for reuse. + - Fabric mode with reuse_session=False: deletes the session. + """ + with _session_lock: + if LivySessionManager.livy_global_session is None: + logger.debug("No session to disconnect") + return + + session = LivySessionManager.livy_global_session + session_id = session.session_id + + if session.is_local_mode or session.credential.reuse_session: + logger.debug( + f"Disconnecting from session manager (session {session_id} kept alive for reuse)" + ) + else: + logger.debug(f"Deleting Fabric Livy session: {session_id}") + session.delete_session() + + LivySessionManager.livy_global_session = None + + # Aliases preserved for explicit class-level invocation patterns. + connect_static = staticmethod(_connect_impl) # type: ignore[assignment] + disconnect_static = staticmethod(_disconnect_impl) # type: ignore[assignment] + + +class LivySessionConnectionWrapper(object): + """Connection wrapper for the livy session connection method.""" + + def __init__(self, handle): + self.handle = handle + self._cursor = None + + def cursor(self) -> LivySessionConnectionWrapper: + self._cursor = self.handle.cursor() + return self + + def cancel(self): + logger.debug("NotImplemented: cancel") + + def close(self): + self.handle.close() + + def rollback(self, *args, **kwargs): + logger.debug("NotImplemented: rollback") + + def fetchall(self): + return self._cursor.fetchall() + + def fetchmany(self, size=None): + return self._cursor.fetchmany(size) + + def fetchone(self): + return self._cursor.fetchone() + + def execute(self, sql, bindings=None): + if sql.strip().endswith(";"): + sql = sql.strip()[:-1] + + if bindings is None: + self._cursor.execute(sql) + else: + bindings = [self._fix_binding(binding) for binding in bindings] + self._cursor.execute(sql, *bindings) + + @property + def description(self): + return self._cursor.description + + @classmethod + def _fix_binding(cls, value) -> float | str: + """Convert complex datatypes to primitives that can be loaded by + the Spark driver""" + if isinstance(value, NUMBERS): + return float(value) + elif isinstance(value, dt.datetime): + return f"'{value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}'" + elif value is None: + return "''" + else: + escaped = str(value).replace("'", "\\'") + return f"'{escaped}'" diff --git a/src/dbt/include/fabricspark/profile_template.yml b/src/dbt/include/fabricspark/profile_template.yml index 4cbcfde..8ba3107 100644 --- a/src/dbt/include/fabricspark/profile_template.yml +++ b/src/dbt/include/fabricspark/profile_template.yml @@ -33,6 +33,10 @@ prompts: type: 'int' session_id_file: hint: Optional path to file storing Livy session ID for session reuse. Defaults to ./livy-session-id.txt if not specified. + high_concurrency: + hint: When true (default), each dbt thread gets its own REPL inside one underlying Livy session via Fabric's high-concurrency Livy API — statements from different threads run in parallel. When false, falls back to the legacy single-session-per-process mode where statements execute FIFO inside the default Spark scheduling pool. Has no effect in local mode. + type: 'bool' + default: true schema: hint: 'default schema that dbt will build objects in' threads: diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 74fbc66..b0dea47 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -340,6 +340,17 @@ def dbt_profile_target(request, workspace_id, api_endpoint, schema_mode): "livy_mode": os.getenv("LIVY_MODE", "fabric"), "reuse_session": True, "session_idle_timeout": "60m", + # High-concurrency Livy is disabled for the functional suite even though + # the user-facing default is True. The pytest-xdist test framework spawns + # multiple worker processes that pre-attach to specific Livy session IDs + # written to file by the orchestrator. The HC API has no "attach to this + # underlying session" parameter (only a packing-hint sessionTag), and + # the docs explicitly note that rapid concurrent POSTs to acquire HC + # sessions with the same tag can create multiple underlying Livy + # sessions instead of packing — overwhelming Fabric capacity in CI. + # Re-enabling here would require a test-infra refactor outside the + # scope of the HC support work itself. + "high_concurrency": False, "spark_config": { "name": f"dbt-test-{lakehouse_name}", "tags": { diff --git a/tests/functional/fixtures/ws2_seed/profiles.yml b/tests/functional/fixtures/ws2_seed/profiles.yml index e85768a..b48d1a6 100644 --- a/tests/functional/fixtures/ws2_seed/profiles.yml +++ b/tests/functional/fixtures/ws2_seed/profiles.yml @@ -12,6 +12,7 @@ ws2_seed_fixture: schema: dbo threads: 1 reuse_session: false + high_concurrency: false retry_all: true connect_retries: 3 connect_timeout: 15 diff --git a/tests/unit/test_concurrent_livy.py b/tests/unit/test_concurrent_livy.py new file mode 100644 index 0000000..5562792 --- /dev/null +++ b/tests/unit/test_concurrent_livy.py @@ -0,0 +1,411 @@ +"""Tests for the high-concurrency Livy backend. + +Mocked-HTTP coverage of the HC lifecycle: +- ``derive_session_tag`` returns the same value across managers when reuse_session + is true, and is uuid-stable per process when reuse_session is false. +- ``HighConcurrencySession.acquire`` follows the documented state machine: + POST returns NotStarted, GET polls through AcquiringHighConcurrencySession, + GET returns Idle with sessionId+replId. +- ``HighConcurrencyCursor.execute`` POSTs to ``/repls/{replId}/statements``, + polls until ``state == available``, and parses Fabric's standard + ``output.data.application/json.{schema,data}`` envelope. +- ``HighConcurrencySessionManager.disconnect`` DELETEs the HC id. +- The HC session manager is registered as a :class:`LivyBackend`. +- 404 on submit flags the REPL for re-acquire. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.fabricspark import concurrent_livy +from dbt.adapters.fabricspark.concurrent_livy import ( + HighConcurrencyConnection, + HighConcurrencyConnectionWrapper, + HighConcurrencyCursor, + HighConcurrencySession, + HighConcurrencySessionManager, + derive_session_tag, +) +from dbt.adapters.fabricspark.credentials import FabricSparkCredentials +from dbt.adapters.fabricspark.livy_backend import LivyBackend + + +def _make_creds(reuse_session: bool = False, **overrides) -> FabricSparkCredentials: + base = dict( + method="livy", + livy_mode="fabric", + authentication="CLI", + workspaceid="1de8390c-9aca-4790-bee8-72049109c0f4", + lakehouseid="8c5bc260-bc3a-4898-9ada-01e433d461ba", + lakehouse="tests", + endpoint="https://api.fabric.microsoft.com/v1", + spark_config={"name": "test-session", "numExecutors": 4}, + reuse_session=reuse_session, + session_start_timeout=10, + statement_timeout=30, + poll_wait=0, + poll_statement_wait=0, + ) + base.update(overrides) + return FabricSparkCredentials(**base) + + +def _mock_response(status_code: int, json_body=None, text: str = "") -> MagicMock: + resp = MagicMock() + resp.status_code = status_code + resp.text = text + if json_body is not None: + resp.json.return_value = json_body + return resp + + +@pytest.fixture(autouse=True) +def _reset_module_state(): + """Reset module-level caches between tests so they don't bleed across cases.""" + concurrent_livy._session_tags.clear() + concurrent_livy._active_sessions.clear() + concurrent_livy._shortcuts_done.clear() + yield + concurrent_livy._session_tags.clear() + concurrent_livy._active_sessions.clear() + concurrent_livy._shortcuts_done.clear() + + +# --------------------------------------------------------------------------- # +# derive_session_tag # +# --------------------------------------------------------------------------- # + + +class TestDeriveSessionTag: + def test_reuse_session_true_returns_deterministic_hash(self): + creds = _make_creds(reuse_session=True) + tag1 = derive_session_tag(creds) + tag2 = derive_session_tag(creds) + assert tag1 == tag2 + # Hash content includes the workspace+lakehouse pair. + assert tag1.startswith("dbt-fabricspark-") + + def test_reuse_session_true_same_pair_yields_same_tag_across_creds(self): + a = _make_creds(reuse_session=True) + b = _make_creds(reuse_session=True) + # Two credential objects targeting the same lakehouse must hit the + # same Spark cluster, so the tag must collide. + assert derive_session_tag(a) == derive_session_tag(b) + + def test_reuse_session_true_different_lakehouse_yields_different_tag(self): + a = _make_creds( + reuse_session=True, + lakehouseid="11111111-1111-1111-1111-111111111111", + ) + # Reset so the second creds gets a fresh tag computation. + concurrent_livy._session_tags.clear() + b = _make_creds( + reuse_session=True, + lakehouseid="22222222-2222-2222-2222-222222222222", + ) + # Different lakehouses → distinct underlying Spark clusters → distinct tags. + assert derive_session_tag(a) != derive_session_tag(b) + + def test_reuse_session_false_caches_uuid_per_process(self): + creds = _make_creds(reuse_session=False) + tag1 = derive_session_tag(creds) + tag2 = derive_session_tag(creds) + # Same process, same creds → cached uuid, so every per-thread manager + # acquires onto the same underlying Livy session for this run. + assert tag1 == tag2 + assert tag1.startswith("dbt-fabricspark-") + + +# --------------------------------------------------------------------------- # +# Acquire # +# --------------------------------------------------------------------------- # + + +class TestHighConcurrencySessionAcquire: + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.get") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_happy_path(self, mock_post, mock_get, _sleep, _headers): + mock_post.return_value = _mock_response(202, {"id": "hc-1", "state": "NotStarted"}) + mock_get.side_effect = [ + _mock_response(200, {"state": "AcquiringHighConcurrencySession"}), + _mock_response( + 200, + { + "state": "Idle", + "sessionId": "livy-42", + "replId": "repl-7", + }, + ), + ] + + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.acquire() + + assert hc.hc_id == "hc-1" + assert hc.session_id == "livy-42" + assert hc.repl_id == "repl-7" + assert hc.is_new_session_required is False + # POST sent sessionTag and conf + post_body = mock_post.call_args.kwargs.get("data") or mock_post.call_args[1].get("data") + assert "sessionTag" in post_body + # Session is now in the active registry so atexit will reap it. + assert hc in concurrent_livy._active_sessions + + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.get") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_terminal_dead_state_raises(self, mock_post, mock_get, _sleep, _headers): + mock_post.return_value = _mock_response(202, {"id": "hc-2", "state": "NotStarted"}) + mock_get.return_value = _mock_response( + 200, + { + "state": "Dead", + "fabricSessionStateInfo": {"errorMessage": "out of capacity"}, + }, + ) + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + with pytest.raises(Exception) as exc: + hc.acquire() + assert "Dead" in str(exc.value) or "out of capacity" in str(exc.value) + + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_404_on_post_retries_then_succeeds(self, mock_post, _sleep, _headers): + mock_post.side_effect = [ + _mock_response(404, text="livy not yet up"), + _mock_response(202, {"id": "hc-3", "state": "NotStarted"}), + ] + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + with patch("dbt.adapters.fabricspark.concurrent_livy.requests.get") as mock_get: + mock_get.return_value = _mock_response( + 200, {"state": "Idle", "sessionId": "s", "replId": "r"} + ) + hc.acquire() + assert hc.hc_id == "hc-3" + + +# --------------------------------------------------------------------------- # +# Cursor execute # +# --------------------------------------------------------------------------- # + + +class TestHighConcurrencyCursorExecute: + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.get") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_select_returns_rows_and_schema(self, mock_post, mock_get, _sleep, _headers): + mock_post.return_value = _mock_response(200, {"id": 1, "state": "waiting"}) + mock_get.return_value = _mock_response( + 200, + { + "id": 1, + "state": "available", + "output": { + "status": "ok", + "data": { + "application/json": { + "schema": { + "fields": [{"name": "version", "type": "string", "nullable": True}] + }, + "data": [["3.5.5"]], + } + }, + }, + }, + ) + + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.hc_id = "hc-x" + hc.session_id = "s" + hc.repl_id = "r" + hc.is_new_session_required = False + + cursor = HighConcurrencyCursor(creds, hc) + cursor.execute("SELECT version()") + + assert cursor.fetchall() == [["3.5.5"]] + assert cursor.fetchone() == ["3.5.5"] + assert cursor.fetchone() is None + assert cursor.description[0][0] == "version" + + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.get") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_ddl_returns_empty_result(self, mock_post, mock_get, _sleep, _headers): + mock_post.return_value = _mock_response(200, {"id": 1, "state": "waiting"}) + # Fabric returns an envelope without `data` for DDL statements. + mock_get.return_value = _mock_response( + 200, + {"id": 1, "state": "available", "output": {"status": "ok", "data": {}}}, + ) + + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.session_id = "s" + hc.repl_id = "r" + hc.is_new_session_required = False + + cursor = HighConcurrencyCursor(creds, hc) + cursor.execute("CREATE TABLE foo (a int)") + assert cursor.fetchall() == [] + + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_404_on_submit_marks_repl_dead(self, mock_post, _sleep, _headers): + mock_post.return_value = _mock_response(404, text="repl gone") + + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.session_id = "s" + hc.repl_id = "r" + hc.is_new_session_required = False + + cursor = HighConcurrencyCursor(creds, hc) + with pytest.raises(Exception): + cursor.execute("SELECT 1") + assert hc.is_dead is True + assert hc.is_new_session_required is True + + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.time.sleep") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.get") + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.post") + def test_statement_error_raises(self, mock_post, mock_get, _sleep, _headers): + mock_post.return_value = _mock_response(200, {"id": 1, "state": "waiting"}) + mock_get.return_value = _mock_response( + 200, + { + "id": 1, + "state": "error", + "output": {"status": "error", "evalue": "table not found"}, + }, + ) + + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.session_id = "s" + hc.repl_id = "r" + hc.is_new_session_required = False + + cursor = HighConcurrencyCursor(creds, hc) + with pytest.raises(Exception) as exc: + cursor.execute("SELECT * FROM nope") + assert "table not found" in str(exc.value) + + +# --------------------------------------------------------------------------- # +# Delete / disconnect # +# --------------------------------------------------------------------------- # + + +class TestHighConcurrencyDelete: + @patch("dbt.adapters.fabricspark.concurrent_livy._get_headers", return_value={}) + @patch("dbt.adapters.fabricspark.concurrent_livy.requests.delete") + def test_delete_calls_api_and_clears_state(self, mock_delete, _headers): + mock_delete.return_value = _mock_response(200) + + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.hc_id = "hc-del" + concurrent_livy._active_sessions.add(hc) + + hc.delete() + + mock_delete.assert_called_once() + assert hc.hc_id is None + assert hc.session_id is None + assert hc.repl_id is None + assert hc not in concurrent_livy._active_sessions + + +# --------------------------------------------------------------------------- # +# Manager lifecycle # +# --------------------------------------------------------------------------- # + + +class TestHighConcurrencySessionManager: + def test_satisfies_livy_backend_abc(self): + mgr = HighConcurrencySessionManager() + assert isinstance(mgr, LivyBackend) + # Both methods are required by the ABC and must be callable. + assert callable(mgr.connect) + assert callable(mgr.disconnect) + + @patch("dbt.adapters.fabricspark.concurrent_livy._maybe_create_shortcuts") + def test_connect_acquires_once(self, _shortcuts): + def _fake_acquire(self): + # Mimic real acquire — set the flag so the manager's healthy-fast-path triggers. + self.is_new_session_required = False + self.session_id = "s" + self.repl_id = "r" + + with patch.object(HighConcurrencySession, "acquire", _fake_acquire): + creds = _make_creds() + mgr = HighConcurrencySessionManager() + conn1 = mgr.connect(creds) + conn2 = mgr.connect(creds) + assert conn1 is conn2 + assert isinstance(conn1, HighConcurrencyConnection) + + @patch("dbt.adapters.fabricspark.concurrent_livy._maybe_create_shortcuts") + @patch.object(HighConcurrencySession, "delete") + @patch.object(HighConcurrencySession, "acquire") + def test_disconnect_releases_hc(self, _acquire, mock_delete, _shortcuts): + creds = _make_creds() + mgr = HighConcurrencySessionManager() + mgr.connect(creds) + mgr.disconnect() + mock_delete.assert_called_once() + assert mgr._hc_session is None + + +# --------------------------------------------------------------------------- # +# Connection wrapper # +# --------------------------------------------------------------------------- # + + +class TestHighConcurrencyConnectionWrapper: + def test_wrapper_delegates_to_cursor(self): + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.session_id = "s" + hc.repl_id = "r" + hc.is_new_session_required = False + conn = HighConcurrencyConnection(creds, hc) + wrapper = HighConcurrencyConnectionWrapper(conn) + + cursor = wrapper.cursor() + assert cursor is wrapper + # The cursor returned by the wrapper must expose execute/fetch* surface. + assert hasattr(wrapper, "execute") + assert hasattr(wrapper, "fetchall") + assert hasattr(wrapper, "fetchmany") + assert hasattr(wrapper, "fetchone") + + def test_execute_strips_trailing_semicolon(self): + creds = _make_creds() + hc = HighConcurrencySession(creds, creds.spark_config) + hc.session_id = "s" + hc.repl_id = "r" + hc.is_new_session_required = False + conn = HighConcurrencyConnection(creds, hc) + wrapper = HighConcurrencyConnectionWrapper(conn) + wrapper.cursor() + + with patch.object(HighConcurrencyCursor, "execute") as mock_exec: + wrapper.execute("SELECT 1;") + mock_exec.assert_called_once_with("SELECT 1")