Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/dbt/adapters/fabricspark/livysession.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import threading
import time
from base64 import urlsafe_b64decode
from typing import Any, Optional

import requests
Expand Down Expand Up @@ -182,13 +183,44 @@ def get_default_access_token(credentials: FabricSparkCredentials) -> AccessToken
out : AccessToken
The access token.
"""
expires_on = 1845972874
derived_expiry = _extract_expiry_from_jwt(credentials.accessToken)
expires_on = derived_expiry or int(time.time())

if derived_expiry is None:
logger.debug(
"Could not derive token expiry from credentials.accessToken; "
"forcing immediate refresh for int_tests auth."
)

# Create an AccessToken instance
accessToken = AccessToken(token=credentials.accessToken, expires_on=expires_on)
return accessToken


def _extract_expiry_from_jwt(token: Optional[str]) -> Optional[int]:
"""Best-effort extraction of JWT `exp` claim."""
if not token:
return None

try:
parts = token.split(".")
if len(parts) < 2:
return None

payload = parts[1]
payload_with_padding = payload + "=" * (-len(payload) % 4)
claims = json.loads(urlsafe_b64decode(payload_with_padding))
exp = claims.get("exp")
if isinstance(exp, int) and exp > 0:
return exp
if isinstance(exp, float) and exp > 0:
return int(exp)
except (json.JSONDecodeError, TypeError, UnicodeDecodeError, ValueError):
logger.debug("Unable to parse JWT expiry from credentials.accessToken.")
return None

return None


def _load_custom_credential(credentials: FabricSparkCredentials) -> Any:
"""
Import and instantiate the user-supplied TokenCredential.
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_livysession.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests for livysession module, focusing on local vs Fabric mode routing."""

import base64
import datetime as dt
import json
import os
import tempfile
from decimal import Decimal
Expand All @@ -16,6 +18,7 @@
LivyCursor,
LivySession,
LivySessionManager,
get_default_access_token,
get_headers,
read_session_id_from_file,
write_session_id_to_file,
Expand Down Expand Up @@ -107,6 +110,44 @@ def test_init_fabric_mode(self):
assert "lakehouses/8c5bc260-bc3a-4898-9ada-01e433d461ba" in session.connect_url


class TestIntTestsAuthTokenExpiry:
@staticmethod
def _make_jwt_with_exp(expiry: int) -> str:
header = base64.urlsafe_b64encode(
json.dumps({"alg": "none", "typ": "JWT"}).encode()
).decode()
payload = base64.urlsafe_b64encode(json.dumps({"exp": expiry}).encode()).decode()
return f"{header.rstrip('=')}.{payload.rstrip('=')}.signature"

@staticmethod
def _make_int_tests_credentials(token: str) -> FabricSparkCredentials:
return FabricSparkCredentials(
method="livy",
livy_mode="fabric",
authentication="int_tests",
accessToken=token,
workspaceid="1de8390c-9aca-4790-bee8-72049109c0f4",
lakehouseid="8c5bc260-bc3a-4898-9ada-01e433d461ba",
lakehouse="tests",
spark_config={"name": "test-session"},
)

def test_default_access_token_uses_jwt_exp_claim(self):
credentials = self._make_int_tests_credentials(self._make_jwt_with_exp(424242))

token = get_default_access_token(credentials)

assert token.expires_on == 424242

def test_default_access_token_without_jwt_exp_forces_immediate_refresh(self):
credentials = self._make_int_tests_credentials("plain-test-token")

with patch("dbt.adapters.fabricspark.livysession.time.time", return_value=1712345678):
token = get_default_access_token(credentials)

assert token.expires_on == 1712345678


class TestCreateSessionRetry:
"""Tests for retry logic in LivySession.create_session()."""

Expand Down