Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api_reference/api_inventory.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ mlflow.config.set_system_metrics_sampling_interval
mlflow.config.set_tracking_uri
mlflow.create_experiment
mlflow.create_external_model
mlflow.create_workspace
mlflow.crewai.autolog
mlflow.data.dataset.Dataset
mlflow.data.dataset.Dataset.to_dict
Expand Down Expand Up @@ -227,6 +228,7 @@ mlflow.delete_logged_model_tag
mlflow.delete_run
mlflow.delete_tag
mlflow.delete_trace_tag
mlflow.delete_workspace
mlflow.deployments.BaseDeploymentClient
mlflow.deployments.BaseDeploymentClient.create_deployment
mlflow.deployments.BaseDeploymentClient.create_endpoint
Expand Down Expand Up @@ -891,6 +893,7 @@ mlflow.get_registry_uri
mlflow.get_run
mlflow.get_trace
mlflow.get_tracking_uri
mlflow.get_workspace
mlflow.groq.autolog
mlflow.h2o.get_default_conda_env
mlflow.h2o.get_default_pip_requirements
Expand Down Expand Up @@ -934,6 +937,7 @@ mlflow.lightgbm.get_default_pip_requirements
mlflow.lightgbm.load_model
mlflow.lightgbm.log_model
mlflow.lightgbm.save_model
mlflow.list_workspaces
mlflow.litellm.autolog
mlflow.llama_index.autolog
mlflow.llama_index.load_model
Expand Down Expand Up @@ -1234,6 +1238,7 @@ mlflow.set_tag
mlflow.set_tags
mlflow.set_trace_tag
mlflow.set_tracking_uri
mlflow.set_workspace
mlflow.shap.get_default_conda_env
mlflow.shap.get_default_pip_requirements
mlflow.shap.get_underlying_model_flavor
Expand Down Expand Up @@ -1460,6 +1465,7 @@ mlflow.types.schema.Schema
mlflow.types.schema.TensorSpec
mlflow.update_assessment
mlflow.update_current_trace
mlflow.update_workspace
mlflow.utils.async_logging.run_operations.RunOperations
mlflow.utils.async_logging.run_operations.RunOperations.wait
mlflow.utils.async_logging.run_operations.get_combined_run_operations
Expand Down
14 changes: 14 additions & 0 deletions mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@
set_model_version_tag,
set_prompt_alias,
)
from mlflow.tracking._workspace.fluent import (
create_workspace,
delete_workspace,
get_workspace,
list_workspaces,
set_workspace,
update_workspace,
)
from mlflow.tracking.fluent import (
ActiveModel,
ActiveRun,
Expand Down Expand Up @@ -330,7 +338,9 @@
"clear_active_model",
"create_experiment",
"create_external_model",
"create_workspace",
"delete_experiment",
"delete_workspace",
"delete_run",
"delete_tag",
"disable_system_metrics_logging",
Expand All @@ -346,6 +356,7 @@
"get_experiment",
"get_experiment_by_name",
"get_logged_model",
"get_workspace",
"get_parent_run",
"get_registry_uri",
"get_run",
Expand Down Expand Up @@ -376,6 +387,7 @@
"search_logged_models",
"search_model_versions",
"search_registered_models",
"list_workspaces",
"search_runs",
"search_prompts",
"set_active_model",
Expand All @@ -389,6 +401,7 @@
"set_system_metrics_sampling_interval",
"set_tag",
"set_tags",
"set_workspace",
"start_run",
"validate_evaluation_results",
"Image",
Expand All @@ -402,6 +415,7 @@
"delete_prompt_alias",
"set_logged_model_tags",
"delete_logged_model_tag",
"update_workspace",
]


