diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4803d70..22fcfdb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,10 +96,10 @@ jobs: run: uv sync --extra dev - name: Run Black - run: uv run black --check mlflow_descope_auth tests + run: uv run black --check mlflow_descope_auth tests examples - name: Run Ruff - run: uv run ruff check mlflow_descope_auth tests + run: uv run ruff check mlflow_descope_auth tests examples - name: Run mypy run: uv run mypy mlflow_descope_auth diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 6347fdf..f4a534d 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,182 +2,155 @@ ## How MLflow Descope Auth Works -This plugin uses **MLflow's standard plugin system** to add Descope authentication to MLflow clients. It does NOT wrap or modify the MLflow server - it extends the client-side behavior. +This plugin uses **MLflow's `mlflow.app` entry point** to add Descope authentication to the MLflow server. It provides server-side authentication with a browser-based login UI. ### Plugin Architecture ```txt ┌─────────────────────────────────────────┐ -│ MLflow Client │ -│ (mlflow.set_tracking_uri(...)) │ -├─────────────────────────────────────────┤ -│ DescopeAuthProvider │ ← Adds Bearer token to requests -│ (mlflow.request_auth_provider) │ -├─────────────────────────────────────────┤ -│ DescopeHeaderProvider │ ← Injects user context headers -│ (mlflow.request_header_provider) │ -├─────────────────────────────────────────┤ -│ DescopeContextProvider │ ← Auto-tags runs with user info -│ (mlflow.run_context_provider) │ -├─────────────────────────────────────────┤ -│ HTTP Request │ -│ Authorization: Bearer │ -│ X-Descope-User-ID: ... │ -│ X-Descope-Email: ... │ +│ Browser │ +│ (User visits MLflow UI) │ └─────────────────────────────────────────┘ ↓ ┌─────────────────────────────────────────┐ -│ MLflow Server │ ← Unmodified, standard MLflow -│ (receives authenticated requests) │ +│ MLflow Server │ +│ (with Descope plugin loaded) │ +├─────────────────────────────────────────┤ +│ before_request hook │ ← Validates session cookie +│ - Check DS/DSR cookies │ +│ - Redirect to /auth/login if invalid │ +│ - Set user context in Flask g │ +├─────────────────────────────────────────┤ +│ after_request hook │ ← Updates cookie if refreshed +│ - Check if token was refreshed │ +│ - Update DS cookie with new token │ +├─────────────────────────────────────────┤ +│ Auth Routes │ +│ - /auth/login (Descope Web Component)│ +│ - /auth/logout (Clear cookies) │ +│ - /auth/user (Current user info) │ +│ - /health (Health check) │ └─────────────────────────────────────────┘ ``` -### Entry Points +### Entry Point -The plugin registers three MLflow entry points in `pyproject.toml`: +The plugin registers one MLflow entry point in `pyproject.toml`: ```toml -[project.entry-points."mlflow.request_auth_provider"] -descope = "mlflow_descope_auth.auth_provider:DescopeAuthProvider" - -[project.entry-points."mlflow.request_header_provider"] -descope = "mlflow_descope_auth.header_provider:DescopeHeaderProvider" - -[project.entry-points."mlflow.run_context_provider"] -descope = "mlflow_descope_auth.context_provider:DescopeContextProvider" +[project.entry-points."mlflow.app"] +descope = "mlflow_descope_auth.server:create_app" ``` ### Authentication Flow -1. **Plugin Activation** +1. **User visits MLflow UI** - ```bash - export MLFLOW_TRACKING_AUTH=descope + ```txt + Browser → GET / + → before_request hook triggered + → No valid session cookie? + → Redirect to /auth/login ``` - MLflow discovers and loads the `descope` auth provider. - -2. **Token Injection** +2. **Login with Descope** ```txt - Client calls mlflow.log_metric(...) - → DescopeAuthProvider.get_auth() called - → Returns callable that adds Authorization header - → Request sent with: Authorization: Bearer + Browser → /auth/login + → Descope Web Component loads + → User authenticates via Descope flow + → On success: JavaScript sets cookies (DS, DSR) + → Redirect to MLflow UI ``` -3. **Header Injection** +3. **Authenticated Access** ```txt - DescopeHeaderProvider.request_headers() called - → Decodes JWT token (without validation) - → Extracts user info (sub, email, roles, etc.) - → Returns headers: X-Descope-User-ID, X-Descope-Email, etc. + Browser → GET / (with DS cookie) + → before_request hook validates session + → Sets user info in Flask g + → MLflow UI loads normally ``` -4. **Run Tagging** +4. **Token Refresh** ```txt - mlflow.start_run() called - → DescopeContextProvider.in_context() → True (if token present) - → DescopeContextProvider.tags() called - → Returns tags: descope.user_id, descope.email, etc. - → Tags automatically added to run + Request comes in with expired session token + → validate_and_refresh_session() refreshes token + → after_request hook updates DS cookie + → User continues seamlessly ``` ### Key Components -#### 1. `auth_provider.py` - Request Authentication - -```python -class DescopeAuthProvider(RequestAuthProvider): - def get_name(self) -> str: - return "descope" - - def get_auth(self) -> Callable: - # Returns function that adds Bearer token to requests - token = os.environ.get("DESCOPE_SESSION_TOKEN") - return lambda: ("Bearer", token) -``` - -#### 2. `header_provider.py` - Request Headers +#### 1. `server.py` - Flask App Factory ```python -class DescopeHeaderProvider(RequestHeaderProvider): - def in_context(self) -> bool: - return bool(os.environ.get("DESCOPE_SESSION_TOKEN")) +def create_app(app: Flask = None) -> Flask: + """MLflow app factory entry point.""" + if app is None: + from mlflow.server import app as mlflow_app + app = mlflow_app - def request_headers(self) -> Dict[str, str]: - # Decode JWT and return user info as headers - return { - "X-Descope-User-ID": user_id, - "X-Descope-Email": email, - # ... - } + register_auth_routes(app) + app.before_request(_before_request) + app.after_request(_after_request) + return app ``` -#### 3. `context_provider.py` - Run Tagging +#### 2. `auth_routes.py` - Authentication Endpoints -```python -class DescopeContextProvider(RunContextProvider): - def in_context(self) -> bool: - return bool(os.environ.get("DESCOPE_SESSION_TOKEN")) - - def tags(self) -> Dict[str, str]: - # Return tags to add to every run - return { - "descope.user_id": user_id, - "descope.email": email, - # ... - } -``` +- `/auth/login` - Login page with Descope Web Component +- `/auth/logout` - Clears cookies, redirects to login +- `/auth/user` - Returns current user info as JSON +- `/health` - Health check endpoint -#### 4. `client.py` - Descope SDK Wrapper +#### 3. `client.py` - Descope SDK Wrapper - Session validation with Descope API - Token refresh handling -- User info extraction from validated tokens +- User claims extraction from validated tokens -#### 5. `config.py` - Configuration +#### 4. `config.py` - Configuration - Environment variable management +- Cookie settings - Default values and validation ### Why This Architecture? -1. **Non-Invasive** - - Server runs unmodified - - All logic is client-side - - No middleware, no wrapping +1. **Server-Side Security** + - Tokens stored in HttpOnly cookies (not accessible to JavaScript) + - Server validates every request + - Automatic token refresh 2. **Standard MLflow Integration** - - Uses official plugin entry points - - Works with any MLflow server - - Future-proof against MLflow updates + - Uses official `mlflow.app` entry point + - Works with stock MLflow server + - No custom server deployment needed 3. **Minimal Dependencies** - Only `descope` SDK required - - No FastAPI, no ASGI/WSGI complexity + - Uses Flask (already part of MLflow) + - No additional frameworks 4. **Simple Configuration** - Just environment variables - No config files needed - Works in any environment (local, Docker, K8s) -### Comparison with Other Approaches +### Cookie Details + +| Cookie | Purpose | HttpOnly | Secure | +|--------|---------|----------|--------| +| `DS` | Session token | Yes | Configurable | +| `DSR` | Refresh token | Yes | Configurable | -| Aspect | Simple Plugin (This) | Server Wrapper Approach | -| ------------------- | ------------------------- | --------------------------- | -| **Where it runs** | Client-side | Server-side | -| **Server changes** | None | Wraps entire server | -| **Complexity** | ~300 LOC | ~2000+ LOC | -| **Dependencies** | descope SDK only | FastAPI, ASGI, middleware | -| **MLflow version** | Any | May break on updates | -| **Deployment** | pip install | Custom server setup | +Set `DESCOPE_COOKIE_SECURE=true` in production (HTTPS). ### Security Considerations -- Tokens are passed via environment variables (not committed to code) -- JWT decoding in header/context providers is for extracting claims only -- Actual token validation should happen server-side -- For server-side validation, use MLflow's built-in auth or a reverse proxy +- Cookies are HttpOnly (not accessible to JavaScript XSS attacks) +- Session tokens are validated on every request +- Refresh tokens enable seamless token renewal +- No tokens stored in browser localStorage/sessionStorage diff --git a/README.md b/README.md index 102e081..7f3bb1d 100644 --- a/README.md +++ b/README.md @@ -7,19 +7,19 @@ A simple, standards-compliant authentication plugin for [MLflow](https://mlflow. ## Features -✨ **Simple Plugin Architecture** - Uses MLflow's standard plugin system -🔐 **Descope Authentication** - Secure token-based authentication -🏷️ **Automatic Tagging** - User context automatically added to MLflow runs -📊 **Request Headers** - User info propagated via HTTP headers -⚡ **Zero Configuration** - Works with environment variables +🔐 **Descope Authentication** - Secure token-based authentication with auto-refresh +🌐 **Browser Login UI** - Built-in login page with Descope Web Component +🍪 **Cookie-Based Sessions** - Secure, HttpOnly cookies with automatic refresh ## How It Works -This plugin integrates with MLflow via three standard plugin types: +This plugin integrates with MLflow via the `mlflow.app` entry point, providing server-side authentication with a browser-based login UI. -1. **Request Auth Provider** - Adds Descope authentication to MLflow API requests -2. **Request Header Provider** - Injects user context into request headers -3. **Run Context Provider** - Automatically tags runs with user information +``` +Browser → /auth/login → Descope Web Component → Cookie Set → MLflow UI + ↑ + before_request hook validates cookies +``` ## Installation @@ -35,200 +35,88 @@ cd mlflow-descope-auth pip install -e . ``` -## Configuration - -### 1. Set Up Descope - -1. Sign up at [descope.com](https://www.descope.com/) -2. Create a new project -3. Copy your Project ID (starts with `P2`) -4. Create authentication flow in Descope Console - -### 2. Get Authentication Tokens +## Quick Start -Authenticate with Descope and obtain session tokens. You can use: +1. **Set environment variables**: + ```bash + export DESCOPE_PROJECT_ID="P2XXXXX" + ``` -- [Descope Web SDK](https://docs.descope.com/build/guides/client_sdks/web/) -- [Descope Python SDK](https://docs.descope.com/build/guides/client_sdks/python/) -- Descope API directly +2. **Start MLflow with Descope authentication**: + ```bash + mlflow server --app-name descope --host 0.0.0.0 --port 5000 + ``` -Example using Python SDK: - -```python -from descope import DescopeClient - -descope_client = DescopeClient(project_id="P2XXXXX") - -# Authenticate user (example with magic link) -response = descope_client.magiclink.sign_in_or_up( - method="email", - login_id="user@example.com" -) - -# Extract session token -session_token = response["sessionToken"]["jwt"] -print(f"export DESCOPE_SESSION_TOKEN='{session_token}'") -``` +3. **Access MLflow UI**: Open `http://localhost:5000` in your browser. You'll be redirected to the login page. -### 3. Set Environment Variables +4. **Logout**: Visit `/auth/logout` to clear your session. -```bash -# Required -export DESCOPE_PROJECT_ID="P2XXXXX" -export DESCOPE_SESSION_TOKEN="" - -# Optional (with defaults) -export DESCOPE_ADMIN_ROLES="admin,mlflow-admin" -export DESCOPE_DEFAULT_PERMISSION="READ" -export DESCOPE_USERNAME_CLAIM="sub" # or "email" -``` - -### 4. Enable the Plugin - -```bash -# Set MLflow tracking authentication -export MLFLOW_TRACKING_AUTH=descope - -# Start MLflow server -mlflow server --host 0.0.0.0 --port 5000 -``` - -## Usage +## Configuration Reference -Once configured, the plugin works automatically: +| Variable | Required | Default | Description | +| ---------------------------- | -------- | -------------------- | ----------------------------------------------- | +| `DESCOPE_PROJECT_ID` | ✅ Yes | - | Your Descope Project ID | +| `DESCOPE_FLOW_ID` | ❌ No | `sign-up-or-in` | Descope flow ID for login | +| `DESCOPE_ADMIN_ROLES` | ❌ No | `admin,mlflow-admin` | Comma-separated list of admin roles | +| `DESCOPE_DEFAULT_PERMISSION` | ❌ No | `READ` | Default permission level (READ/EDIT/MANAGE) | +| `DESCOPE_USERNAME_CLAIM` | ❌ No | `sub` | JWT claim to use as username (`sub` or `email`) | +| `DESCOPE_COOKIE_SECURE` | ❌ No | `false` | Enable secure cookies (set `true` for HTTPS) | -```python -import mlflow +## Plugin Entry Point -# Plugin automatically adds authentication to requests -mlflow.set_tracking_uri("http://localhost:5000") +This plugin registers one MLflow entry point: -# Start a run - user context automatically added -with mlflow.start_run(): - mlflow.log_param("alpha", 0.5) - mlflow.log_metric("rmse", 0.8) +```toml +[project.entry-points."mlflow.app"] +descope = "mlflow_descope_auth.server:create_app" ``` -### Automatic Run Tagging - -The plugin automatically adds these tags to all runs: +## Architecture -- `descope.user_id` - User's Descope ID -- `descope.username` - Username -- `descope.email` - User's email -- `descope.name` - User's display name -- `descope.roles` - Comma-separated list of roles -- `descope.permissions` - Comma-separated list of permissions -- `descope.tenants` - Comma-separated list of tenants +### Components -### Request Headers +- **`server.py`** - Flask app factory for `mlflow.app` entry point +- **`auth_routes.py`** - Login, logout, and user info endpoints +- **`client.py`** - Descope SDK wrapper for token validation +- **`config.py`** - Configuration management +- **`store.py`** - User store adapter (optional, for advanced use cases) -The plugin adds these headers to MLflow API requests: +### Authentication Flow -- `X-Descope-User-ID` -- `X-Descope-Username` -- `X-Descope-Email` -- `X-Descope-Roles` -- `X-Descope-Permissions` -- `X-Descope-Tenants` +1. User visits MLflow UI +2. `before_request` hook checks for valid session cookie (`DS`) +3. If no valid session → redirect to `/auth/login` +4. User authenticates via Descope Web Component +5. On success, session cookies (`DS`, `DSR`) are set +6. User redirected back to MLflow UI +7. `after_request` hook refreshes cookies if token was refreshed ## Development -### Setup with mise +### Setup ```bash -# Install mise -curl https://mise.run | sh - # Clone and setup git clone https://github.com/descope/mlflow-descope-auth.git cd mlflow-descope-auth -# Install dependencies -mise run install +# Install with uv +uv sync -# Verify plugin is registered -mise run verify-plugin -``` - -### Available Tasks +# Run tests +uv run pytest tests/ -v -```bash -# Development -mise run dev # Start MLflow server with plugin enabled -mise run demo # Run demo tracking session (requires tokens) - -# Testing -mise run test # Run tests with coverage -mise run test-quick # Run tests without coverage -mise run verify-plugin # Verify plugin entry points registered - -# Code Quality -mise run lint # Check code style -mise run format # Format code -mise run check # Run all checks (lint + format + type) -mise run fix # Auto-fix all issues - -# Maintenance -mise run clean # Remove generated files -mise run pre-commit-install # Install git hooks -mise run pre-commit # Run hooks on all files -mise run ci # Run full CI pipeline +# Lint and format +uv run ruff check mlflow_descope_auth tests --fix +uv run ruff format mlflow_descope_auth tests ``` -## Plugin Entry Points - -This plugin registers three MLflow entry points: +### Verify Plugin Registration -```toml -[project.entry-points."mlflow.request_auth_provider"] -descope = "mlflow_descope_auth.auth_provider:DescopeAuthProvider" - -[project.entry-points."mlflow.request_header_provider"] -descope = "mlflow_descope_auth.header_provider:DescopeHeaderProvider" - -[project.entry-points."mlflow.run_context_provider"] -descope = "mlflow_descope_auth.context_provider:DescopeContextProvider" -``` - -## Configuration Reference - -| Variable | Required | Default | Description | -| ---------------------------- | -------- | -------------------- | ----------------------------------------------- | -| `DESCOPE_PROJECT_ID` | ✅ Yes | - | Your Descope Project ID | -| `DESCOPE_SESSION_TOKEN` | ✅ Yes | - | Current session JWT token | -| `DESCOPE_ADMIN_ROLES` | ❌ No | `admin,mlflow-admin` | Comma-separated list of admin roles | -| `DESCOPE_DEFAULT_PERMISSION` | ❌ No | `READ` | Default permission level (READ/EDIT/MANAGE) | -| `DESCOPE_USERNAME_CLAIM` | ❌ No | `sub` | JWT claim to use as username (`sub` or `email`) | -| `MLFLOW_TRACKING_AUTH` | ✅ Yes | - | Set to `descope` to enable plugin | - -## Architecture - -### Simple Plugin Design - -This plugin follows MLflow's standard plugin architecture: - -``` -MLflow Client - ↓ -[Auth Provider] ← Adds authentication to requests - ↓ -[Header Provider] ← Injects user context headers - ↓ -[Context Provider] ← Tags runs with user info - ↓ -MLflow Server +```bash +python -c "from importlib.metadata import entry_points; print([ep.name for ep in entry_points(group='mlflow.app')])" ``` -### Components - -- **`auth_provider.py`** - Implements `RequestAuthProvider` for authentication -- **`header_provider.py`** - Implements `RequestHeaderProvider` for headers -- **`context_provider.py`** - Implements `RunContextProvider` for run tagging -- **`client.py`** - Descope SDK wrapper for token validation -- **`config.py`** - Configuration management -- **`store.py`** - User store adapter (optional, for advanced use cases) - ## Troubleshooting ### Plugin Not Loaded @@ -238,51 +126,14 @@ MLflow Server pip list | grep mlflow-descope-auth # Check entry points -python -c "import pkg_resources; print([ep for ep in pkg_resources.iter_entry_points('mlflow.request_auth_provider')])" -``` - -### Authentication Fails - -```bash -# Check environment variables -env | grep DESCOPE - -# Verify tokens are valid -python -c " -from mlflow_descope_auth import get_descope_client -import os -client = get_descope_client() -result = client.validate_session(os.environ['DESCOPE_SESSION_TOKEN']) -print('✓ Token valid') -" +python -c "from importlib.metadata import entry_points; print([ep.name for ep in entry_points(group='mlflow.app')])" ``` -### Enable Debug Logging +### Cookie Issues -```bash -export MLFLOW_TRACKING_INSECURE_TLS=true # For development only -export PYTHONWARNINGS=default - -python -c " -import logging -logging.basicConfig(level=logging.DEBUG) -import mlflow -mlflow.set_tracking_uri('http://localhost:5000') -mlflow.search_runs() -" -``` - -## Comparison: Simple Plugin vs Full App - -| Feature | Simple Plugin (This) | Full App Approach | -| -------------------------- | -------------------- | ---------------------- | -| **Architecture** | Extends MLflow | Wraps MLflow | -| **Complexity** | Low | High | -| **Dependencies** | Minimal | FastAPI, Uvicorn, etc. | -| **MLflow Compatibility** | Standard | Custom | -| **Maintenance** | Easy | Complex | -| **Integration** | Plugin system | App replacement | -| **Setup** | Environment vars | Config files + UI | +- Ensure `DESCOPE_COOKIE_SECURE=true` when using HTTPS +- Check browser dev tools for cookie presence (`DS`, `DSR`) +- Verify cookies are not being blocked by browser settings ## Contributing @@ -294,17 +145,10 @@ Apache License 2.0 - see [LICENSE](LICENSE) file for details. ## Support -- **Documentation**: [GitHub Wiki](https://github.com/descope/mlflow-descope-auth/wiki) - **Issues**: [GitHub Issues](https://github.com/descope/mlflow-descope-auth/issues) - **Descope Docs**: [docs.descope.com](https://docs.descope.com/) - **MLflow Docs**: [mlflow.org/docs](https://mlflow.org/docs/latest/index.html) -## Related Projects - -- [MLflow](https://mlflow.org/) - Open source platform for the ML lifecycle -- [Descope](https://www.descope.com/) - Authentication and user management -- [MLflow Plugin Guide](https://mlflow.org/docs/latest/plugins.html) - Official plugin documentation - --- Made with ❤️ by the MLflow Descope Auth Contributors diff --git a/archive/app.py b/archive/app.py deleted file mode 100644 index 33d5175..0000000 --- a/archive/app.py +++ /dev/null @@ -1,94 +0,0 @@ -"""FastAPI application factory for MLflow with Descope authentication.""" - -import logging - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - -# Import MLflow Flask app -from mlflow.server import app as mlflow_flask_app - -from .auth_routes import router as auth_router -from .client import get_descope_client -from .config import get_config -from .middleware import AuthenticationMiddleware - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -def create_app() -> FastAPI: - """Create and configure the FastAPI application. - - This function: - 1. Initializes the Descope client - 2. Creates a FastAPI application - 3. Adds authentication middleware - 4. Registers authentication routes - 5. Configures CORS if needed - - Returns: - FastAPI: The configured FastAPI application. - """ - # Load configuration - config = get_config() - - logger.info("Initializing MLflow Descope Auth plugin") - logger.info(f"Descope Project ID: {config.DESCOPE_PROJECT_ID}") - logger.info(f"Flow ID: {config.DESCOPE_FLOW_ID}") - - # Initialize Descope client - descope_client = get_descope_client() - - # Create FastAPI app - app = FastAPI( - title="MLflow with Descope Authentication", - description="MLflow Tracking Server with Descope Flow-based authentication", - version="0.1.0", - docs_url="/docs" if config.DESCOPE_PROJECT_ID else None, # Disable docs in prod - ) - - # Add CORS middleware (configure as needed) - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Configure this based on your needs - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Add authentication middleware - app.add_middleware(AuthenticationMiddleware, descope_client=descope_client) - - # Include authentication routes - app.include_router(auth_router) - - # Mount MLflow Flask app at root with auth-aware middleware - from .wsgi_middleware import AuthAwareWSGIMiddleware - - app.mount("/", AuthAwareWSGIMiddleware(mlflow_flask_app)) - logger.info("Mounted MLflow Flask app at / with auth-aware WSGI middleware") - - logger.info("MLflow Descope Auth plugin initialized successfully") - - return app - - -# Create the app instance for MLflow plugin entry point -app = create_app() - - -# This function is called by MLflow when using: mlflow server --app-name descope-auth -def get_app(): - """Get the FastAPI application instance. - - This is the entry point for MLflow plugin system. - - Returns: - FastAPI: The configured application. - """ - return app diff --git a/archive/auth_routes.py b/archive/auth_routes.py deleted file mode 100644 index 0a3b2c8..0000000 --- a/archive/auth_routes.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Authentication routes for Descope Flow-based authentication.""" - -import logging -from pathlib import Path - -from fastapi import APIRouter, Request -from fastapi.responses import HTMLResponse, RedirectResponse -from fastapi.templating import Jinja2Templates - -from .config import get_config - -logger = logging.getLogger(__name__) - -# Initialize router -router = APIRouter(prefix="/auth", tags=["authentication"]) - -# Initialize Jinja2 templates -templates_dir = Path(__file__).parent / "templates" -templates = Jinja2Templates(directory=str(templates_dir)) - - -@router.get("/login", response_class=HTMLResponse) -async def login(request: Request): - """Render login page with Descope Flow web component. - - This endpoint serves an HTML page that embeds the Descope web component. - The component loads the authentication flow configured in Descope Console. - - Args: - request: The incoming HTTP request. - - Returns: - HTMLResponse: Login page with embedded Descope web component. - """ - config = get_config() - - logger.info(f"Rendering login page with flow: {config.DESCOPE_FLOW_ID}") - - return templates.TemplateResponse( - "login.html", - { - "request": request, - "project_id": config.DESCOPE_PROJECT_ID, - "flow_id": config.DESCOPE_FLOW_ID, - "base_url": config.DESCOPE_BASE_URL, - "redirect_url": config.DESCOPE_REDIRECT_URL, - "web_component_url": config.web_component_url, - "session_cookie_name": config.SESSION_COOKIE_NAME, - "refresh_cookie_name": config.REFRESH_COOKIE_NAME, - }, - ) - - -@router.get("/logout") -async def logout(request: Request): - """Logout endpoint - clears authentication cookies. - - This endpoint clears the session and refresh token cookies, effectively - logging the user out. It then redirects to the login page. - - Args: - request: The incoming HTTP request. - - Returns: - RedirectResponse: Redirects to login page after clearing cookies. - """ - config = get_config() - - logger.info("User logging out") - - # Create redirect response - response = RedirectResponse(url="/auth/login", status_code=302) - - # Clear authentication cookies - response.delete_cookie(config.SESSION_COOKIE_NAME, path="/") - response.delete_cookie(config.REFRESH_COOKIE_NAME, path="/") - - return response - - -@router.get("/user") -async def get_current_user(request: Request): - """Get current authenticated user information. - - This endpoint returns information about the currently authenticated user - based on the request state set by the authentication middleware. - - Args: - request: The incoming HTTP request. - - Returns: - dict: User information including username, roles, and permissions. - """ - # Check if user is authenticated (set by middleware) - if not hasattr(request.state, "username"): - return {"error": "Not authenticated"}, 401 - - return { - "username": request.state.username, - "email": getattr(request.state, "email", None), - "name": getattr(request.state, "name", None), - "roles": getattr(request.state, "roles", []), - "permissions": getattr(request.state, "permissions", []), - "is_admin": getattr(request.state, "is_admin", False), - } - - -@router.get("/health") -async def health_check(): - """Health check endpoint for monitoring. - - Returns: - dict: Health status of the authentication service. - """ - config = get_config() - - return { - "status": "healthy", - "service": "mlflow-descope-auth", - "project_id": config.DESCOPE_PROJECT_ID, - "flow_id": config.DESCOPE_FLOW_ID, - } - - -@router.get("/config") -async def get_auth_config(): - """Get public authentication configuration. - - Returns non-sensitive configuration information that can be used - by frontend clients. - - Returns: - dict: Public configuration including project ID and flow ID. - """ - config = get_config() - - return { - "project_id": config.DESCOPE_PROJECT_ID, - "flow_id": config.DESCOPE_FLOW_ID, - "web_component_version": config.DESCOPE_WEB_COMPONENT_VERSION, - "redirect_url": config.DESCOPE_REDIRECT_URL, - } diff --git a/archive/middleware.py b/archive/middleware.py deleted file mode 100644 index 8b487bd..0000000 --- a/archive/middleware.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Authentication middleware for session validation.""" - -import logging -from typing import Callable - -from descope import AuthException -from fastapi import Request, Response -from fastapi.responses import RedirectResponse -from starlette.middleware.base import BaseHTTPMiddleware - -from .client import get_descope_client -from .config import get_config - -logger = logging.getLogger(__name__) - - -class AuthenticationMiddleware(BaseHTTPMiddleware): - """Middleware to validate Descope session on each request. - - This middleware: - 1. Checks if the request path requires authentication - 2. Validates session tokens using Descope SDK - 3. Automatically refreshes expired tokens - 4. Attaches user information to request state - 5. Redirects to login if authentication fails - """ - - def __init__(self, app, descope_client=None): - """Initialize the authentication middleware. - - Args: - app: The FastAPI application. - descope_client: Optional Descope client (for testing). - """ - super().__init__(app) - self.descope_client = descope_client or get_descope_client() - self.config = get_config() - - # Public routes that don't require authentication - self.public_routes = { - "/auth/login", - "/auth/logout", - "/auth/health", - "/auth/config", - "/health", - "/docs", - "/openapi.json", - "/redoc", - } - - def _is_public_route(self, path: str) -> bool: - """Check if a route is public (doesn't require authentication). - - Args: - path: The request path. - - Returns: - bool: True if the route is public, False otherwise. - """ - # Check exact matches - if path in self.public_routes: - return True - - # Check if path starts with public prefixes - public_prefixes = ["/static/", "/_static/"] - return any(path.startswith(prefix) for prefix in public_prefixes) - - async def dispatch(self, request: Request, call_next: Callable) -> Response: - """Process each request through authentication middleware. - - Args: - request: The incoming HTTP request. - call_next: The next middleware or route handler. - - Returns: - Response: The HTTP response. - """ - path = request.url.path - - # Skip authentication for public routes - if self._is_public_route(path): - return await call_next(request) - - # Get tokens from cookies - session_token = request.cookies.get(self.config.SESSION_COOKIE_NAME) - refresh_token = request.cookies.get(self.config.REFRESH_COOKIE_NAME) - - # If no session token, redirect to login - if not session_token: - logger.debug(f"No session token found for path: {path}") - return RedirectResponse(url="/auth/login", status_code=302) - - # Validate session with Descope - try: - jwt_response = self.descope_client.validate_session(session_token, refresh_token) - - # Extract user claims - claims = self.descope_client.extract_user_claims(jwt_response) - - # Attach user information to request state - request.state.username = claims["username"] - request.state.email = claims["email"] - request.state.name = claims["name"] - request.state.roles = claims["roles"] - request.state.permissions = claims["permissions"] - request.state.user_id = claims["user_id"] - request.state.tenants = claims["tenants"] - - # Check if user is admin - request.state.is_admin = self.config.is_admin_role(claims["roles"]) - - logger.debug( - f"Authenticated user: {claims['username']} " - f"(admin: {request.state.is_admin}) for path: {path}" - ) - - # Process the request - response = await call_next(request) - - # Check if token was refreshed and update cookie - if jwt_response.get("cookieData"): - session_jwt = jwt_response["sessionToken"]["jwt"] - response.set_cookie( - key=self.config.SESSION_COOKIE_NAME, - value=session_jwt, - max_age=3600, # 1 hour - path="/", - secure=True, - httponly=True, - samesite="strict", - ) - logger.debug(f"Updated session cookie for user: {claims['username']}") - - return response - - except AuthException as e: - # Session validation failed, redirect to login - logger.warning(f"Session validation failed: {e}") - return RedirectResponse(url="/auth/login", status_code=302) - - except Exception as e: - # Unexpected error - logger.error(f"Unexpected error in authentication middleware: {e}", exc_info=True) - return RedirectResponse(url="/auth/login?error=internal_error", status_code=302) - - -def require_permission(permission: str): - """Decorator to require specific permission for a route. - - Args: - permission: The required permission name. - - Returns: - Callable: The decorator function. - """ - - def decorator(func): - async def wrapper(request: Request, *args, **kwargs): - if not hasattr(request.state, "permissions"): - return {"error": "Not authenticated"}, 401 - - if permission not in request.state.permissions: - return {"error": f"Missing required permission: {permission}"}, 403 - - return await func(request, *args, **kwargs) - - return wrapper - - return decorator - - -def require_role(role: str): - """Decorator to require specific role for a route. - - Args: - role: The required role name. - - Returns: - Callable: The decorator function. - """ - - def decorator(func): - async def wrapper(request: Request, *args, **kwargs): - if not hasattr(request.state, "roles"): - return {"error": "Not authenticated"}, 401 - - if role not in request.state.roles: - return {"error": f"Missing required role: {role}"}, 403 - - return await func(request, *args, **kwargs) - - return wrapper - - return decorator - - -def require_admin(func): - """Decorator to require admin role for a route. - - Args: - func: The route handler function. - - Returns: - Callable: The wrapped function. - """ - - async def wrapper(request: Request, *args, **kwargs): - if not hasattr(request.state, "is_admin") or not request.state.is_admin: - return {"error": "Admin access required"}, 403 - - return await func(request, *args, **kwargs) - - return wrapper diff --git a/archive/wsgi_middleware.py b/archive/wsgi_middleware.py deleted file mode 100644 index dc331f4..0000000 --- a/archive/wsgi_middleware.py +++ /dev/null @@ -1,80 +0,0 @@ -"""WSGI middleware to pass FastAPI auth info to MLflow Flask app.""" - -import asyncio -import logging - -from asgiref.wsgi import WsgiToAsgi -from starlette.types import Receive, Scope, Send - -logger = logging.getLogger(__name__) - - -class AuthInjectingWSGIApp: - """WSGI app wrapper that injects FastAPI auth info into Flask environ. - - This bridges FastAPI authentication with MLflow's Flask app by injecting - user information into the WSGI environ dict. - """ - - def __init__(self, flask_app, scope: Scope): - self.flask_app = flask_app - self.scope = scope - - def __call__(self, environ, start_response): - """Inject auth info from ASGI scope into WSGI environ.""" - # Extract auth info from request state (set by AuthenticationMiddleware) - state = self.scope.get("state", {}) - - username = state.get("username") - is_admin = state.get("is_admin", False) - email = state.get("email") - roles = state.get("roles", []) - permissions = state.get("permissions", []) - - if username: - logger.debug(f"Injecting auth into WSGI environ: user={username}, admin={is_admin}") - # Inject auth info for MLflow to access - environ["mlflow_descope_auth.username"] = username - environ["mlflow_descope_auth.is_admin"] = is_admin - environ["mlflow_descope_auth.email"] = email or username - environ["mlflow_descope_auth.roles"] = ",".join(roles) - environ["mlflow_descope_auth.permissions"] = ",".join(permissions) - - # Also set standard REMOTE_USER for compatibility - environ["REMOTE_USER"] = username - - return self.flask_app(environ, start_response) - - -class AuthAwareWSGIMiddleware: - """ASGI middleware that passes FastAPI auth to MLflow Flask app. - - This middleware: - 1. Extracts authentication info from ASGI scope - 2. Wraps the Flask app to inject auth into WSGI environ - 3. Converts ASGI to WSGI using asgiref - """ - - def __init__(self, flask_app): - self.flask_app = flask_app - logger.info("Initialized AuthAwareWSGIMiddleware") - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] == "http": - # Create auth-injecting wrapper for this request - auth_injecting_app = AuthInjectingWSGIApp(self.flask_app, scope) - - # Use asgiref to convert WSGI to ASGI - wsgi_adapter = WsgiToAsgi(auth_injecting_app) - await wsgi_adapter(scope, receive, send) - else: - # For non-HTTP requests (websocket/lifespan) - if callable(self.flask_app): - result = self.flask_app(scope, receive, send) - if asyncio.iscoroutine(result) or asyncio.isfuture(result): - await result - return - - # Fallback to WSGI adapter - adapter = WsgiToAsgi(self.flask_app) - await adapter(scope, receive, send) diff --git a/docs/QUICKSTART.md b/docs/QUICKSTART.md index 568b4b3..09e0a39 100644 --- a/docs/QUICKSTART.md +++ b/docs/QUICKSTART.md @@ -15,87 +15,36 @@ Get MLflow running with Descope authentication in 5 minutes! 3. Give it a name (e.g., "MLflow Auth") 4. Copy your **Project ID** (starts with `P2`) -## Step 2: Get Authentication Tokens - -Authenticate with Descope to obtain session tokens. You can use the Descope Python SDK: - -```python -from descope import DescopeClient - -descope_client = DescopeClient(project_id="P2XXXXX") - -# Authenticate user (example with magic link) -response = descope_client.magiclink.sign_in_or_up( - method="email", - login_id="user@example.com" -) - -# Extract session token -session_token = response["sessionToken"]["jwt"] -print(f"export DESCOPE_SESSION_TOKEN='{session_token}'") -``` - -Or use any other Descope authentication method (OAuth, SAML, etc.) via: -- [Descope Web SDK](https://docs.descope.com/build/guides/client_sdks/web/) -- [Descope Python SDK](https://docs.descope.com/build/guides/client_sdks/python/) -- Descope API directly - -> **Note**: Token refresh is handled automatically by the server, not via environment variables. - -## Step 3: Install MLflow Descope Auth +## Step 2: Install MLflow Descope Auth ```bash pip install mlflow-descope-auth ``` -## Step 4: Configure Environment - -Set the required environment variables: +## Step 3: Start MLflow with Descope Authentication ```bash -# Required export DESCOPE_PROJECT_ID="P2XXXXX" -export DESCOPE_SESSION_TOKEN="" - -# Enable the plugin -export MLFLOW_TRACKING_AUTH=descope - -# Optional (with defaults) -export DESCOPE_ADMIN_ROLES="admin,mlflow-admin" -export DESCOPE_DEFAULT_PERMISSION="READ" -export DESCOPE_USERNAME_CLAIM="sub" # or "email" +mlflow server --app-name descope --host 0.0.0.0 --port 5000 ``` -## Step 5: Use MLflow +## Step 4: Access MLflow -Once configured, the plugin works automatically with any MLflow client: +1. Open your browser: `http://localhost:5000` +2. You'll be redirected to the Descope login page +3. Sign in with your Descope account +4. After authentication, you'll be redirected to the MLflow UI -```python -import mlflow - -# Set tracking URI to your MLflow server -mlflow.set_tracking_uri("http://localhost:5000") - -# Start a run - authentication and user context are automatic! -with mlflow.start_run(): - mlflow.log_param("alpha", 0.5) - mlflow.log_metric("rmse", 0.8) -``` - -The plugin automatically: -- Adds authentication headers to all requests -- Injects user context headers (X-Descope-User-ID, X-Descope-Email, etc.) -- Tags runs with user information (descope.user_id, descope.email, etc.) +That's it! You now have a secure MLflow server with Descope authentication. ## What's Next? ### Verify Plugin is Loaded ```bash -# Check that the plugin is registered python -c " from importlib.metadata import entry_points -eps = entry_points(group='mlflow.request_auth_provider') +eps = entry_points(group='mlflow.app') print([ep.name for ep in eps]) " # Should include 'descope' @@ -106,7 +55,6 @@ print([ep.name for ep in eps]) By default, all authenticated users have READ access. To grant admin privileges: 1. Set admin roles in your environment: - ```bash export DESCOPE_ADMIN_ROLES="admin,mlflow-admin" ``` @@ -115,105 +63,40 @@ By default, all authenticated users have READ access. To grant admin privileges: 3. Create a role called `admin` or `mlflow-admin` 4. Assign this role to specific users -### Configure Permissions - -```bash -# Default permission for all users (READ, EDIT, or MANAGE) -export DESCOPE_DEFAULT_PERMISSION="READ" - -# Roles that get admin access (MANAGE permission) -export DESCOPE_ADMIN_ROLES="admin,mlflow-admin,superuser" -``` - -### Automatic Run Tagging - -The plugin automatically adds these tags to all runs: - -- `descope.user_id` - User's Descope ID -- `descope.username` - Username -- `descope.email` - User's email -- `descope.name` - User's display name -- `descope.roles` - Comma-separated list of roles -- `descope.permissions` - Comma-separated list of permissions -- `descope.tenants` - Comma-separated list of tenants +### Logout -### Request Headers - -The plugin adds these headers to MLflow API requests: - -- `X-Descope-User-ID` -- `X-Descope-Username` -- `X-Descope-Email` -- `X-Descope-Roles` -- `X-Descope-Permissions` -- `X-Descope-Tenants` +Visit `/auth/logout` to clear your session and redirect to login. ## Troubleshooting -### Plugin not loaded - -**Problem**: Authentication doesn't work - -**Solution**: +### Server won't start ```bash # Verify plugin is installed pip list | grep mlflow-descope-auth -# Check entry points -python -c " -from importlib.metadata import entry_points -eps = entry_points(group='mlflow.request_auth_provider') -print([ep.name for ep in eps]) -" - -# Ensure MLFLOW_TRACKING_AUTH is set -echo $MLFLOW_TRACKING_AUTH # Should be "descope" +# Check DESCOPE_PROJECT_ID is set +echo $DESCOPE_PROJECT_ID ``` -### Authentication fails - -**Problem**: Requests fail with authentication errors - -**Solution**: +### Login page doesn't appear +Make sure you're using `--app-name descope`: ```bash -# Check environment variables -env | grep DESCOPE - -# Verify tokens are valid -python -c " -from mlflow_descope_auth import get_descope_client -import os -client = get_descope_client() -result = client.validate_session(os.environ['DESCOPE_SESSION_TOKEN']) -print('✓ Token valid') -" +mlflow server --app-name descope --port 5000 ``` -### "DESCOPE_PROJECT_ID is required" error - -**Problem**: Plugin fails to initialize - -**Solution**: +### Authentication fails ```bash -# Set the required environment variable -export DESCOPE_PROJECT_ID=P2XXXXX -export DESCOPE_SESSION_TOKEN="" +# Check environment variables +env | grep DESCOPE ``` ## Need Help? - **Documentation**: See [README.md](../README.md) for full documentation -- **Architecture**: See [ARCHITECTURE.md](../ARCHITECTURE.md) for technical details - **Issues**: [GitHub Issues](https://github.com/descope/mlflow-descope-auth/issues) - **Descope Docs**: [docs.descope.com](https://docs.descope.com/) -## Next Steps - -- Read the [Configuration Reference](../README.md#configuration-reference) -- Learn about the [Plugin Architecture](../ARCHITECTURE.md) -- Check out the [Docker Examples](../examples/) - -Happy MLflow-ing with Descope! 🚀 +Happy MLflow-ing with Descope! diff --git a/examples/basic_usage.py b/examples/basic_usage.py index fd8721c..828dfb3 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -1,26 +1,22 @@ -"""Basic usage example for MLflow Descope Auth plugin.""" +"""Basic usage example for MLflow Descope Auth plugin. + +The recommended way to run MLflow with Descope authentication is: + + export DESCOPE_PROJECT_ID="P2XXXXX" + mlflow server --app-name descope --host 0.0.0.0 --port 5000 + +This example shows how to programmatically create and run the app. +""" import os -# Set environment variables (alternatively use .env file) os.environ["DESCOPE_PROJECT_ID"] = "P2XXXXX" # Replace with your project ID -os.environ["DESCOPE_FLOW_ID"] = "sign-up-or-in" -os.environ["DESCOPE_ADMIN_ROLES"] = "admin,mlflow-admin" -# Import and run the app from mlflow_descope_auth import create_app app = create_app() if __name__ == "__main__": - import uvicorn - print("Starting MLflow with Descope authentication...") print("Visit http://localhost:5000 to login") - - uvicorn.run( - app, - host="0.0.0.0", - port=5000, - log_level="info", - ) + app.run(host="0.0.0.0", port=5000) diff --git a/examples/custom_config.py b/examples/custom_config.py index da46bd6..256d412 100644 --- a/examples/custom_config.py +++ b/examples/custom_config.py @@ -1,34 +1,35 @@ -"""Example of custom configuration for MLflow Descope Auth.""" +"""Example of custom configuration for MLflow Descope Auth. + +The recommended way to run MLflow with Descope authentication is: + + export DESCOPE_PROJECT_ID="P2XXXXX" + export DESCOPE_ADMIN_ROLES="admin,superuser" + mlflow server --app-name descope --host 0.0.0.0 --port 5000 + +This example shows how to programmatically configure and run the app. +""" + +import os + +os.environ["DESCOPE_PROJECT_ID"] = "P2XXXXX" # Replace with your project ID from mlflow_descope_auth import Config, create_app, set_config -# Create custom configuration config = Config( - DESCOPE_PROJECT_ID="P2XXXXX", # Replace with your project ID + DESCOPE_PROJECT_ID="P2XXXXX", DESCOPE_FLOW_ID="sign-up-or-in", - DESCOPE_REDIRECT_URL="/experiments", # Custom redirect - ADMIN_ROLES=["admin", "superuser", "mlflow-admin"], # Custom admin roles - DEFAULT_PERMISSION="EDIT", # More permissive default (READ, EDIT, or MANAGE) - USERNAME_CLAIM="email", # Use email as username instead of 'sub' + DESCOPE_REDIRECT_URL="/experiments", + ADMIN_ROLES=["admin", "superuser", "mlflow-admin"], + DEFAULT_PERMISSION="EDIT", + USERNAME_CLAIM="email", ) -# Set the custom configuration set_config(config) -# Create the app with custom config app = create_app() if __name__ == "__main__": - import uvicorn - print("Starting MLflow with custom Descope configuration...") print(f"Admin roles: {config.ADMIN_ROLES}") print(f"Default permission: {config.DEFAULT_PERMISSION}") - print(f"Redirect URL: {config.DESCOPE_REDIRECT_URL}") - - uvicorn.run( - app, - host="0.0.0.0", - port=5000, - log_level="info", - ) + app.run(host="0.0.0.0", port=5000) diff --git a/examples/docker-compose.postgres.yml b/examples/docker-compose.postgres.yml index a6e1fc6..8d311b0 100644 --- a/examples/docker-compose.postgres.yml +++ b/examples/docker-compose.postgres.yml @@ -25,7 +25,7 @@ services: command: > sh -c " pip install --no-cache-dir mlflow mlflow-descope-auth psycopg2-binary && - mlflow server --app-name descope-auth --host 0.0.0.0 --port 5000 + mlflow server --app-name descope --host 0.0.0.0 --port 5000 " environment: # Descope configuration diff --git a/examples/docker-compose.yml b/examples/docker-compose.yml index 4d6960c..2615194 100644 --- a/examples/docker-compose.yml +++ b/examples/docker-compose.yml @@ -7,7 +7,7 @@ services: command: > sh -c " pip install --no-cache-dir mlflow mlflow-descope-auth && - mlflow server --app-name descope-auth --host 0.0.0.0 --port 5000 + mlflow server --app-name descope --host 0.0.0.0 --port 5000 " environment: # Required: Set your Descope Project ID diff --git a/mise.toml b/mise.toml index d8cbf98..abcffe0 100644 --- a/mise.toml +++ b/mise.toml @@ -60,14 +60,14 @@ description = "Run pre-commit on all files" run = "uv run pre-commit run --all-files" [tasks.dev] -description = "Start MLflow server with Descope plugin enabled" -env = { MLFLOW_TRACKING_AUTH = "descope" } +description = "Start MLflow server with Descope authentication" run = """ -if [ -z "$DESCOPE_SESSION_TOKEN" ]; then - echo "Warning: DESCOPE_SESSION_TOKEN not set" - echo "Plugin will be loaded but authentication will fail without tokens" +if [ -z "$DESCOPE_PROJECT_ID" ]; then + echo "Error: DESCOPE_PROJECT_ID is required" + echo " export DESCOPE_PROJECT_ID='P2XXXXX'" + exit 1 fi -uv run mlflow server --host 0.0.0.0 --port 5000 +uv run mlflow server --app-name descope --host 0.0.0.0 --port 5000 """ [tasks.verify-plugin] @@ -80,46 +80,10 @@ if sys.version_info >= (3, 10): else: from importlib_metadata import entry_points -auth_eps = [ep.name for ep in entry_points(group="mlflow.request_auth_provider")] -header_eps = [ep.name for ep in entry_points(group="mlflow.request_header_provider")] -context_eps = [ep.name for ep in entry_points(group="mlflow.run_context_provider")] - +app_eps = [ep.name for ep in entry_points(group="mlflow.app")] print("MLflow Plugin Entry Points:") -auth_status = "descope" if "descope" in auth_eps else "NOT FOUND" -header_status = "descope" if "descope" in header_eps else "NOT FOUND" -context_status = "descope" if "descope" in context_eps else "NOT FOUND" -print(f" Auth Provider: {auth_status}") -print(f" Header Provider: {header_status}") -print(f" Context Provider: {context_status}") -' -""" - -[tasks.demo] -description = "Run a demo MLflow tracking session with plugin" -env = { MLFLOW_TRACKING_AUTH = "descope" } -run = """ -if [ -z "$DESCOPE_SESSION_TOKEN" ]; then - echo "Error: DESCOPE_SESSION_TOKEN required for demo" - echo "" - echo "Set these environment variables first:" - echo " export DESCOPE_PROJECT_ID='P2XXXXX'" - echo " export DESCOPE_SESSION_TOKEN='your-token'" - echo " export DESCOPE_REFRESH_TOKEN='your-refresh-token'" - exit 1 -fi - -echo "Running demo with Descope authentication..." -uv run python -c ' -import mlflow -import os - -print(f"MLflow Tracking URI: {mlflow.get_tracking_uri()}") -print(f"Auth Provider: descope (via MLFLOW_TRACKING_AUTH)") - -with mlflow.start_run(run_name="descope-plugin-demo"): - mlflow.log_param("plugin", "mlflow-descope-auth") - mlflow.log_metric("demo_metric", 1.0) - print("Run logged successfully with Descope authentication!") +app_status = "descope" if "descope" in app_eps else "NOT FOUND" +print(f" App Factory: {app_status}") ' """ diff --git a/mlflow_descope_auth/__init__.py b/mlflow_descope_auth/__init__.py index 1fddbe6..3f8d26f 100644 --- a/mlflow_descope_auth/__init__.py +++ b/mlflow_descope_auth/__init__.py @@ -1,35 +1,29 @@ """MLflow Descope Authentication Plugin. -This plugin provides simple, standards-compliant authentication for MLflow using Descope. -It integrates via MLflow's plugin system with: -- Request authentication provider -- Request header provider -- Run context provider +This plugin provides server-side authentication for MLflow using Descope. -Authentication tokens are managed via environment variables: -- DESCOPE_SESSION_TOKEN: Current session token -- DESCOPE_REFRESH_TOKEN: Refresh token for automatic renewal +Usage: + Start MLflow with: mlflow server --app-name descope + Browser login at /auth/login with automatic token refresh via cookies. """ __version__ = "0.1.0" -from .auth_provider import DescopeAuth, DescopeAuthProvider +from .auth_routes import register_auth_routes from .client import DescopeClientWrapper, get_descope_client -from .config import Config, get_config -from .context_provider import DescopeContextProvider -from .header_provider import DescopeHeaderProvider +from .config import Config, get_config, set_config +from .server import create_app from .store import DescopeUserStore, get_user_store __all__ = [ - "DescopeAuth", - "DescopeAuthProvider", - "DescopeHeaderProvider", - "DescopeContextProvider", "DescopeClientWrapper", "get_descope_client", "Config", "get_config", + "set_config", "DescopeUserStore", "get_user_store", + "create_app", + "register_auth_routes", "__version__", ] diff --git a/mlflow_descope_auth/auth_provider.py b/mlflow_descope_auth/auth_provider.py deleted file mode 100644 index dac33a9..0000000 --- a/mlflow_descope_auth/auth_provider.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Descope authentication provider for MLflow.""" - -import logging -import os - -from descope import AuthException -from mlflow.tracking.request_auth.abstract_request_auth_provider import ( - RequestAuthProvider, -) - -from .client import get_descope_client -from .config import get_config - -logger = logging.getLogger(__name__) - - -class DescopeAuth: - """Authentication handler for Descope session tokens.""" - - def __init__(self): - self.client = get_descope_client() - self.config = get_config() - - def __call__(self, request): - session_token = os.environ.get("DESCOPE_SESSION_TOKEN") - refresh_token = os.environ.get("DESCOPE_REFRESH_TOKEN") - - if not session_token: - logger.warning("No DESCOPE_SESSION_TOKEN found in environment") - return request - - try: - jwt_response = self.client.validate_session(session_token, refresh_token) - - if jwt_response.get("sessionToken"): - new_token = jwt_response["sessionToken"]["jwt"] - request.headers["Authorization"] = f"Bearer {new_token}" - - claims = self.client.extract_user_claims(jwt_response) - request.headers["X-MLflow-User"] = claims["username"] - request.headers["X-MLflow-User-Email"] = claims["email"] - - if self.config.is_admin_role(claims["roles"]): - request.headers["X-MLflow-Admin"] = "true" - - return request - - except AuthException as e: - logger.error(f"Descope authentication failed: {e}") - return request - - -class DescopeAuthProvider(RequestAuthProvider): - """MLflow auth provider for Descope authentication.""" - - def get_name(self): - return "descope" - - def get_auth(self): - return DescopeAuth() diff --git a/mlflow_descope_auth/auth_routes.py b/mlflow_descope_auth/auth_routes.py new file mode 100644 index 0000000..5583b75 --- /dev/null +++ b/mlflow_descope_auth/auth_routes.py @@ -0,0 +1,185 @@ +"""Authentication routes for MLflow Descope integration. + +This module provides Flask routes for login, logout, and user info endpoints. +The login page uses the Descope Web Component for authentication. +""" + +import logging +from flask import Flask, g, jsonify, make_response, redirect + +from .config import get_config + +logger = logging.getLogger(__name__) + +# Inline HTML template for login page with Descope Web Component +LOGIN_TEMPLATE = """ + + + + + MLflow Login - Descope + + + + +
+

