Skip to content

Commit 1814662

Browse files
committed
Add workspace CRUD support
1 parent 1f0b624 commit 1814662

20 files changed

Lines changed: 1183 additions & 7 deletions

mlflow/environment_variables.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def get(self):
106106
#: (default: ``None``)
107107
MLFLOW_REGISTRY_URI = _EnvironmentVariable("MLFLOW_REGISTRY_URI", str, None)
108108

109+
#: Specifies the workspace provider backend URI.
110+
#: Defaults to the tracking URI when unset.
111+
MLFLOW_WORKSPACE_URI = _EnvironmentVariable("MLFLOW_WORKSPACE_URI", str, None)
112+
109113
#: Specifies the ``dfs_tmpdir`` parameter to use for ``mlflow.spark.save_model``,
110114
#: ``mlflow.spark.log_model`` and ``mlflow.spark.load_model``. See
111115
#: https://www.mlflow.org/docs/latest/python_api/mlflow.spark.html#mlflow.spark.save_model
@@ -435,6 +439,10 @@ def get(self):
435439
"MLFLOW_ENABLE_DBFS_FUSE_ARTIFACT_REPO", True
436440
)
437441

442+
#: Enables workspace-aware multi-tenancy features on the MLflow server.
443+
#: (default: ``False``)
444+
MLFLOW_ENABLE_WORKSPACES = _BooleanEnvironmentVariable("MLFLOW_ENABLE_WORKSPACES", False)
445+
438446
#: Specifies whether or not to use UC Volume FUSE mount to store artifacts on Databricks
439447
#: (default: ``True``)
440448
MLFLOW_ENABLE_UC_VOLUME_FUSE_ARTIFACT_REPO = _BooleanEnvironmentVariable(

mlflow/server/handlers.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Param,
2929
RunTag,
3030
ViewType,
31+
Workspace,
3132
)
3233
from mlflow.entities.logged_model import LoggedModel
3334
from mlflow.entities.logged_model_input import LoggedModelInput
@@ -57,7 +58,9 @@
5758
from mlflow.protos import databricks_pb2
5859
from mlflow.protos.databricks_pb2 import (
5960
BAD_REQUEST,
61+
FEATURE_DISABLED,
6062
INVALID_PARAMETER_VALUE,
63+
INVALID_STATE,
6164
RESOURCE_DOES_NOT_EXIST,
6265
)
6366
from mlflow.protos.mlflow_artifacts_pb2 import (
@@ -183,6 +186,7 @@
183186
WebhookService,
184187
)
185188
from mlflow.server.validation import _validate_content_type
189+
from mlflow.server.workspace_helpers import _get_workspace_store, _workspaces_enabled_flag
186190
from mlflow.store.artifact.artifact_repo import MultipartUploadMixin
187191
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
188192
from mlflow.store.db.db_types import DATABASE_ENGINES
@@ -191,6 +195,7 @@
191195
from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore
192196
from mlflow.store.tracking.abstract_store import AbstractStore as AbstractTrackingStore
193197
from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore
198+
from mlflow.store.workspace.abstract_store import WorkspaceNameValidator
194199
from mlflow.tracing.utils.artifact_utils import (
195200
TRACE_DATA_FILE_NAME,
196201
get_artifact_uri_for_trace,
@@ -779,6 +784,122 @@ def wrapper(*args, **kwargs):
779784
return wrapper
780785

781786

787+
def _disable_if_workspaces_disabled(func):
788+
@wraps(func)
789+
def wrapper(*args, **kwargs):
790+
if not _workspaces_enabled_flag():
791+
return Response(
792+
(
793+
f"Endpoint: {request.url_rule} disabled because the server is running "
794+
"without multi-tenancy support. To enable workspace functionality, run "
795+
"`mlflow server` with `--enable-workspaces`"
796+
),
797+
databricks_pb2.FEATURE_DISABLED,
798+
)
799+
return func(*args, **kwargs)
800+
801+
return wrapper
802+
803+
804+
def _workspace_to_response_payload(workspace: Workspace) -> dict[str, str | None]:
805+
return {"name": workspace.name, "description": workspace.description}
806+
807+
808+
def _workspace_response(workspace: Workspace, status: int = 200) -> Response:
809+
response = jsonify(_workspace_to_response_payload(workspace))
810+
response.status_code = status
811+
return response
812+
813+
814+
815+
def _workspace_not_supported(message: str) -> MlflowException:
816+
return MlflowException(message, FEATURE_DISABLED)
817+
818+
819+
@catch_mlflow_exception
820+
@_disable_if_workspaces_disabled
821+
def _list_workspaces_handler():
822+
workspaces = _get_workspace_store().list_workspaces(request)
823+
return jsonify({"workspaces": [_workspace_to_response_payload(ws) for ws in workspaces]})
824+
825+
826+
@catch_mlflow_exception
827+
@_disable_if_workspaces_disabled
828+
def _create_workspace_handler():
829+
payload = _get_request_json(request) or {}
830+
name = payload.get("name")
831+
if not name:
832+
raise MlflowException.invalid_parameter_value("Workspace name must be provided")
833+
834+
WorkspaceNameValidator.validate(name)
835+
836+
description = payload.get("description")
837+
store = _get_workspace_store()
838+
try:
839+
workspace = store.create_workspace(Workspace(name=name, description=description), request)
840+
except NotImplementedError:
841+
raise _workspace_not_supported("Workspace creation is not supported by this provider")
842+
843+
return _workspace_response(workspace, status=201)
844+
845+
846+
@catch_mlflow_exception
847+
@_disable_if_workspaces_disabled
848+
def _get_workspace_handler(workspace_name: str):
849+
workspace = _get_workspace_store().get_workspace(workspace_name, request)
850+
return _workspace_response(workspace)
851+
852+
853+
@catch_mlflow_exception
854+
@_disable_if_workspaces_disabled
855+
def _update_workspace_handler(workspace_name: str):
856+
payload = _get_request_json(request) or {}
857+
858+
if not payload.keys():
859+
raise MlflowException.invalid_parameter_value("Workspace update must have at least one key")
860+
861+
invalid_keys = payload.keys() - {"description"}
862+
if invalid_keys:
863+
raise MlflowException.invalid_parameter_value(
864+
f"Workspace update had the following invalid keys: {', '.join(invalid_keys)}"
865+
)
866+
867+
store = _get_workspace_store()
868+
try:
869+
workspace = store.update_workspace(
870+
Workspace(name=workspace_name, description=payload["description"]), request
871+
)
872+
except NotImplementedError:
873+
raise _workspace_not_supported("Workspace updates are not supported by this provider")
874+
875+
return _workspace_response(workspace)
876+
877+
878+
@catch_mlflow_exception
879+
@_disable_if_workspaces_disabled
880+
def _delete_workspace_handler(workspace_name: str):
881+
store = _get_workspace_store()
882+
try:
883+
store.delete_workspace(workspace_name, request)
884+
except NotImplementedError:
885+
raise _workspace_not_supported("Workspace deletion is not supported by this provider")
886+
return Response(status=204)
887+
888+
889+
def _workspace_endpoints():
890+
endpoints = []
891+
for path in _get_paths("/mlflow/workspaces"):
892+
endpoints.append((path, _list_workspaces_handler, ["GET"]))
893+
endpoints.append((path, _create_workspace_handler, ["POST"]))
894+
895+
for path in _get_paths("/mlflow/workspaces/<workspace_name>"):
896+
endpoints.append((path, _get_workspace_handler, ["GET"]))
897+
endpoints.append((path, _update_workspace_handler, ["PATCH"]))
898+
endpoints.append((path, _delete_workspace_handler, ["DELETE"]))
899+
900+
return endpoints
901+
902+
782903
@catch_mlflow_exception
783904
def get_artifact_handler():
784905
run_id = request.args.get("run_id") or request.args.get("run_uuid")
@@ -3830,6 +3951,7 @@ def get_endpoints(get_handler=get_handler):
38303951
+ get_service_endpoints(MlflowArtifactsService, get_handler)
38313952
+ get_service_endpoints(WebhookService, get_handler)
38323953
+ [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])]
3954+
+ _workspace_endpoints()
38333955
)
38343956