Expand Down
41 changes: 24 additions & 17 deletions mlflow/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,8 @@ def _validate_static_prefix(ctx, param, value):
),
)
@click.option(
"--enable-workspaces",
"--enable-workspaces/--disable-workspaces",
envvar=MLFLOW_ENABLE_WORKSPACES.name,
is_flag=True,
default=False,
show_default=True,
help="Enable backwards compatible workspace-aware multi-tenancy mode.",
Expand Down Expand Up @@ -552,16 +551,17 @@ def server(
disable_security_middleware=disable_security_middleware,
)

if enable_workspaces:
os.environ[MLFLOW_ENABLE_WORKSPACES.name] = "true"
if workspace_store_uri:
os.environ[MLFLOW_WORKSPACE_STORE_URI.name] = workspace_store_uri
# Keep environment flag in sync with the resolved boolean so server-side gating
# (which reads MLFLOW_ENABLE_WORKSPACES.get()) has a single source of truth.
os.environ[MLFLOW_ENABLE_WORKSPACES.name] = "true" if enable_workspaces else "false"
if enable_workspaces and workspace_store_uri:
os.environ[MLFLOW_WORKSPACE_STORE_URI.name] = workspace_store_uri
elif workspace_store_uri:
click.echo(
"Ignoring --workspace-store-uri because workspaces are not enabled. "
"Use --enable-workspaces to activate workspace mode.",
err=True,
warning_msg = (
"--workspace-store-uri was provided but workspaces are not enabled. "
"Workspace APIs will remain disabled unless you pass --enable-workspaces."
)
_logger.warning(warning_msg)

if disable_security_middleware:
os.environ["MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE"] = "true"
Expand All @@ -586,6 +586,7 @@ def server(
if x_frame_options:
os.environ["MLFLOW_SERVER_X_FRAME_OPTIONS"] = x_frame_options

# Ensure that both backend_store_uri and default_artifact_uri are set correctly.
if not backend_store_uri:
backend_store_uri = _get_default_tracking_uri()
click.echo(f"Backend store URI not provided. Using {backend_store_uri}")
Expand All @@ -597,14 +598,20 @@ def server(
default_artifact_root = resolve_default_artifact_root(
serve_artifacts, default_artifact_root, backend_store_uri
)
artifacts_only_config_validation(artifacts_only, backend_store_uri)
artifacts_only_config_validation(artifacts_only, backend_store_uri, enable_workspaces)

try:
initialize_backend_stores(backend_store_uri, registry_store_uri, default_artifact_root)
except Exception as e:
_logger.error("Error initializing backend store")
_logger.exception(e)
sys.exit(1)
if not artifacts_only:
try:
initialize_backend_stores(
backend_store_uri,
registry_store_uri,
default_artifact_root,
workspace_store_uri=workspace_store_uri,
)
except Exception as e:
_logger.error("Error initializing backend store")
_logger.exception(e)
sys.exit(1)

if disable_security_middleware:
click.echo(
Expand Down
7 changes: 7 additions & 0 deletions mlflow/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
get_trace_artifact_handler,
upload_artifact_handler,
)
from mlflow.server.workspace_helpers import (
workspace_before_request_handler,
workspace_teardown_request_handler,
)
from mlflow.utils.os import is_windows
from mlflow.utils.plugins import get_entry_points
from mlflow.utils.process import _exec_cmd
Expand Down Expand Up @@ -70,6 +74,9 @@

security.init_security_middleware(app)

app.before_request(workspace_before_request_handler)
app.teardown_request(workspace_teardown_request_handler)

for http_path, handler, methods in handlers.get_endpoints():
app.add_url_rule(http_path, handler.__name__, handler, methods=methods)

Expand Down
55 changes: 54 additions & 1 deletion mlflow/server/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,67 @@
to FastAPI endpoints.
"""

from fastapi import FastAPI
import json

from fastapi import FastAPI, Request
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.responses import JSONResponse
from flask import Flask

from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES
from mlflow.exceptions import MlflowException
from mlflow.protos import databricks_pb2
from mlflow.server import app as flask_app
from mlflow.server.fastapi_security import init_fastapi_security
from mlflow.server.job_api import job_api_router
from mlflow.server.otel_api import otel_router
from mlflow.server.workspace_helpers import WORKSPACE_HEADER_NAME, resolve_workspace_from_header
from mlflow.tracing.utils.otlp import OTLP_TRACES_PATH
from mlflow.utils.workspace_context import set_current_workspace
from mlflow.version import VERSION

# FastAPI routes that do not go through the Flask WSGI bridge (currently jobs + OTLP).
FASTAPI_NATIVE_PREFIXES = (job_api_router.prefix, OTLP_TRACES_PATH)


def add_fastapi_workspace_middleware(fastapi_app: FastAPI) -> None:
if getattr(fastapi_app.state, "workspace_middleware_added", False):
return

@fastapi_app.middleware("http")
async def workspace_context_middleware(request: Request, call_next):
if not MLFLOW_ENABLE_WORKSPACES.get():
return await call_next(request)

path = request.url.path
if not any(path.startswith(prefix) for prefix in FASTAPI_NATIVE_PREFIXES):
# Skip if it's a Flask route and let the Flask before request handler handle it.
return await call_next(request)

try:
workspace = resolve_workspace_from_header(request.headers.get(WORKSPACE_HEADER_NAME))
if workspace is None:
raise MlflowException(
f"Active workspace is required. Set the '{WORKSPACE_HEADER_NAME}' request "
"header, call mlflow.set_workspace(), or set the MLFLOW_WORKSPACE environment "
"variable before making requests.",
error_code=databricks_pb2.INVALID_PARAMETER_VALUE,
)
except MlflowException as e:
return JSONResponse(
status_code=e.get_http_status_code(),
content=json.loads(e.serialize_as_json()),
)

token = set_current_workspace(workspace.name)
try:
response = await call_next(request)
finally:
set_current_workspace(token)
return response

fastapi_app.state.workspace_middleware_added = True


def create_fastapi_app(flask_app: Flask = flask_app):
"""
Expand All @@ -39,6 +90,8 @@ def create_fastapi_app(flask_app: Flask = flask_app):
# Initialize security middleware BEFORE adding routes
init_fastapi_security(fastapi_app)

add_fastapi_workspace_middleware(fastapi_app)

# Include OpenTelemetry API router BEFORE mounting Flask app
# This ensures FastAPI routes take precedence over the catch-all Flask mount
fastapi_app.include_router(otel_router)
Expand Down
Loading
Loading