MLflow

+

Sign in to continue

+
+ +
+ + + +""" + + +def register_auth_routes(app: Flask) -> None: + """Register authentication routes on the Flask app. + + This adds the following routes: + - /auth/login - Login page with Descope Web Component + - /auth/logout - Logout endpoint (clears cookies) + - /auth/user - Get current user info + - /health - Health check endpoint + + Args: + app: The Flask application. + """ + config = get_config() + + @app.route("/auth/login") + def auth_login(): + """Render login page with Descope Web Component.""" + html = LOGIN_TEMPLATE.format( + web_component_url=config.web_component_url, + project_id=config.DESCOPE_PROJECT_ID, + flow_id=config.DESCOPE_FLOW_ID, + session_cookie=config.SESSION_COOKIE_NAME, + refresh_cookie=config.REFRESH_COOKIE_NAME, + redirect_url=config.DESCOPE_REDIRECT_URL, + ) + return html, 200, {"Content-Type": "text/html"} + + @app.route("/auth/logout") + def auth_logout(): + """Clear authentication cookies and redirect to login.""" + response = make_response(redirect("/auth/login", code=302)) + + # Delete authentication cookies + response.delete_cookie(config.SESSION_COOKIE_NAME, path="/") + response.delete_cookie(config.REFRESH_COOKIE_NAME, path="/") + + logger.info("User logged out") + return response + + @app.route("/auth/user") + def auth_user(): + """Get current authenticated user information.""" + if not hasattr(g, "username"): + return jsonify({"error": "Not authenticated"}), 401 + + return jsonify( + { + "user_id": getattr(g, "user_id", None), + "username": g.username, + "email": getattr(g, "email", None), + "name": getattr(g, "name", None), + "roles": getattr(g, "roles", []), + "permissions": getattr(g, "permissions", []), + "tenants": getattr(g, "tenants", []), + "is_admin": getattr(g, "is_admin", False), + } + ) + + @app.route("/health") + def health_check(): + """Health check endpoint.""" + return jsonify( + { + "status": "healthy", + "service": "mlflow-descope-auth", + "project_id": config.DESCOPE_PROJECT_ID, + } + ) + + logger.debug("Registered auth routes: /auth/login, /auth/logout, /auth/user, /health") diff --git a/mlflow_descope_auth/client.py b/mlflow_descope_auth/client.py index 803e831..e50bde8 100644 --- a/mlflow_descope_auth/client.py +++ b/mlflow_descope_auth/client.py @@ -13,20 +13,17 @@ class DescopeClientWrapper: """Wrapper around Descope SDK for MLflow authentication.""" - def __init__(self, project_id: Optional[str] = None, management_key: Optional[str] = None): + def __init__(self, project_id: Optional[str] = None): """Initialize Descope client. Args: project_id: Descope project ID. If None, loads from config. - management_key: Descope management key. If None, loads from config. """ config = get_config() self.project_id = project_id or config.DESCOPE_PROJECT_ID - self.management_key = management_key or config.DESCOPE_MANAGEMENT_KEY self.client = DescopeClient( project_id=self.project_id, - management_key=self.management_key, ) logger.info(f"Initialized Descope client for project: {self.project_id}") diff --git a/mlflow_descope_auth/config.py b/mlflow_descope_auth/config.py index 9621308..25300a6 100644 --- a/mlflow_descope_auth/config.py +++ b/mlflow_descope_auth/config.py @@ -18,7 +18,6 @@ class Config: DESCOPE_PROJECT_ID: str # Optional - Descope settings - DESCOPE_MANAGEMENT_KEY: Optional[str] = None DESCOPE_FLOW_ID: str = "sign-up-or-in" DESCOPE_REDIRECT_URL: str = "/" DESCOPE_WEB_COMPONENT_VERSION: str = "3.54.0" @@ -39,6 +38,9 @@ class Config: SESSION_COOKIE_NAME: str = "DS" REFRESH_COOKIE_NAME: str = "DSR" + # Cookie security settings + COOKIE_SECURE: bool = False # Set True in production (HTTPS only) + @classmethod def from_env(cls) -> "Config": """Load configuration from environment variables. @@ -60,9 +62,11 @@ def from_env(cls) -> "Config": admin_roles_str = os.getenv("DESCOPE_ADMIN_ROLES", "admin,mlflow-admin") admin_roles = [role.strip() for role in admin_roles_str.split(",")] + # Parse cookie secure flag + cookie_secure = os.getenv("DESCOPE_COOKIE_SECURE", "false").lower() == "true" + return cls( DESCOPE_PROJECT_ID=project_id, - DESCOPE_MANAGEMENT_KEY=os.getenv("DESCOPE_MANAGEMENT_KEY"), DESCOPE_FLOW_ID=os.getenv("DESCOPE_FLOW_ID", "sign-up-or-in"), DESCOPE_REDIRECT_URL=os.getenv("DESCOPE_REDIRECT_URL", "/"), DESCOPE_WEB_COMPONENT_VERSION=os.getenv("DESCOPE_WEB_COMPONENT_VERSION", "3.54.0"), @@ -72,6 +76,7 @@ def from_env(cls) -> "Config": ADMIN_ROLES=admin_roles, DEFAULT_PERMISSION=os.getenv("DESCOPE_DEFAULT_PERMISSION", "READ"), USERNAME_CLAIM=os.getenv("DESCOPE_USERNAME_CLAIM", "sub"), + COOKIE_SECURE=cookie_secure, ) def is_admin_role(self, roles: List[str]) -> bool: diff --git a/mlflow_descope_auth/context_provider.py b/mlflow_descope_auth/context_provider.py deleted file mode 100644 index 7047278..0000000 --- a/mlflow_descope_auth/context_provider.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Descope run context provider for MLflow.""" - -import logging -import os - -from mlflow.tracking.context.abstract_context import RunContextProvider - -from .client import get_descope_client - -logger = logging.getLogger(__name__) - - -class DescopeContextProvider(RunContextProvider): - """Automatically add Descope user context as tags to MLflow runs.""" - - def __init__(self): - self._client = None - - @property - def client(self): - if self._client is None: - self._client = get_descope_client() - return self._client - - def in_context(self): - return os.environ.get("DESCOPE_SESSION_TOKEN") is not None - - def tags(self): - session_token = os.environ.get("DESCOPE_SESSION_TOKEN") - refresh_token = os.environ.get("DESCOPE_REFRESH_TOKEN") - - if not session_token: - return {} - - try: - jwt_response = self.client.validate_session(session_token, refresh_token) - claims = self.client.extract_user_claims(jwt_response) - - tags = { - "descope.user_id": claims["user_id"], - "descope.username": claims["username"], - "descope.email": claims["email"], - } - - if claims["name"]: - tags["descope.name"] = claims["name"] - - if claims["roles"]: - tags["descope.roles"] = ",".join(claims["roles"]) - - if claims["permissions"]: - tags["descope.permissions"] = ",".join(claims["permissions"]) - - if claims["tenants"]: - tags["descope.tenants"] = ",".join(claims["tenants"]) - - logger.debug(f"Added Descope context for user: {claims['username']}") - return tags - - except Exception as e: - logger.error(f"Failed to add Descope context: {e}") - return {} diff --git a/mlflow_descope_auth/header_provider.py b/mlflow_descope_auth/header_provider.py deleted file mode 100644 index ecb84f4..0000000 --- a/mlflow_descope_auth/header_provider.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Descope request header provider for MLflow.""" - -import logging -import os - -from mlflow.tracking.request_header.abstract_request_header_provider import ( - RequestHeaderProvider, -) - -from .client import get_descope_client -from .config import get_config - -logger = logging.getLogger(__name__) - - -class DescopeHeaderProvider(RequestHeaderProvider): - """Add Descope user context headers to MLflow requests.""" - - def __init__(self): - self._client = None - self._config = None - - @property - def client(self): - if self._client is None: - self._client = get_descope_client() - return self._client - - @property - def config(self): - if self._config is None: - self._config = get_config() - return self._config - - def in_context(self): - return os.environ.get("DESCOPE_SESSION_TOKEN") is not None - - def request_headers(self): - session_token = os.environ.get("DESCOPE_SESSION_TOKEN") - refresh_token = os.environ.get("DESCOPE_REFRESH_TOKEN") - - if not session_token: - return {} - - try: - jwt_response = self.client.validate_session(session_token, refresh_token) - claims = self.client.extract_user_claims(jwt_response) - - headers = { - "X-Descope-User-ID": claims["user_id"], - "X-Descope-Username": claims["username"], - "X-Descope-Email": claims["email"], - "X-Descope-Project-ID": self.config.DESCOPE_PROJECT_ID, - } - - if claims["name"]: - headers["X-Descope-Name"] = claims["name"] - - if claims["roles"]: - headers["X-Descope-Roles"] = ",".join(claims["roles"]) - - if claims["permissions"]: - headers["X-Descope-Permissions"] = ",".join(claims["permissions"]) - - if claims["tenants"]: - headers["X-Descope-Tenants"] = ",".join(claims["tenants"]) - - return headers - - except Exception as e: - logger.error(f"Failed to add Descope headers: {e}") - return {} diff --git a/mlflow_descope_auth/server.py b/mlflow_descope_auth/server.py new file mode 100644 index 0000000..c28d5af --- /dev/null +++ b/mlflow_descope_auth/server.py @@ -0,0 +1,181 @@ +"""Server-side authentication for MLflow using Descope. + +This module provides the Flask app factory for the `mlflow.app` entry point, +enabling cookie-based authentication with the Descope Web Component. + +Usage: + mlflow server --app-name descope +""" + +import logging +from typing import Optional + +from descope import AuthException +from flask import Flask, g, redirect, request + +from .auth_routes import register_auth_routes +from .client import get_descope_client +from .config import get_config + +logger = logging.getLogger(__name__) + +# Public routes that don't require authentication +PUBLIC_ROUTES = { + "/auth/login", + "/auth/logout", + "/auth/callback", + "/health", + "/version", +} + +# Public route prefixes +PUBLIC_PREFIXES = [ + "/static/", + "/_static/", +] + + +def _is_public_route(path: str) -> bool: + """Check if a route is public (doesn't require authentication). + + Args: + path: The request path. + + Returns: + bool: True if the route is public, False otherwise. + """ + if path in PUBLIC_ROUTES: + return True + return any(path.startswith(prefix) for prefix in PUBLIC_PREFIXES) + + +def _before_request(): + """Validate Descope session on each request. + + This function is called before every request to the MLflow server. + It validates the session token from cookies and sets user info in Flask g. + + Returns: + None if authenticated, or redirect Response if not. + """ + # Skip auth for public routes + if _is_public_route(request.path): + return None + + config = get_config() + client = get_descope_client() + + # Read tokens from cookies + session_token = request.cookies.get(config.SESSION_COOKIE_NAME) + refresh_token = request.cookies.get(config.REFRESH_COOKIE_NAME) + + # No session token? Redirect to login + if not session_token: + logger.debug(f"No session token for path: {request.path}") + return redirect("/auth/login", code=302) + + # Validate session with Descope + try: + jwt_response = client.validate_session(session_token, refresh_token) + claims = client.extract_user_claims(jwt_response) + + # Store user info in Flask g for this request + g.user_id = claims["user_id"] + g.username = claims["username"] + g.email = claims["email"] + g.name = claims["name"] + g.roles = claims["roles"] + g.permissions = claims["permissions"] + g.tenants = claims["tenants"] + g.is_admin = config.is_admin_role(claims["roles"]) + + # Store JWT response for potential cookie update in after_request + g._descope_jwt_response = jwt_response + g._descope_session_token = session_token + + logger.debug(f"Authenticated user: {claims['username']} (admin: {g.is_admin})") + return None + + except AuthException as e: + logger.warning(f"Session validation failed: {e}") + return redirect("/auth/login", code=302) + + except Exception as e: + logger.error(f"Unexpected error in auth: {e}", exc_info=True) + return redirect("/auth/login?error=internal_error", code=302) + + +def _after_request(response): + """Update session cookie if token was refreshed. + + This function is called after every request. If the session token + was refreshed during validation, the new token is set in cookies. + + Args: + response: The Flask Response object. + + Returns: + The Response with updated cookies if needed. + """ + # Check if we have a new session token from refresh + if hasattr(g, "_descope_jwt_response"): + jwt_response = g._descope_jwt_response + original_token = getattr(g, "_descope_session_token", None) + + # Check if token was refreshed (new token different from original) + new_session = jwt_response.get("sessionToken", {}) + new_jwt = new_session.get("jwt") if isinstance(new_session, dict) else None + + if new_jwt and new_jwt != original_token: + config = get_config() + + # Update session cookie with refreshed token + response.set_cookie( + key=config.SESSION_COOKIE_NAME, + value=new_jwt, + max_age=3600, # 1 hour + path="/", + secure=config.COOKIE_SECURE, + httponly=True, + samesite="Lax", + ) + logger.debug(f"Updated session cookie for user: {g.username}") + + return response + + +def create_app(app: Optional[Flask] = None) -> Flask: + """Create Descope-authenticated MLflow Flask app. + + This is the factory function for the `mlflow.app` entry point. + It adds Descope authentication to the MLflow Flask server. + + Args: + app: The MLflow Flask app. If None, imports from mlflow.server. + + Returns: + Flask: The app with Descope authentication enabled. + + Usage: + mlflow server --app-name descope + """ + # Import MLflow app if not provided + if app is None: + from mlflow.server import app as mlflow_app + + app = mlflow_app + + config = get_config() + logger.info(f"Initializing MLflow Descope Auth for project: {config.DESCOPE_PROJECT_ID}") + + # Register authentication routes (/auth/login, /auth/logout, etc.) + register_auth_routes(app) + + # Add before_request hook for authentication + app.before_request(_before_request) + + # Add after_request hook for cookie refresh + app.after_request(_after_request) + + logger.info("MLflow Descope Auth initialized successfully") + return app diff --git a/pyproject.toml b/pyproject.toml index ba2504e..9faa64f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,14 +55,8 @@ Documentation = "https://github.com/descope/mlflow-descope-auth#readme" Repository = "https://github.com/descope/mlflow-descope-auth" Issues = "https://github.com/descope/mlflow-descope-auth/issues" -[project.entry-points."mlflow.request_auth_provider"] -descope = "mlflow_descope_auth.auth_provider:DescopeAuthProvider" - -[project.entry-points."mlflow.request_header_provider"] -descope = "mlflow_descope_auth.header_provider:DescopeHeaderProvider" - -[project.entry-points."mlflow.run_context_provider"] -descope = "mlflow_descope_auth.context_provider:DescopeContextProvider" +[project.entry-points."mlflow.app"] +descope = "mlflow_descope_auth.server:create_app" [tool.setuptools.packages.find] where = ["."] diff --git a/tests/conftest.py b/tests/conftest.py index fafef26..39105ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,6 @@ def mock_config(): """Create a mock configuration for testing.""" return Config( DESCOPE_PROJECT_ID="test_project_id", - DESCOPE_MANAGEMENT_KEY="test_management_key", DESCOPE_FLOW_ID="test-flow", DESCOPE_REDIRECT_URL="/", ADMIN_ROLES=["admin", "test-admin"], diff --git a/tests/test_auth_routes.py b/tests/test_auth_routes.py new file mode 100644 index 0000000..fcc3417 --- /dev/null +++ b/tests/test_auth_routes.py @@ -0,0 +1,100 @@ +"""Tests for authentication routes.""" + +import pytest +from flask import Flask +from unittest.mock import Mock, patch + + +@pytest.fixture +def mock_config(): + config = Mock() + config.DESCOPE_PROJECT_ID = "test_project_id" + config.DESCOPE_FLOW_ID = "sign-up-or-in" + config.SESSION_COOKIE_NAME = "DS" + config.REFRESH_COOKIE_NAME = "DSR" + config.DESCOPE_REDIRECT_URL = "/" + config.web_component_url = "https://unpkg.com/@descope/web-component@3.54.0/dist/index.js" + return config + + +@pytest.fixture +def app(mock_config): + app = Flask(__name__) + app.config["TESTING"] = True + + with patch("mlflow_descope_auth.auth_routes.get_config", return_value=mock_config): + from mlflow_descope_auth.auth_routes import register_auth_routes + + register_auth_routes(app) + + return app + + +@pytest.fixture +def client(app): + return app.test_client() + + +class TestLoginRoute: + def test_login_returns_html(self, client): + response = client.get("/auth/login") + assert response.status_code == 200 + assert response.content_type == "text/html" + + def test_login_contains_descope_component(self, client): + response = client.get("/auth/login") + html = response.data.decode("utf-8") + assert "descope-wc" in html + assert "project-id" in html + assert "flow-id" in html + + def test_login_contains_project_id(self, client): + response = client.get("/auth/login") + html = response.data.decode("utf-8") + assert "test_project_id" in html + + def test_login_contains_web_component_script(self, client): + response = client.get("/auth/login") + html = response.data.decode("utf-8") + assert "@descope/web-component" in html + + +class TestLogoutRoute: + def test_logout_redirects_to_login(self, client): + response = client.get("/auth/logout") + assert response.status_code == 302 + assert "/auth/login" in response.location + + def test_logout_clears_session_cookie(self, client): + response = client.get("/auth/logout") + cookies = response.headers.getlist("Set-Cookie") + ds_cookies = [c for c in cookies if c.startswith("DS=")] + assert len(ds_cookies) > 0 + assert any("Max-Age=0" in c or "expires=" in c.lower() for c in ds_cookies) + + def test_logout_clears_refresh_cookie(self, client): + response = client.get("/auth/logout") + cookies = response.headers.getlist("Set-Cookie") + dsr_cookies = [c for c in cookies if c.startswith("DSR=")] + assert len(dsr_cookies) > 0 + + +class TestUserRoute: + def test_user_unauthenticated_returns_401(self, client): + response = client.get("/auth/user") + assert response.status_code == 401 + data = response.get_json() + assert data["error"] == "Not authenticated" + + +class TestHealthRoute: + def test_health_returns_ok(self, client): + response = client.get("/health") + assert response.status_code == 200 + + def test_health_returns_json(self, client): + response = client.get("/health") + data = response.get_json() + assert data["status"] == "healthy" + assert data["service"] == "mlflow-descope-auth" + assert "project_id" in data diff --git a/tests/test_plugin_integration.py b/tests/test_plugin_integration.py index eb44267..2147837 100644 --- a/tests/test_plugin_integration.py +++ b/tests/test_plugin_integration.py @@ -1,5 +1,6 @@ """Integration tests for MLflow plugin entry points.""" +import inspect import sys # Use importlib.metadata for all Python versions (available in 3.8+ via importlib_metadata backport) @@ -10,76 +11,22 @@ class TestPluginEntryPoints: - def test_auth_provider_entry_point(self): - eps = entry_points(group="mlflow.request_auth_provider") + def test_app_entry_point(self): + eps = entry_points(group="mlflow.app") names = [ep.name for ep in eps] - assert "descope" in names, "descope auth provider not registered" + assert "descope" in names, "descope app not registered" - def test_header_provider_entry_point(self): - eps = entry_points(group="mlflow.request_header_provider") - names = [ep.name for ep in eps] - - assert "descope" in names, "descope header provider not registered" - - def test_context_provider_entry_point(self): - eps = entry_points(group="mlflow.run_context_provider") - names = [ep.name for ep in eps] - - assert "descope" in names, "descope context provider not registered" - - def test_auth_provider_can_be_loaded(self): - eps = {ep.name: ep for ep in entry_points(group="mlflow.request_auth_provider")} - - assert "descope" in eps - provider_class = eps["descope"].load() - assert provider_class is not None - assert hasattr(provider_class, "get_name") - assert hasattr(provider_class, "get_auth") - - def test_header_provider_can_be_loaded(self): - eps = {ep.name: ep for ep in entry_points(group="mlflow.request_header_provider")} - - assert "descope" in eps - provider_class = eps["descope"].load() - assert provider_class is not None - assert hasattr(provider_class, "in_context") - assert hasattr(provider_class, "request_headers") - - def test_context_provider_can_be_loaded(self): - eps = {ep.name: ep for ep in entry_points(group="mlflow.run_context_provider")} + def test_app_entry_point_can_be_loaded(self): + eps = {ep.name: ep for ep in entry_points(group="mlflow.app")} assert "descope" in eps - provider_class = eps["descope"].load() - assert provider_class is not None - assert hasattr(provider_class, "in_context") - assert hasattr(provider_class, "tags") - - def test_auth_provider_instantiation(self): - from mlflow_descope_auth.auth_provider import DescopeAuthProvider - - provider = DescopeAuthProvider() - assert provider.get_name() == "descope" - assert callable(provider.get_auth()) - - def test_header_provider_not_in_context_without_token(self): - import os - - os.environ.pop("DESCOPE_SESSION_TOKEN", None) - - from mlflow_descope_auth.header_provider import DescopeHeaderProvider - - provider = DescopeHeaderProvider() - assert not provider.in_context() - assert provider.request_headers() == {} - - def test_context_provider_not_in_context_without_token(self): - import os - - os.environ.pop("DESCOPE_SESSION_TOKEN", None) + factory = eps["descope"].load() + assert factory is not None + assert callable(factory) - from mlflow_descope_auth.context_provider import DescopeContextProvider + def test_create_app_signature(self): + from mlflow_descope_auth.server import create_app - provider = DescopeContextProvider() - assert not provider.in_context() - assert provider.tags() == {} + sig = inspect.signature(create_app) + assert "app" in sig.parameters diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py new file mode 100644 index 0000000..981c0b1 --- /dev/null +++ b/tests/test_server_auth.py @@ -0,0 +1,246 @@ +"""Tests for server-side authentication.""" + +import pytest +from flask import Flask, g +from unittest.mock import Mock, patch, MagicMock +from descope import AuthException + + +class TestIsPublicRoute: + def test_auth_login_is_public(self): + from mlflow_descope_auth.server import _is_public_route + + assert _is_public_route("/auth/login") is True + + def test_auth_logout_is_public(self): + from mlflow_descope_auth.server import _is_public_route + + assert _is_public_route("/auth/logout") is True + + def test_health_is_public(self): + from mlflow_descope_auth.server import _is_public_route + + assert _is_public_route("/health") is True + + def test_static_prefix_is_public(self): + from mlflow_descope_auth.server import _is_public_route + + assert _is_public_route("/static/css/style.css") is True + assert _is_public_route("/_static/js/app.js") is True + + def test_api_routes_are_protected(self): + from mlflow_descope_auth.server import _is_public_route + + assert _is_public_route("/api/2.0/mlflow/experiments/list") is False + assert _is_public_route("/ajax-api/2.0/mlflow/runs/search") is False + + def test_root_is_protected(self): + from mlflow_descope_auth.server import _is_public_route + + assert _is_public_route("/") is False + + +class TestCreateApp: + def test_create_app_returns_flask_app(self): + mock_app = MagicMock(spec=Flask) + mock_app.before_request = MagicMock() + mock_app.after_request = MagicMock() + mock_app.route = MagicMock(return_value=lambda f: f) + + with patch("mlflow_descope_auth.server.get_config") as mock_config: + mock_config.return_value = Mock( + DESCOPE_PROJECT_ID="test_project", + SESSION_COOKIE_NAME="DS", + REFRESH_COOKIE_NAME="DSR", + DESCOPE_FLOW_ID="sign-up-or-in", + DESCOPE_REDIRECT_URL="/", + web_component_url="https://example.com/wc.js", + ) + with patch("mlflow_descope_auth.auth_routes.get_config", mock_config): + from mlflow_descope_auth.server import create_app + + result = create_app(mock_app) + + assert result is mock_app + mock_app.before_request.assert_called_once() + mock_app.after_request.assert_called_once() + + def test_create_app_imports_mlflow_app_when_none(self): + with patch("mlflow_descope_auth.server.get_config") as mock_config: + mock_config.return_value = Mock( + DESCOPE_PROJECT_ID="test_project", + SESSION_COOKIE_NAME="DS", + REFRESH_COOKIE_NAME="DSR", + ) + with patch("mlflow_descope_auth.auth_routes.get_config", mock_config): + with patch("mlflow.server.app") as mock_mlflow_app: + mock_mlflow_app.before_request = MagicMock() + mock_mlflow_app.after_request = MagicMock() + mock_mlflow_app.route = MagicMock(return_value=lambda f: f) + + from mlflow_descope_auth.server import create_app + + result = create_app(None) + + assert result is mock_mlflow_app + + +class TestBeforeRequest: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_config(self): + config = Mock() + config.SESSION_COOKIE_NAME = "DS" + config.REFRESH_COOKIE_NAME = "DSR" + config.COOKIE_SECURE = False + config.is_admin_role = Mock(return_value=False) + return config + + @pytest.fixture + def mock_client(self): + client = Mock() + client.validate_session = Mock( + return_value={ + "sessionToken": {"jwt": "new_token"}, + } + ) + client.extract_user_claims = Mock( + return_value={ + "user_id": "user123", + "username": "testuser", + "email": "test@example.com", + "name": "Test User", + "roles": ["user"], + "permissions": ["read"], + "tenants": {}, + } + ) + return client + + def test_public_route_skips_auth(self, app, mock_config): + with app.test_request_context("/auth/login"): + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + from mlflow_descope_auth.server import _before_request + + result = _before_request() + assert result is None + + def test_no_session_redirects_to_login(self, app, mock_config): + with app.test_request_context("/api/experiments", headers={}): + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + from mlflow_descope_auth.server import _before_request + + result = _before_request() + assert result is not None + assert result.status_code == 302 + assert "/auth/login" in result.location + + def test_valid_session_sets_user_context(self, app, mock_config, mock_client): + with app.test_request_context( + "/api/experiments", + headers={"Cookie": "DS=valid_token; DSR=refresh_token"}, + ): + from flask import request + + request.cookies = {"DS": "valid_token", "DSR": "refresh_token"} + + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + with patch( + "mlflow_descope_auth.server.get_descope_client", return_value=mock_client + ): + from mlflow_descope_auth.server import _before_request + + result = _before_request() + + assert result is None + assert g.username == "testuser" + assert g.email == "test@example.com" + assert g.user_id == "user123" + + def test_invalid_session_redirects_to_login(self, app, mock_config, mock_client): + mock_client.validate_session.side_effect = AuthException(401, "invalid", "Invalid token") + + with app.test_request_context("/api/experiments"): + from flask import request + + request.cookies = {"DS": "invalid_token"} + + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + with patch( + "mlflow_descope_auth.server.get_descope_client", return_value=mock_client + ): + from mlflow_descope_auth.server import _before_request + + result = _before_request() + + assert result is not None + assert result.status_code == 302 + assert "/auth/login" in result.location + + +class TestAfterRequest: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_config(self): + config = Mock() + config.SESSION_COOKIE_NAME = "DS" + config.COOKIE_SECURE = False + return config + + def test_no_jwt_response_returns_unchanged(self, app, mock_config): + with app.test_request_context("/"): + from flask import make_response + + from mlflow_descope_auth.server import _after_request + + response = make_response("OK") + + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + result = _after_request(response) + + assert result is response + + def test_refreshed_token_sets_cookie(self, app, mock_config): + with app.test_request_context("/"): + from flask import make_response + + g._descope_jwt_response = {"sessionToken": {"jwt": "new_refreshed_token"}} + g._descope_session_token = "old_token" + g.username = "testuser" + + from mlflow_descope_auth.server import _after_request + + response = make_response("OK") + + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + result = _after_request(response) + + cookies = result.headers.getlist("Set-Cookie") + assert any("DS=new_refreshed_token" in cookie for cookie in cookies) + + def test_same_token_no_cookie_update(self, app, mock_config): + with app.test_request_context("/"): + from flask import make_response + + g._descope_jwt_response = {"sessionToken": {"jwt": "same_token"}} + g._descope_session_token = "same_token" + + from mlflow_descope_auth.server import _after_request + + response = make_response("OK") + + with patch("mlflow_descope_auth.server.get_config", return_value=mock_config): + result = _after_request(response) + + cookies = result.headers.getlist("Set-Cookie") + assert not any("DS=" in cookie for cookie in cookies)