|
28 | 28 | Param, |
29 | 29 | RunTag, |
30 | 30 | ViewType, |
| 31 | + Workspace, |
31 | 32 | ) |
32 | 33 | from mlflow.entities.logged_model import LoggedModel |
33 | 34 | from mlflow.entities.logged_model_input import LoggedModelInput |
|
57 | 58 | from mlflow.protos import databricks_pb2 |
58 | 59 | from mlflow.protos.databricks_pb2 import ( |
59 | 60 | BAD_REQUEST, |
| 61 | + FEATURE_DISABLED, |
60 | 62 | INVALID_PARAMETER_VALUE, |
| 63 | + INVALID_STATE, |
61 | 64 | RESOURCE_DOES_NOT_EXIST, |
62 | 65 | ) |
63 | 66 | from mlflow.protos.mlflow_artifacts_pb2 import ( |
|
183 | 186 | WebhookService, |
184 | 187 | ) |
185 | 188 | from mlflow.server.validation import _validate_content_type |
| 189 | +from mlflow.server.workspace_helpers import _get_workspace_store, _workspaces_enabled_flag |
186 | 190 | from mlflow.store.artifact.artifact_repo import MultipartUploadMixin |
187 | 191 | from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository |
188 | 192 | from mlflow.store.db.db_types import DATABASE_ENGINES |
|
191 | 195 | from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore |
192 | 196 | from mlflow.store.tracking.abstract_store import AbstractStore as AbstractTrackingStore |
193 | 197 | from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore |
| 198 | +from mlflow.store.workspace.abstract_store import WorkspaceNameValidator |
194 | 199 | from mlflow.tracing.utils.artifact_utils import ( |
195 | 200 | TRACE_DATA_FILE_NAME, |
196 | 201 | get_artifact_uri_for_trace, |
@@ -779,6 +784,122 @@ def wrapper(*args, **kwargs): |
779 | 784 | return wrapper |
780 | 785 |
|
781 | 786 |
|
| 787 | +def _disable_if_workspaces_disabled(func): |
| 788 | + @wraps(func) |
| 789 | + def wrapper(*args, **kwargs): |
| 790 | + if not _workspaces_enabled_flag(): |
| 791 | + return Response( |
| 792 | + ( |
| 793 | + f"Endpoint: {request.url_rule} disabled because the server is running " |
| 794 | + "without multi-tenancy support. To enable workspace functionality, run " |
| 795 | + "`mlflow server` with `--enable-workspaces`" |
| 796 | + ), |
| 797 | + databricks_pb2.FEATURE_DISABLED, |
| 798 | + ) |
| 799 | + return func(*args, **kwargs) |
| 800 | + |
| 801 | + return wrapper |
| 802 | + |
| 803 | + |
| 804 | +def _workspace_to_response_payload(workspace: Workspace) -> dict[str, str | None]: |
| 805 | + return {"name": workspace.name, "description": workspace.description} |
| 806 | + |
| 807 | + |
| 808 | +def _workspace_response(workspace: Workspace, status: int = 200) -> Response: |
| 809 | + response = jsonify(_workspace_to_response_payload(workspace)) |
| 810 | + response.status_code = status |
| 811 | + return response |
| 812 | + |
| 813 | + |
| 814 | + |
| 815 | +def _workspace_not_supported(message: str) -> MlflowException: |
| 816 | + return MlflowException(message, FEATURE_DISABLED) |
| 817 | + |
| 818 | + |
| 819 | +@catch_mlflow_exception |
| 820 | +@_disable_if_workspaces_disabled |
| 821 | +def _list_workspaces_handler(): |
| 822 | + workspaces = _get_workspace_store().list_workspaces(request) |
| 823 | + return jsonify({"workspaces": [_workspace_to_response_payload(ws) for ws in workspaces]}) |
| 824 | + |
| 825 | + |
| 826 | +@catch_mlflow_exception |
| 827 | +@_disable_if_workspaces_disabled |
| 828 | +def _create_workspace_handler(): |
| 829 | + payload = _get_request_json(request) or {} |
| 830 | + name = payload.get("name") |
| 831 | + if not name: |
| 832 | + raise MlflowException.invalid_parameter_value("Workspace name must be provided") |
| 833 | + |
| 834 | + WorkspaceNameValidator.validate(name) |
| 835 | + |
| 836 | + description = payload.get("description") |
| 837 | + store = _get_workspace_store() |
| 838 | + try: |
| 839 | + workspace = store.create_workspace(Workspace(name=name, description=description), request) |
| 840 | + except NotImplementedError: |
| 841 | + raise _workspace_not_supported("Workspace creation is not supported by this provider") |
| 842 | + |
| 843 | + return _workspace_response(workspace, status=201) |
| 844 | + |
| 845 | + |
| 846 | +@catch_mlflow_exception |
| 847 | +@_disable_if_workspaces_disabled |
| 848 | +def _get_workspace_handler(workspace_name: str): |
| 849 | + workspace = _get_workspace_store().get_workspace(workspace_name, request) |
| 850 | + return _workspace_response(workspace) |
| 851 | + |
| 852 | + |
| 853 | +@catch_mlflow_exception |
| 854 | +@_disable_if_workspaces_disabled |
| 855 | +def _update_workspace_handler(workspace_name: str): |
| 856 | + payload = _get_request_json(request) or {} |
| 857 | + |
| 858 | + if not payload.keys(): |
| 859 | + raise MlflowException.invalid_parameter_value("Workspace update must have at least one key") |
| 860 | + |
| 861 | + invalid_keys = payload.keys() - {"description"} |
| 862 | + if invalid_keys: |
| 863 | + raise MlflowException.invalid_parameter_value( |
| 864 | + f"Workspace update had the following invalid keys: {', '.join(invalid_keys)}" |
| 865 | + ) |
| 866 | + |
| 867 | + store = _get_workspace_store() |
| 868 | + try: |
| 869 | + workspace = store.update_workspace( |
| 870 | + Workspace(name=workspace_name, description=payload["description"]), request |
| 871 | + ) |
| 872 | + except NotImplementedError: |
| 873 | + raise _workspace_not_supported("Workspace updates are not supported by this provider") |
| 874 | + |
| 875 | + return _workspace_response(workspace) |
| 876 | + |
| 877 | + |
| 878 | +@catch_mlflow_exception |
| 879 | +@_disable_if_workspaces_disabled |
| 880 | +def _delete_workspace_handler(workspace_name: str): |
| 881 | + store = _get_workspace_store() |
| 882 | + try: |
| 883 | + store.delete_workspace(workspace_name, request) |
| 884 | + except NotImplementedError: |
| 885 | + raise _workspace_not_supported("Workspace deletion is not supported by this provider") |
| 886 | + return Response(status=204) |
| 887 | + |
| 888 | + |
| 889 | +def _workspace_endpoints(): |
| 890 | + endpoints = [] |
| 891 | + for path in _get_paths("/mlflow/workspaces"): |
| 892 | + endpoints.append((path, _list_workspaces_handler, ["GET"])) |
| 893 | + endpoints.append((path, _create_workspace_handler, ["POST"])) |
| 894 | + |
| 895 | + for path in _get_paths("/mlflow/workspaces/<workspace_name>"): |
| 896 | + endpoints.append((path, _get_workspace_handler, ["GET"])) |
| 897 | + endpoints.append((path, _update_workspace_handler, ["PATCH"])) |
| 898 | + endpoints.append((path, _delete_workspace_handler, ["DELETE"])) |
| 899 | + |
| 900 | + return endpoints |
| 901 | + |
| 902 | + |
782 | 903 | @catch_mlflow_exception |
783 | 904 | def get_artifact_handler(): |
784 | 905 | run_id = request.args.get("run_id") or request.args.get("run_uuid") |
@@ -3830,6 +3951,7 @@ def get_endpoints(get_handler=get_handler): |
3830 | 3951 | + get_service_endpoints(MlflowArtifactsService, get_handler) |
3831 | 3952 | + get_service_endpoints(WebhookService, get_handler) |
3832 | 3953 | + [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])] |
| 3954 | + + _workspace_endpoints() |
3833 | 3955 | ) |
3834 | 3956 |
|
3835 | 3957 |
|
|
0 commit comments