38353957

mlflow/server/workspace_helpers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import logging
2+
import os
3+
4+
from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES, MLFLOW_WORKSPACE_URI
5+
from mlflow.exceptions import MlflowException
6+
from mlflow.protos import databricks_pb2
7+
from mlflow.tracking._workspace import utils as workspace_utils
8+
from mlflow.tracking._workspace.registry import get_workspace_store
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
_workspace_store = None
13+
14+
15+
def _workspaces_enabled_flag() -> bool:
16+
return bool(MLFLOW_ENABLE_WORKSPACES.get())
17+
18+
19+
def _get_workspace_store(workspace_uri: str | None = None, tracking_uri: str | None = None):
20+
if not _workspaces_enabled_flag():
21+
raise MlflowException(
22+
"Workspace APIs are not available: multi-tenancy is not enabled on this server",
23+
databricks_pb2.FEATURE_DISABLED,
24+
)
25+
26+
global _workspace_store
27+
if _workspace_store is not None:
28+
return _workspace_store
29+
30+
from mlflow.server import BACKEND_STORE_URI_ENV_VAR
31+
32+
resolved_tracking_uri = tracking_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR)
33+
resolved_workspace_uri = workspace_utils.resolve_workspace_uri(
34+
workspace_uri, tracking_uri=resolved_tracking_uri
35+
)
36+
if resolved_workspace_uri is None:
37+
raise MlflowException.invalid_parameter_value(
38+
"Workspace URI could not be resolved. Provide --workspace-store-uri or set "
39+
f"{MLFLOW_WORKSPACE_URI.name}."
40+
)
41+
42+
_workspace_store = get_workspace_store(workspace_uri=resolved_workspace_uri)
43+
return _workspace_store
44+
45+
46+
__all__ = ["_get_workspace_store", "_workspaces_enabled_flag"]
47+

