Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api_reference/api_inventory.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ mlflow.client.MlflowClient.create_prompt_version
mlflow.client.MlflowClient.create_registered_model
mlflow.client.MlflowClient.create_run
mlflow.client.MlflowClient.create_webhook
mlflow.client.MlflowClient.create_workspace
mlflow.client.MlflowClient.delete_dataset
mlflow.client.MlflowClient.delete_dataset_tag
mlflow.client.MlflowClient.delete_experiment
Expand All @@ -61,6 +62,7 @@ mlflow.client.MlflowClient.delete_tag
mlflow.client.MlflowClient.delete_trace_tag
mlflow.client.MlflowClient.delete_traces
mlflow.client.MlflowClient.delete_webhook
mlflow.client.MlflowClient.delete_workspace
mlflow.client.MlflowClient.detach_prompt_from_run
mlflow.client.MlflowClient.download_artifacts
mlflow.client.MlflowClient.end_span
Expand All @@ -84,6 +86,8 @@ mlflow.client.MlflowClient.get_registered_model
mlflow.client.MlflowClient.get_run
mlflow.client.MlflowClient.get_trace
mlflow.client.MlflowClient.get_webhook
mlflow.client.MlflowClient.get_workspace
mlflow.client.MlflowClient.get_workspace_uri
mlflow.client.MlflowClient.link_prompt_version_to_model
mlflow.client.MlflowClient.link_prompt_version_to_run
mlflow.client.MlflowClient.link_prompt_versions_to_trace
Expand All @@ -92,6 +96,7 @@ mlflow.client.MlflowClient.list_artifacts
mlflow.client.MlflowClient.list_logged_model_artifacts
mlflow.client.MlflowClient.list_logged_prompts
mlflow.client.MlflowClient.list_webhooks
mlflow.client.MlflowClient.list_workspaces
mlflow.client.MlflowClient.load_prompt
mlflow.client.MlflowClient.load_table
mlflow.client.MlflowClient.log_artifact
Expand Down Expand Up @@ -146,6 +151,7 @@ mlflow.client.MlflowClient.update_model_version
mlflow.client.MlflowClient.update_registered_model
mlflow.client.MlflowClient.update_run
mlflow.client.MlflowClient.update_webhook
mlflow.client.MlflowClient.update_workspace
mlflow.config.disable_system_metrics_logging
mlflow.config.enable_async_logging
mlflow.config.enable_system_metrics_logging
Expand Down Expand Up @@ -529,6 +535,7 @@ mlflow.entities.WebhookTestResult
mlflow.entities.WebhookTestResult.from_proto
mlflow.entities.WebhookTestResult.to_proto
mlflow.entities.Workspace
mlflow.entities.Workspace.to_dict
mlflow.entities.assessment.Assessment
mlflow.entities.assessment.Expectation
mlflow.entities.assessment.Feedback
Expand Down
38 changes: 37 additions & 1 deletion mlflow/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from mlflow import ai_commands, projects, version
from mlflow.entities import ViewType
from mlflow.entities.lifecycle_stage import LifecycleStage
from mlflow.environment_variables import MLFLOW_EXPERIMENT_ID, MLFLOW_EXPERIMENT_NAME
from mlflow.environment_variables import (
MLFLOW_ENABLE_WORKSPACES,
MLFLOW_EXPERIMENT_ID,
MLFLOW_EXPERIMENT_NAME,
MLFLOW_WORKSPACE_URI,
)
from mlflow.exceptions import InvalidUrlException, MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
Expand Down Expand Up @@ -458,6 +463,24 @@ def _validate_static_prefix(ctx, param, value):
"Unsupported on Windows."
),
)
@click.option(
"--workspace-store-uri",
envvar=MLFLOW_WORKSPACE_URI.name,
metavar="URI",
default=None,
help=(
"Workspace provider backend URI used for workspace CRUD APIs and request routing. "
"When unspecified, defaults to the backend store URI."
),
)
@click.option(
"--enable-workspaces",
envvar=MLFLOW_ENABLE_WORKSPACES.name,
is_flag=True,
default=False,
show_default=True,
help="Enable backwards compatible workspace-aware multi-tenancy mode.",
)
def server(
ctx,
backend_store_uri,
Expand All @@ -480,6 +503,8 @@ def server(
app_name,
dev,
uvicorn_opts,
workspace_store_uri,
enable_workspaces,
):
"""
Run the MLflow tracking server with built-in security middleware.
Expand Down Expand Up @@ -525,6 +550,17 @@ def server(
disable_security_middleware=disable_security_middleware,
)

if enable_workspaces:
os.environ[MLFLOW_ENABLE_WORKSPACES.name] = "true"
if workspace_store_uri:
os.environ[MLFLOW_WORKSPACE_URI.name] = workspace_store_uri
elif workspace_store_uri:
click.echo(
"Ignoring --workspace-store-uri because workspaces are not enabled. "
"Use --enable-workspaces to activate workspace mode.",
err=True,
)

if disable_security_middleware:
os.environ["MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE"] = "true"
else:
Expand Down
3 changes: 3 additions & 0 deletions mlflow/entities/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ class Workspace:

name: str
description: str | None = None

def to_dict(self) -> dict[str, str | None]:
return {"name": self.name, "description": self.description}
8 changes: 8 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def get(self):
#: (default: ``None``)
MLFLOW_REGISTRY_URI = _EnvironmentVariable("MLFLOW_REGISTRY_URI", str, None)

#: Specifies the workspace provider backend URI.
#: Defaults to the tracking URI when unset.
MLFLOW_WORKSPACE_URI = _EnvironmentVariable("MLFLOW_WORKSPACE_URI", str, None)

#: Specifies the ``dfs_tmpdir`` parameter to use for ``mlflow.spark.save_model``,
#: ``mlflow.spark.log_model`` and ``mlflow.spark.load_model``. See
#: https://www.mlflow.org/docs/latest/python_api/mlflow.spark.html#mlflow.spark.save_model
Expand Down Expand Up @@ -435,6 +439,10 @@ def get(self):
"MLFLOW_ENABLE_DBFS_FUSE_ARTIFACT_REPO", True
)

#: Enables workspace-aware multi-tenancy features on the MLflow server.
#: (default: ``False``)
MLFLOW_ENABLE_WORKSPACES = _BooleanEnvironmentVariable("MLFLOW_ENABLE_WORKSPACES", False)

#: Specifies whether or not to use UC Volume FUSE mount to store artifacts on Databricks
#: (default: ``True``)
MLFLOW_ENABLE_UC_VOLUME_FUSE_ARTIFACT_REPO = _BooleanEnvironmentVariable(
Expand Down
113 changes: 113 additions & 0 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Param,
RunTag,
ViewType,
Workspace,
)
from mlflow.entities.logged_model import LoggedModel
from mlflow.entities.logged_model_input import LoggedModelInput
Expand All @@ -45,6 +46,7 @@
from mlflow.environment_variables import (
MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX,
MLFLOW_DEPLOYMENTS_TARGET,
MLFLOW_ENABLE_WORKSPACES,
)
from mlflow.exceptions import (
MlflowException,
Expand All @@ -57,6 +59,7 @@
from mlflow.protos import databricks_pb2
from mlflow.protos.databricks_pb2 import (
BAD_REQUEST,
FEATURE_DISABLED,
INVALID_PARAMETER_VALUE,
RESOURCE_DOES_NOT_EXIST,
)
Expand Down Expand Up @@ -183,6 +186,7 @@
WebhookService,
)
from mlflow.server.validation import _validate_content_type
from mlflow.server.workspace_helpers import _get_workspace_store
from mlflow.store.artifact.artifact_repo import MultipartUploadMixin
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.db.db_types import DATABASE_ENGINES
Expand All @@ -191,6 +195,7 @@
from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore
from mlflow.store.tracking.abstract_store import AbstractStore as AbstractTrackingStore
from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore
from mlflow.store.workspace.abstract_store import WorkspaceNameValidator
from mlflow.tracing.utils.artifact_utils import (
TRACE_DATA_FILE_NAME,
get_artifact_uri_for_trace,
Expand Down Expand Up @@ -779,6 +784,113 @@ def wrapper(*args, **kwargs):
return wrapper


def _disable_if_workspaces_disabled(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not MLFLOW_ENABLE_WORKSPACES.get():
return Response(
(
f"Endpoint: {request.url_rule} disabled because the server is running "
"without multi-tenancy support. To enable workspace functionality, run "
"`mlflow server` with `--enable-workspaces`"
),
503,
)
return func(*args, **kwargs)

return wrapper


def _workspace_not_supported(message: str) -> MlflowException:
return MlflowException(message, FEATURE_DISABLED)


@catch_mlflow_exception
@_disable_if_workspaces_disabled
def _list_workspaces_handler():
workspaces = _get_workspace_store().list_workspaces()
return jsonify({"workspaces": [ws.to_dict() for ws in workspaces]})


@catch_mlflow_exception
@_disable_if_workspaces_disabled
def _create_workspace_handler():
payload = _get_request_json(request) or {}
name = payload.get("name")
if not name:
raise MlflowException.invalid_parameter_value("Workspace name must be provided")

WorkspaceNameValidator.validate(name)

description = payload.get("description")
store = _get_workspace_store()
try:
workspace = store.create_workspace(Workspace(name=name, description=description))
except NotImplementedError:
raise _workspace_not_supported("Workspace creation is not supported by this provider")

response = jsonify(workspace.to_dict())
response.status_code = 201
return response


@catch_mlflow_exception
@_disable_if_workspaces_disabled
def _get_workspace_handler(workspace_name: str):
workspace = _get_workspace_store().get_workspace(workspace_name)
return jsonify(workspace.to_dict())


@catch_mlflow_exception
@_disable_if_workspaces_disabled
def _update_workspace_handler(workspace_name: str):
payload = _get_request_json(request) or {}

if not payload.keys():
raise MlflowException.invalid_parameter_value("Workspace update must have at least one key")

invalid_keys = payload.keys() - {"description"}
if invalid_keys:
raise MlflowException.invalid_parameter_value(
f"Workspace update had the following invalid keys: {', '.join(invalid_keys)}"
)

store = _get_workspace_store()
try:
workspace = store.update_workspace(
Workspace(name=workspace_name, description=payload["description"])
)
except NotImplementedError:
raise _workspace_not_supported("Workspace updates are not supported by this provider")

return jsonify(workspace.to_dict())


@catch_mlflow_exception
@_disable_if_workspaces_disabled
def _delete_workspace_handler(workspace_name: str):
store = _get_workspace_store()
try:
store.delete_workspace(workspace_name)
except NotImplementedError:
raise _workspace_not_supported("Workspace deletion is not supported by this provider")
return Response(status=204)


def _workspace_endpoints():
endpoints = []
for path in _get_paths("/mlflow/workspaces"):
endpoints.append((path, _list_workspaces_handler, ["GET"]))
endpoints.append((path, _create_workspace_handler, ["POST"]))

for path in _get_paths("/mlflow/workspaces/<workspace_name>"):
endpoints.append((path, _get_workspace_handler, ["GET"]))
endpoints.append((path, _update_workspace_handler, ["PATCH"]))
endpoints.append((path, _delete_workspace_handler, ["DELETE"]))

return endpoints


@catch_mlflow_exception
def get_artifact_handler():
run_id = request.args.get("run_id") or request.args.get("run_uuid")
Expand Down Expand Up @@ -3830,6 +3942,7 @@ def get_endpoints(get_handler=get_handler):
+ get_service_endpoints(MlflowArtifactsService, get_handler)
+ get_service_endpoints(WebhookService, get_handler)
+ [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])]
+ _workspace_endpoints()
)


Expand Down
51 changes: 51 additions & 0 deletions mlflow/server/workspace_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import logging
import os

from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES, MLFLOW_WORKSPACE_URI
from mlflow.exceptions import MlflowException
from mlflow.protos import databricks_pb2
from mlflow.tracking._workspace import utils as workspace_utils
from mlflow.tracking._workspace.registry import get_workspace_store

_logger = logging.getLogger(__name__)

_workspace_store = None


def _get_workspace_store(workspace_uri: str | None = None, tracking_uri: str | None = None):
"""
Resolve and cache the workspace store configured for this server process.

The store is constructed on first invocation using the provided arguments (or their
environment-derived defaults) and memoized for all subsequent calls, regardless of any new
``workspace_uri`` / ``tracking_uri`` values supplied later.
"""
if not MLFLOW_ENABLE_WORKSPACES.get():
raise MlflowException(
"Workspace APIs are not available: multi-tenancy is not enabled on this server",
databricks_pb2.FEATURE_DISABLED,
)

global _workspace_store
if _workspace_store is not None:
return _workspace_store

from mlflow.server import BACKEND_STORE_URI_ENV_VAR

resolved_tracking_uri = tracking_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR)
resolved_workspace_uri = workspace_utils.resolve_workspace_uri(
workspace_uri, tracking_uri=resolved_tracking_uri
)
if resolved_workspace_uri is None:
raise MlflowException.invalid_parameter_value(
"Workspace URI could not be resolved. Provide --workspace-store-uri or set "
f"{MLFLOW_WORKSPACE_URI.name}."
)

_workspace_store = get_workspace_store(workspace_uri=resolved_workspace_uri)
return _workspace_store


__all__ = ["_get_workspace_store"]
2 changes: 2 additions & 0 deletions mlflow/store/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
SqlTraceMetadata,
SqlTraceTag,
)
from mlflow.store.workspace.dbmodels.models import SqlWorkspace

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -105,6 +106,7 @@ def _all_tables_exist(engine):
SqlScorer.__tablename__,
SqlScorerVersion.__tablename__,
SqlJob.__tablename__,
SqlWorkspace.__tablename__,
}


Expand Down
11 changes: 11 additions & 0 deletions mlflow/store/workspace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Public workspace store facade and re-exports."""

from mlflow.entities.workspace import Workspace
from mlflow.store.workspace.abstract_store import AbstractStore
from mlflow.store.workspace.rest_store import RestWorkspaceStore

__all__ = [
"Workspace",
"AbstractStore",
"RestWorkspaceStore",
]
Loading
Loading