mlflow/store/db/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
SqlRegisteredModelAlias,
4747
SqlRegisteredModelTag,
4848
)
49+
from mlflow.store.workspace.dbmodels.models import SqlWorkspace
4950
from mlflow.store.tracking.dbmodels.initial_models import Base as InitialBase
5051
from mlflow.store.tracking.dbmodels.models import (
5152
SqlDataset,
@@ -105,6 +106,7 @@ def _all_tables_exist(engine):
105106
SqlScorer.__tablename__,
106107
SqlScorerVersion.__tablename__,
107108
SqlJob.__tablename__,
109+
SqlWorkspace.__tablename__,
108110
}
109111

110112

mlflow/store/workspace/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Public workspace store facade and re-exports."""
2+
3+
from mlflow.entities.workspace import Workspace
4+
from mlflow.store.workspace.abstract_store import AbstractStore
5+
from mlflow.store.workspace.rest_store import RestWorkspaceStore
6+
7+
__all__ = [
8+
"Workspace",
9+
"AbstractStore",
10+
"RestWorkspaceStore",
11+
]
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from abc import ABC, abstractmethod
5+
from typing import Iterable
6+
7+
from mlflow.entities import Workspace
8+
from mlflow.exceptions import MlflowException
9+
from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
10+
11+
12+
class AbstractStore(ABC):
13+
"""Interface for resolving and managing workspaces in the tracking server."""
14+
15+
@abstractmethod
16+
def list_workspaces(self, request) -> Iterable[Workspace]:
17+
"""
18+
Return the workspaces visible to the current request context.
19+
20+
Implementations may inspect the request (e.g., for authN/Z context) to
21+
determine which workspaces to expose.
22+
"""
23+
24+
@abstractmethod
25+
def get_workspace(self, workspace_name: str, request) -> Workspace:
26+
"""
27+
Validate access to ``workspace_name`` and return its metadata.
28+
29+
Implementations should raise ``MlflowException`` with
30+
``RESOURCE_DOES_NOT_EXIST`` if the workspace cannot be found and
31+
``PERMISSION_DENIED`` if the caller lacks access.
32+
"""
33+
34+
def create_workspace(self, workspace: Workspace, request) -> Workspace:
35+
"""Provision a new workspace.
36+
37+
Raises ``NotImplementedError`` when the active provider is read-only.
38+
Implementations should raise ``MlflowException`` with
39+
``RESOURCE_ALREADY_EXISTS`` when the workspace already exists or
40+
``INVALID_PARAMETER_VALUE`` when validation fails.
41+
"""
42+
43+
raise NotImplementedError
44+
45+
def update_workspace(self, workspace: Workspace, request) -> Workspace:
46+
"""Update metadata for an existing workspace."""
47+
48+
raise NotImplementedError
49+
50+
def delete_workspace(self, workspace_name: str, request) -> None:
51+
"""Delete an existing workspace."""
52+
53+
raise NotImplementedError
54+
55+
def get_default_workspace(self, request) -> Workspace:
56+
"""
57+
Return the workspace to select when none is explicitly supplied.
58+
59+
Implementations that require an explicit workspace should raise an
60+
``MlflowException`` with ``INVALID_PARAMETER_VALUE``.
61+
"""
62+
63+
raise NotImplementedError
64+
65+
def resolve_artifact_root(
66+
self, default_artifact_root: str, workspace_name: str | None = None
67+
) -> tuple[str, bool]:
68+
"""
69+
Allow a provider to customize artifact storage roots per workspace.
70+
71+
Returns:
72+
A tuple ``(root, append_workspace_prefix)`` where ``root`` is the base artifact
73+
location to use for the workspace, and ``append_workspace_prefix`` controls whether
74+
MLflow should append the ``/workspaces/<workspace_name>`` suffix automatically.
75+
"""
76+
77+
return default_artifact_root, True
78+
79+
80+
class WorkspaceNameValidator:
81+
_PATTERN = r"^[a-z0-9][-a-z0-9]*[a-z0-9]$"
82+
_MIN_LENGTH = 2
83+
_MAX_LENGTH = 63
84+
_RESERVED = {DEFAULT_WORKSPACE_NAME, "workspaces", "api", "ajax-api", "static-files"}
85+
86+
@classmethod
87+
def pattern(cls) -> str:
88+
return cls._PATTERN
89+
90+
@classmethod
91+
def is_valid(cls, name: str) -> bool:
92+
if not isinstance(name, str):
93+
return False
94+
95+
if not (cls._MIN_LENGTH <= len(name) <= cls._MAX_LENGTH):
96+
return False
97+
98+
if not re.match(cls._PATTERN, name):
99+
return False
100+
101+
if name in cls._RESERVED:
102+
return False
103+
104+
return True
105+
106+
@classmethod
107+
def validate(cls, name: str) -> None:
108+
if not isinstance(name, str):
109+
raise MlflowException.invalid_parameter_value(
110+
f"Workspace name must be a string, got {type(name).__name__!s}."
111+
)
112+
113+
if not (cls._MIN_LENGTH <= len(name) <= cls._MAX_LENGTH):
114+
raise MlflowException.invalid_parameter_value(
115+
f"Workspace name '{name}' must be between {cls._MIN_LENGTH} and "
116+
f"{cls._MAX_LENGTH} characters."
117+
)
118+
119+
if not re.match(cls._PATTERN, name):
120+
raise MlflowException.invalid_parameter_value(
121+
f"Workspace name '{name}' must match the pattern {cls.pattern()} "
122+
"(lowercase alphanumeric with optional internal hyphens)."
123+
)
124+
125+
if name in cls._RESERVED:
126+
raise MlflowException.invalid_parameter_value(
127+
f"Workspace name '{name}' is reserved and cannot be used."
128+
)
129+

0 commit comments

Comments
 (0)