From 1f0b6245073a4586cfd23bdf436a4489ac1a0b9f Mon Sep 17 00:00:00 2001 From: mprahl Date: Tue, 18 Nov 2025 11:41:02 -0500 Subject: [PATCH] Add workspace database schema This adds the required workspace columns and the workspace catalog table with the default workspace precreated. All workspace columns default to "default" for now and we may choose to remove the defaults once the tracking store and model registry store are made workspace aware to catch application logic issues not properly setting the workspace. Some model registry store changes were needed to account for the new composite foreign key. Signed-off-by: mprahl --- docs/api_reference/api_inventory.txt | 2 + mlflow/entities/__init__.py | 2 + mlflow/entities/workspace.py | 13 + ...ad7c1_add_workspace_columns_and_catalog.py | 732 +++++++++++ .../store/model_registry/dbmodels/models.py | 110 +- .../store/model_registry/sqlalchemy_store.py | 29 +- mlflow/store/tracking/dbmodels/models.py | 28 +- mlflow/store/workspace/__init__.py | 0 mlflow/store/workspace/dbmodels/__init__.py | 3 + mlflow/store/workspace/dbmodels/models.py | 22 + mlflow/utils/workspace_utils.py | 5 + tests/db/check_migration.py | 51 +- tests/db/schemas/mssql.sql | 36 +- tests/db/schemas/mysql.sql | 36 +- tests/db/schemas/postgresql.sql | 35 +- tests/db/schemas/sqlite.sql | 36 +- tests/db/test_schema.py | 77 +- tests/db/test_workspace_migration.py | 1073 +++++++++++++++++ tests/resources/db/latest_schema.sql | 36 +- .../tracking/test_sqlalchemy_store_schema.py | 4 + 20 files changed, 2241 insertions(+), 89 deletions(-) create mode 100644 mlflow/entities/workspace.py create mode 100644 mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py create mode 100644 mlflow/store/workspace/__init__.py create mode 100644 mlflow/store/workspace/dbmodels/__init__.py create mode 100644 mlflow/store/workspace/dbmodels/models.py create mode 100644 mlflow/utils/workspace_utils.py create mode 100644 tests/db/test_workspace_migration.py diff --git a/docs/api_reference/api_inventory.txt b/docs/api_reference/api_inventory.txt index 7f8ca89707d1a..0899539839a54 100644 --- a/docs/api_reference/api_inventory.txt +++ b/docs/api_reference/api_inventory.txt @@ -528,6 +528,7 @@ mlflow.entities.WebhookStatus.to_proto mlflow.entities.WebhookTestResult mlflow.entities.WebhookTestResult.from_proto mlflow.entities.WebhookTestResult.to_proto +mlflow.entities.Workspace mlflow.entities.assessment.Assessment mlflow.entities.assessment.Expectation mlflow.entities.assessment.Feedback @@ -627,6 +628,7 @@ mlflow.entities.webhook.Webhook mlflow.entities.webhook.WebhookEvent mlflow.entities.webhook.WebhookStatus mlflow.entities.webhook.WebhookTestResult +mlflow.entities.workspace.Workspace mlflow.evaluate mlflow.exceptions.get_error_code mlflow.finalize_logged_model diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py index ea631158757df..5cc03b222265b 100644 --- a/mlflow/entities/__init__.py +++ b/mlflow/entities/__init__.py @@ -62,6 +62,7 @@ WebhookStatus, WebhookTestResult, ) +from mlflow.entities.workspace import Workspace __all__ = [ "Experiment", @@ -125,6 +126,7 @@ "WebhookEvent", "WebhookStatus", "WebhookTestResult", + "Workspace", ] diff --git a/mlflow/entities/workspace.py b/mlflow/entities/workspace.py new file mode 100644 index 0000000000000..cfd74daf09eaf --- /dev/null +++ b/mlflow/entities/workspace.py @@ -0,0 +1,13 @@ +"""Workspace entity shared between server and stores.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class Workspace: + """Minimal metadata describing a workspace.""" + + name: str + description: str | None = None diff --git a/mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py b/mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py new file mode 100644 index 0000000000000..9f9ae6218b546 --- /dev/null +++ b/mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py @@ -0,0 +1,732 @@ +"""Add workspace columns and catalog table + +Create Date: 2025-11-18 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import inspect + +# revision identifiers, used by Alembic. +revision = "1b5f0d9ad7c1" +down_revision = "bf29a5ff90ea" +branch_labels = None +depends_on = None + +_NAMING_CONVENTION = { + "pk": "pk_%(table_name)s", + "fk": "fk_%(table_name)s_%(referred_table_name)s_%(column_0_name)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", +} + +_WORKSPACE_TABLES = [ + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", + "webhooks", +] + +# Older SQLite migrations emitted unnamed foreign keys. When batch-altering tables we need the +# legacy names so we can drop the constraints deterministically; this mapping gives us the +# aliases for those historical definitions. +_SQLITE_LEGACY_FKS = { + ( + "model_versions", + "registered_models", + ("name",), + ): "fk_model_versions_registered_models_name", + ( + "registered_model_tags", + "registered_models", + ("name",), + ): "fk_registered_model_tags_registered_models_name", + ( + "model_version_tags", + "model_versions", + ("name", "version"), + ): "fk_model_version_tags_model_versions_name", + ( + "registered_model_aliases", + "registered_models", + ("name",), + ): "fk_registered_model_aliases_registered_models_name", +} + + +def _workspace_column(): + return sa.Column( + "workspace", + sa.String(length=63), + nullable=False, + server_default=sa.text("'default'"), + ) + + +def _fetch_mssql_unique_metadata( + conn, dialect_name: str, schema: str | None, table_name: str, for_indexes: bool = False +): + """Fetch unique constraint or index metadata for the given MSSQL table. + + SQLAlchemy's inspector doesn't implement ``get_unique_constraints`` for MSSQL, so we query the + catalog tables directly. When ``for_indexes`` is True we return unique indexes (needed when a + unique constraint is materialized as an index); otherwise we return metadata that looks like + ``get_unique_constraints``. + """ + if dialect_name != "mssql": + return [] + + if for_indexes: + query = sa.text( + """ + SELECT + i.name, + STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) AS column_names + FROM sys.indexes i + JOIN sys.tables t ON i.object_id = t.object_id + JOIN sys.schemas s ON t.schema_id = s.schema_id + JOIN sys.index_columns ic + ON i.object_id = ic.object_id + AND i.index_id = ic.index_id + JOIN sys.columns c + ON ic.object_id = c.object_id + AND ic.column_id = c.column_id + WHERE i.is_unique = 1 + AND i.is_primary_key = 0 + AND ic.is_included_column = 0 + AND t.name = :table_name + AND (:schema IS NULL OR s.name = :schema) + GROUP BY i.name + """ + ) + else: + query = sa.text( + """ + SELECT + kc.name, + STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) AS column_names + FROM sys.key_constraints kc + JOIN sys.tables t ON kc.parent_object_id = t.object_id + JOIN sys.schemas s ON t.schema_id = s.schema_id + JOIN sys.index_columns ic + ON kc.parent_object_id = ic.object_id + AND kc.unique_index_id = ic.index_id + JOIN sys.columns c + ON ic.object_id = c.object_id + AND ic.column_id = c.column_id + WHERE kc.type = 'UQ' + AND t.name = :table_name + AND (:schema IS NULL OR s.name = :schema) + GROUP BY kc.name + """ + ) + + result = conn.execute(query, {"table_name": table_name, "schema": schema}).fetchall() + return [ + { + "name": row[0], + "column_names": [col.strip() for col in row[1].split(",") if col] if row[1] else [], + } + for row in result + ] + + +def _with_batch(table_name): + return op.batch_alter_table(table_name, recreate="auto", naming_convention=_NAMING_CONVENTION) + + +# Once the tracking and model registry stores support workspaces, we can remove the +# server_default to ensure the stores are properly setting the workspace. +def upgrade(): + conn = op.get_bind() + inspector = inspect(conn) + dialect_name = conn.dialect.name + schema = op.get_context().version_table_schema or inspector.default_schema_name + + def _get_unique_constraints(table_name: str): + try: + return inspector.get_unique_constraints(table_name) + except NotImplementedError: + if dialect_name != "mssql": + raise + # SQL Server's inspector does not implement get_unique_constraints; fall back to + # querying the catalog tables directly via _fetch_mssql_unique_metadata. + return _fetch_mssql_unique_metadata( + conn, + dialect_name, + schema, + table_name, + for_indexes=False, + ) + + def _get_unique_indexes(table_name: str): + try: + return inspector.get_indexes(table_name) + except NotImplementedError: + if dialect_name != "mssql": + raise + + metadata = ( + _fetch_mssql_unique_metadata( + conn, + dialect_name, + schema, + table_name, + for_indexes=True, + ) + or [] + ) + + for entry in metadata: + entry["unique"] = True + return metadata + + def _detect_unique_on_name(table_name: str): + expected_name = _NAMING_CONVENTION["uq"] % { + "table_name": table_name, + "column_0_name": "name", + } + + for constraint in _get_unique_constraints(table_name) or []: + cols = constraint.get("column_names") or [] + name = constraint.get("name") + if cols == ["name"] or name == expected_name: + return name or expected_name, None + + for index in _get_unique_indexes(table_name) or []: + if index.get("unique") and index.get("column_names") == ["name"]: + return None, index["name"] + + return None, None + + def _collect_foreign_keys(table: str, referred_table: str): + names = [] + for fk in inspector.get_foreign_keys(table): + if fk.get("referred_table") == referred_table: + name = fk.get("name") + if not name and dialect_name == "sqlite": + key = (table, referred_table, tuple(fk.get("constrained_columns") or ())) + name = _SQLITE_LEGACY_FKS.get(key) + if not name: + continue + names.append(name) + return names + + def _create_workspace_indexes_and_catalog(): + op.create_index("idx_experiments_workspace", "experiments", ["workspace"]) + op.create_index("idx_registered_models_workspace", "registered_models", ["workspace"]) + op.create_index( + "idx_experiments_workspace_creation_time", + "experiments", + ["workspace", "creation_time"], + unique=False, + ) + op.create_index("idx_evaluation_datasets_workspace", "evaluation_datasets", ["workspace"]) + op.create_index("idx_webhooks_workspace", "webhooks", ["workspace"]) + + op.create_table( + "workspaces", + sa.Column("name", sa.String(length=63), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("name", name="workspaces_pk"), + ) + + metadata = sa.MetaData() + workspaces_table = sa.Table( + "workspaces", + metadata, + sa.Column("name", sa.String(length=63)), + sa.Column("description", sa.Text()), + ) + + conn.execute( + workspaces_table.insert().values( + name="default", + description="Default workspace for legacy resources", + ) + ) + + experiments_unique_constraint, experiments_unique_index = _detect_unique_on_name("experiments") + registered_models_unique_constraint, registered_models_unique_index = _detect_unique_on_name( + "registered_models" + ) + + fk_model_versions = _collect_foreign_keys("model_versions", "registered_models") + fk_registered_model_tags = _collect_foreign_keys("registered_model_tags", "registered_models") + fk_registered_model_aliases = _collect_foreign_keys( + "registered_model_aliases", "registered_models" + ) + fk_model_version_tags = _collect_foreign_keys("model_version_tags", "model_versions") + + if dialect_name == "sqlite": + # Let SQLite handle the foreign key drops inside the batches so each table is recreated + # only once. + with _with_batch("experiments") as batch_op: + if experiments_unique_constraint: + batch_op.drop_constraint(experiments_unique_constraint, type_="unique") + elif experiments_unique_index: + batch_op.drop_index(experiments_unique_index) + batch_op.add_column(_workspace_column()) + batch_op.create_unique_constraint( + "uq_experiments_workspace_name", + ["workspace", "name"], + ) + + with _with_batch("registered_models") as batch_op: + if registered_models_unique_constraint: + batch_op.drop_constraint(registered_models_unique_constraint, type_="unique") + elif registered_models_unique_index: + batch_op.drop_index(registered_models_unique_index) + batch_op.add_column(_workspace_column()) + batch_op.drop_constraint("registered_model_pk", type_="primary") + batch_op.create_primary_key("registered_model_pk", ["workspace", "name"]) + + with _with_batch("model_versions") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_model_versions: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("model_version_pk", type_="primary") + batch_op.create_primary_key("model_version_pk", ["workspace", "name", "version"]) + batch_op.create_foreign_key( + "fk_model_versions_registered_models", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + with _with_batch("registered_model_tags") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_registered_model_tags: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("registered_model_tag_pk", type_="primary") + batch_op.create_primary_key("registered_model_tag_pk", ["workspace", "key", "name"]) + batch_op.create_foreign_key( + "fk_registered_model_tags_registered_models", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + with _with_batch("model_version_tags") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_model_version_tags: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("model_version_tag_pk", type_="primary") + batch_op.create_primary_key( + "model_version_tag_pk", + ["workspace", "key", "name", "version"], + ) + batch_op.create_foreign_key( + "fk_model_version_tags_model_versions", + "model_versions", + ["workspace", "name", "version"], + ["workspace", "name", "version"], + onupdate="CASCADE", + ) + + with _with_batch("registered_model_aliases") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_registered_model_aliases: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("registered_model_alias_pk", type_="primary") + batch_op.create_primary_key( + "registered_model_alias_pk", + ["workspace", "name", "alias"], + ) + batch_op.create_foreign_key( + "fk_registered_model_aliases_registered_models", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + with _with_batch("evaluation_datasets") as batch_op: + batch_op.add_column(_workspace_column()) + + with _with_batch("webhooks") as batch_op: + batch_op.add_column(_workspace_column()) + + _create_workspace_indexes_and_catalog() + + return + + # Non-SQLite dialects can issue direct ALTER TABLE statements, which avoids rebuilding the + # tables. This code duplication is worth the performance gain of not rebuilding the tables. + # We could potentially leverage Alembic's batch_alter_table with recreate="auto" to avoid the + # duplication, but parts of this migration caused the tables to be recreated anyways. + + def _drop_fk_constraints(table_name: str, fk_names: list[str]): + for fk_name in fk_names: + if fk_name: + op.drop_constraint(fk_name, table_name=table_name, type_="foreignkey") + + def _drop_unique_on_name(table_name: str, constraint: str | None, index: str | None): + if constraint: + op.drop_constraint(constraint, table_name=table_name, type_="unique") + elif index: + op.drop_index(index, table_name=table_name) + + _drop_unique_on_name("experiments", experiments_unique_constraint, experiments_unique_index) + op.add_column( + "experiments", + _workspace_column(), + ) + op.create_unique_constraint( + "uq_experiments_workspace_name", + "experiments", + ["workspace", "name"], + ) + + _drop_fk_constraints("model_versions", fk_model_versions) + _drop_fk_constraints("registered_model_tags", fk_registered_model_tags) + _drop_fk_constraints("registered_model_aliases", fk_registered_model_aliases) + _drop_fk_constraints("model_version_tags", fk_model_version_tags) + + _drop_unique_on_name( + "registered_models", + registered_models_unique_constraint, + registered_models_unique_index, + ) + op.add_column( + "registered_models", + _workspace_column(), + ) + op.drop_constraint("registered_model_pk", "registered_models", type_="primary") + op.create_primary_key("registered_model_pk", "registered_models", ["workspace", "name"]) + + op.add_column( + "model_versions", + _workspace_column(), + ) + op.drop_constraint("model_version_pk", "model_versions", type_="primary") + op.create_primary_key("model_version_pk", "model_versions", ["workspace", "name", "version"]) + op.create_foreign_key( + "fk_model_versions_registered_models", + "model_versions", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + op.add_column( + "registered_model_tags", + _workspace_column(), + ) + op.drop_constraint("registered_model_tag_pk", "registered_model_tags", type_="primary") + op.create_primary_key( + "registered_model_tag_pk", + "registered_model_tags", + ["workspace", "key", "name"], + ) + op.create_foreign_key( + "fk_registered_model_tags_registered_models", + "registered_model_tags", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + op.add_column( + "model_version_tags", + _workspace_column(), + ) + op.drop_constraint("model_version_tag_pk", "model_version_tags", type_="primary") + op.create_primary_key( + "model_version_tag_pk", + "model_version_tags", + ["workspace", "key", "name", "version"], + ) + op.create_foreign_key( + "fk_model_version_tags_model_versions", + "model_version_tags", + "model_versions", + ["workspace", "name", "version"], + ["workspace", "name", "version"], + onupdate="CASCADE", + ) + + op.add_column( + "registered_model_aliases", + _workspace_column(), + ) + op.drop_constraint("registered_model_alias_pk", "registered_model_aliases", type_="primary") + op.create_primary_key( + "registered_model_alias_pk", + "registered_model_aliases", + ["workspace", "name", "alias"], + ) + op.create_foreign_key( + "fk_registered_model_aliases_registered_models", + "registered_model_aliases", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + for table in ["evaluation_datasets", "webhooks"]: + op.add_column( + table, + _workspace_column(), + ) + + _create_workspace_indexes_and_catalog() + + +def downgrade(): + conn = op.get_bind() + dialect_name = conn.dialect.name + + def _assert_no_workspace_conflicts( + table_name: str, + columns: tuple[str, ...], + resource_description: str, + ): + table = sa.Table(table_name, sa.MetaData(), autoload_with=conn) + group_columns = [table.c[column] for column in columns] + stmt = sa.select(*group_columns).group_by(*group_columns).having(sa.func.count() > 1) + conflicts = conn.execute(stmt).fetchall() + if conflicts: + formatted_conflicts = ", ".join( + "; ".join(f"{column}={value!r}" for column, value in zip(columns, row)) + for row in conflicts[:5] + ) + if len(conflicts) > 5: + formatted_conflicts += ", ..." + raise RuntimeError( + "Downgrade aborted: merging workspaces would create duplicate " + f"{resource_description}. Resolve the following conflicts by deleting or renaming " + f"the affected resources and retry: {formatted_conflicts}" + ) + + def _move_resources_to_default_workspace(table_name: str): + table = sa.Table(table_name, sa.MetaData(), autoload_with=conn) + conn.execute( + table.update().where(table.c.workspace != "default").values(workspace="default") + ) + + conflict_specs = [ + ("experiments", ("name",), "experiments with the same name"), + ("registered_models", ("name",), "registered models with the same name"), + ( + "evaluation_datasets", + ("name",), + "evaluation datasets with the same name", + ), + ( + "model_versions", + ("name", "version"), + "model versions with the same model name and version", + ), + ( + "registered_model_tags", + ("name", "key"), + "registered model tags with the same model name and key", + ), + ( + "model_version_tags", + ("name", "version", "key"), + "model version tags with the same model name, version, and key", + ), + ( + "registered_model_aliases", + ("name", "alias"), + "registered model aliases with the same model name and alias", + ), + ] + + for table_name, columns, description in conflict_specs: + _assert_no_workspace_conflicts(table_name, columns, description) + + for table in _WORKSPACE_TABLES: + _move_resources_to_default_workspace(table) + + drop_index_kwargs = {} if dialect_name == "mysql" else {"if_exists": True} + op.drop_index( + "idx_experiments_workspace_creation_time", table_name="experiments", **drop_index_kwargs + ) + op.drop_index( + "idx_registered_models_workspace", table_name="registered_models", **drop_index_kwargs + ) + op.drop_index("idx_experiments_workspace", table_name="experiments", **drop_index_kwargs) + op.drop_index( + "idx_evaluation_datasets_workspace", table_name="evaluation_datasets", **drop_index_kwargs + ) + op.drop_index("idx_webhooks_workspace", table_name="webhooks", **drop_index_kwargs) + + if dialect_name == "sqlite": + with _with_batch("model_version_tags") as batch_op: + batch_op.drop_constraint("fk_model_version_tags_model_versions", type_="foreignkey") + batch_op.drop_constraint("model_version_tag_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("model_version_tag_pk", ["key", "name", "version"]) + batch_op.create_foreign_key( + "model_version_tags_mv_fkey", + "model_versions", + ["name", "version"], + ["name", "version"], + onupdate="CASCADE", + ) + + with _with_batch("registered_model_aliases") as batch_op: + batch_op.drop_constraint( + "fk_registered_model_aliases_registered_models", type_="foreignkey" + ) + batch_op.drop_constraint("registered_model_alias_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("registered_model_alias_pk", ["name", "alias"]) + batch_op.create_foreign_key( + "registered_model_alias_name_fkey", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + with _with_batch("registered_model_tags") as batch_op: + batch_op.drop_constraint( + "fk_registered_model_tags_registered_models", type_="foreignkey" + ) + batch_op.drop_constraint("registered_model_tag_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("registered_model_tag_pk", ["key", "name"]) + batch_op.create_foreign_key( + "registered_model_tags_name_fkey", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + + with _with_batch("model_versions") as batch_op: + batch_op.drop_constraint("fk_model_versions_registered_models", type_="foreignkey") + batch_op.drop_constraint("model_version_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("model_version_pk", ["name", "version"]) + batch_op.create_foreign_key( + "model_versions_name_fkey", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + + with _with_batch("registered_models") as batch_op: + batch_op.drop_constraint("registered_model_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("registered_model_pk", ["name"]) + + with _with_batch("experiments") as batch_op: + batch_op.drop_constraint("uq_experiments_workspace_name", type_="unique") + batch_op.drop_column("workspace") + batch_op.create_unique_constraint("uq_experiments_name", ["name"]) + + with _with_batch("evaluation_datasets") as batch_op: + batch_op.drop_column("workspace") + + with _with_batch("webhooks") as batch_op: + batch_op.drop_column("workspace") + + op.drop_table("workspaces") + return + + op.drop_constraint( + "fk_model_version_tags_model_versions", "model_version_tags", type_="foreignkey" + ) + op.drop_constraint( + "fk_registered_model_aliases_registered_models", + "registered_model_aliases", + type_="foreignkey", + ) + op.drop_constraint( + "fk_registered_model_tags_registered_models", "registered_model_tags", type_="foreignkey" + ) + op.drop_constraint("fk_model_versions_registered_models", "model_versions", type_="foreignkey") + + op.drop_constraint("uq_experiments_workspace_name", table_name="experiments", type_="unique") + + op.drop_constraint("registered_model_alias_pk", "registered_model_aliases", type_="primary") + op.drop_constraint("model_version_tag_pk", "model_version_tags", type_="primary") + op.drop_constraint("registered_model_tag_pk", "registered_model_tags", type_="primary") + op.drop_constraint("model_version_pk", "model_versions", type_="primary") + op.drop_constraint("registered_model_pk", "registered_models", type_="primary") + + if dialect_name == "mssql": + # SQL Server binds defaults via named constraints. If we try to drop the column while a + # default is attached, the prior downgrade operations can leave behind those constraints, + # causing drop_column to fail. Clear the defaults explicitly first. + for table in _WORKSPACE_TABLES: + op.alter_column( + table_name=table, + column_name="workspace", + existing_type=sa.String(length=63), + existing_nullable=False, + server_default=None, + ) + + op.drop_column("model_version_tags", "workspace") + op.drop_column("registered_model_aliases", "workspace") + op.drop_column("registered_model_tags", "workspace") + op.drop_column("model_versions", "workspace") + op.drop_column("registered_models", "workspace") + op.drop_column("experiments", "workspace") + op.drop_column("evaluation_datasets", "workspace") + op.drop_column("webhooks", "workspace") + + op.create_primary_key("registered_model_pk", "registered_models", ["name"]) + op.create_primary_key("model_version_pk", "model_versions", ["name", "version"]) + op.create_primary_key("registered_model_tag_pk", "registered_model_tags", ["key", "name"]) + op.create_primary_key("model_version_tag_pk", "model_version_tags", ["key", "name", "version"]) + op.create_primary_key( + "registered_model_alias_pk", "registered_model_aliases", ["name", "alias"] + ) + + op.create_foreign_key( + "model_versions_name_fkey", + "model_versions", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + op.create_foreign_key( + "registered_model_tags_name_fkey", + "registered_model_tags", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + op.create_foreign_key( + "registered_model_alias_name_fkey", + "registered_model_aliases", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + op.create_foreign_key( + "model_version_tags_mv_fkey", + "model_version_tags", + "model_versions", + ["name", "version"], + ["name", "version"], + onupdate="CASCADE", + ) + + op.drop_table("workspaces") + + op.create_unique_constraint("uq_experiments_name", "experiments", ["name"]) diff --git a/mlflow/store/model_registry/dbmodels/models.py b/mlflow/store/model_registry/dbmodels/models.py index 2075c50162e99..36d963e65e47f 100644 --- a/mlflow/store/model_registry/dbmodels/models.py +++ b/mlflow/store/model_registry/dbmodels/models.py @@ -1,3 +1,4 @@ +import sqlalchemy as sa from cryptography.fernet import Fernet from sqlalchemy import ( BigInteger, @@ -32,12 +33,20 @@ from mlflow.environment_variables import MLFLOW_WEBHOOK_SECRET_ENCRYPTION_KEY from mlflow.store.db.base_sql_model import Base from mlflow.utils.time import get_current_time_millis +from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME class SqlRegisteredModel(Base): __tablename__ = "registered_models" - name = Column(String(256), unique=True, nullable=False) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) creation_time = Column(BigInteger, default=get_current_time_millis) @@ -45,7 +54,7 @@ class SqlRegisteredModel(Base): description = Column(String(5000), nullable=True) - __table_args__ = (PrimaryKeyConstraint("name", name="registered_model_pk"),) + __table_args__ = (PrimaryKeyConstraint("workspace", "name", name="registered_model_pk"),) def __repr__(self): return ( @@ -76,7 +85,14 @@ def to_mlflow_entity(self): class SqlModelVersion(Base): __tablename__ = "model_versions" - name = Column(String(256), ForeignKey("registered_models.name", onupdate="cascade")) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) version = Column(Integer, nullable=False) @@ -107,7 +123,14 @@ class SqlModelVersion(Base): "SqlRegisteredModel", backref=backref("model_versions", cascade="all") ) - __table_args__ = (PrimaryKeyConstraint("name", "version", name="model_version_pk"),) + __table_args__ = ( + ForeignKeyConstraint( + ["workspace", "name"], + ["registered_models.workspace", "registered_models.name"], + onupdate="cascade", + ), + PrimaryKeyConstraint("workspace", "name", "version", name="model_version_pk"), + ) # entity mappers def to_mlflow_entity(self): @@ -132,7 +155,14 @@ def to_mlflow_entity(self): class SqlRegisteredModelTag(Base): __tablename__ = "registered_model_tags" - name = Column(String(256), ForeignKey("registered_models.name", onupdate="cascade")) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) key = Column(String(250), nullable=False) @@ -143,7 +173,14 @@ class SqlRegisteredModelTag(Base): "SqlRegisteredModel", backref=backref("registered_model_tags", cascade="all") ) - __table_args__ = (PrimaryKeyConstraint("key", "name", name="registered_model_tag_pk"),) + __table_args__ = ( + ForeignKeyConstraint( + ["workspace", "name"], + ["registered_models.workspace", "registered_models.name"], + onupdate="cascade", + ), + PrimaryKeyConstraint("workspace", "key", "name", name="registered_model_tag_pk"), + ) def __repr__(self): return f"" @@ -156,7 +193,14 @@ def to_mlflow_entity(self): class SqlModelVersionTag(Base): __tablename__ = "model_version_tags" - name = Column(String(256)) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) version = Column(Integer) @@ -164,20 +208,18 @@ class SqlModelVersionTag(Base): value = Column(Text, nullable=True) - # linked entities - model_version = relationship( - "SqlModelVersion", - foreign_keys=[name, version], - backref=backref("model_version_tags", cascade="all"), - ) - __table_args__ = ( - PrimaryKeyConstraint("key", "name", "version", name="model_version_tag_pk"), ForeignKeyConstraint( - ("name", "version"), - ("model_versions.name", "model_versions.version"), + ["workspace", "name", "version"], + ["model_versions.workspace", "model_versions.name", "model_versions.version"], onupdate="cascade", ), + PrimaryKeyConstraint("workspace", "key", "name", "version", name="model_version_tag_pk"), + ) + + # linked entities + model_version = relationship( + "SqlModelVersion", backref=backref("model_version_tags", cascade="all") ) def __repr__(self): @@ -190,15 +232,15 @@ def to_mlflow_entity(self): class SqlRegisteredModelAlias(Base): __tablename__ = "registered_model_aliases" - name = Column( - String(256), - ForeignKey( - "registered_models.name", - onupdate="cascade", - ondelete="cascade", - name="registered_model_alias_name_fkey", - ), + + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), ) + + name = Column(String(256), nullable=False) alias = Column(String(256), nullable=False) version = Column(Integer, nullable=False) @@ -207,7 +249,16 @@ class SqlRegisteredModelAlias(Base): "SqlRegisteredModel", backref=backref("registered_model_aliases", cascade="all") ) - __table_args__ = (PrimaryKeyConstraint("name", "alias", name="registered_model_alias_pk"),) + __table_args__ = ( + ForeignKeyConstraint( + ["workspace", "name"], + ["registered_models.workspace", "registered_models.name"], + onupdate="cascade", + ondelete="cascade", + name="registered_model_alias_registered_model_fkey", + ), + PrimaryKeyConstraint("workspace", "name", "alias", name="registered_model_alias_pk"), + ) def __repr__(self): return f"" @@ -247,6 +298,12 @@ def process_result_value(self, value, dialect): class SqlWebhook(Base): __tablename__ = "webhooks" + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) webhook_id = Column(String(256), nullable=False) name = Column(String(256), nullable=False) description = Column(String(1000), nullable=True) @@ -261,6 +318,7 @@ class SqlWebhook(Base): PrimaryKeyConstraint("webhook_id", name="webhook_pk"), Index("idx_webhooks_status", "status"), Index("idx_webhooks_name", "name"), + Index("idx_webhooks_workspace", "workspace"), ) def __repr__(self): diff --git a/mlflow/store/model_registry/sqlalchemy_store.py b/mlflow/store/model_registry/sqlalchemy_store.py index 499fd713c3a23..f63fc3a37128e 100644 --- a/mlflow/store/model_registry/sqlalchemy_store.py +++ b/mlflow/store/model_registry/sqlalchemy_store.py @@ -61,6 +61,7 @@ _validate_webhook_name, _validate_webhook_url, ) +from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME _logger = logging.getLogger(__name__) @@ -194,12 +195,14 @@ def create_registered_model(self, name, tags=None, description=None, deployment_ creation_time=creation_time, last_updated_time=creation_time, description=description, + workspace=DEFAULT_WORKSPACE_NAME, ) tags_dict = {} for tag in tags or []: tags_dict[tag.key] = tag.value registered_model.registered_model_tags = [ - SqlRegisteredModelTag(key=key, value=value) for key, value in tags_dict.items() + SqlRegisteredModelTag(workspace=DEFAULT_WORKSPACE_NAME, key=key, value=value) + for key, value in tags_dict.items() ] session.add(registered_model) session.flush() @@ -696,7 +699,11 @@ def set_registered_model_tag(self, name, tag): with self.ManagedSessionMaker() as session: # check if registered model exists self._get_registered_model(session, name) - session.merge(SqlRegisteredModelTag(name=name, key=tag.key, value=tag.value)) + session.merge( + SqlRegisteredModelTag( + workspace=DEFAULT_WORKSPACE_NAME, name=name, key=tag.key, value=tag.value + ) + ) def delete_registered_model_tag(self, name, key): """ @@ -803,12 +810,14 @@ def next_version(sql_registered_model): run_id=run_id, run_link=run_link, description=description, + workspace=DEFAULT_WORKSPACE_NAME, ) tags_dict = {} for tag in tags or []: tags_dict[tag.key] = tag.value model_version.model_version_tags = [ - SqlModelVersionTag(key=key, value=value) for key, value in tags_dict.items() + SqlModelVersionTag(workspace=DEFAULT_WORKSPACE_NAME, key=key, value=value) + for key, value in tags_dict.items() ] session.add_all([sql_registered_model, model_version]) session.flush() @@ -1193,7 +1202,13 @@ def set_model_version_tag(self, name, version, tag): # check if model version exists self._get_sql_model_version(session, name, version) session.merge( - SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value) + SqlModelVersionTag( + workspace=DEFAULT_WORKSPACE_NAME, + name=name, + version=version, + key=tag.key, + value=tag.value, + ) ) def delete_model_version_tag(self, name, version, key): @@ -1248,7 +1263,11 @@ def set_registered_model_alias(self, name, alias, version): with self.ManagedSessionMaker() as session: # check if model version exists self._get_sql_model_version(session, name, version) - session.merge(SqlRegisteredModelAlias(name=name, alias=alias, version=version)) + session.merge( + SqlRegisteredModelAlias( + workspace=DEFAULT_WORKSPACE_NAME, name=name, alias=alias, version=version + ) + ) def delete_registered_model_alias(self, name, alias): """ diff --git a/mlflow/store/tracking/dbmodels/models.py b/mlflow/store/tracking/dbmodels/models.py index 75d84b60cd1b0..0d38bdb02c01e 100644 --- a/mlflow/store/tracking/dbmodels/models.py +++ b/mlflow/store/tracking/dbmodels/models.py @@ -61,6 +61,7 @@ from mlflow.tracing.utils import generate_assessment_id from mlflow.utils.mlflow_tags import MLFLOW_USER, _get_run_name_from_tags from mlflow.utils.time import get_current_time_millis +from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME SourceTypes = [ SourceType.to_string(SourceType.NOTEBOOK), @@ -94,10 +95,19 @@ class SqlExperiment(Base): """ Experiment ID: `Integer`. *Primary Key* for ``experiment`` table. """ - name = Column(String(256), unique=True, nullable=False) + name = Column(String(256), nullable=False) """ - Experiment name: `String` (limit 256 characters). Defined as *Unique* and *Non null* in - table schema. + Experiment name: `String` (limit 256 characters). Unique per workspace (see + ``uq_experiments_workspace_name``) and *Non null* in the table schema. + """ + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + """ + Workspace identifier for this experiment. Defaults to ``'default'`` for legacy rows. """ artifact_location = Column(String(256), nullable=True) """ @@ -124,6 +134,7 @@ class SqlExperiment(Base): name="experiments_lifecycle_stage", ), PrimaryKeyConstraint("experiment_id", name="experiment_pk"), + UniqueConstraint("workspace", "name", name="uq_experiments_workspace_name"), ) def __repr__(self): @@ -1301,6 +1312,16 @@ class SqlEvaluationDataset(Base): *Primary Key* for ``evaluation_datasets`` table. """ + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + """ + Workspace name that scopes this dataset. Defaults to ``'default'`` for legacy rows. + """ + name = Column(String(255), nullable=False) """ Dataset name: `String` (limit 255 characters). *Non null* in table schema. @@ -1355,6 +1376,7 @@ class SqlEvaluationDataset(Base): PrimaryKeyConstraint("dataset_id", name="evaluation_datasets_pk"), Index("index_evaluation_datasets_name", "name"), Index("index_evaluation_datasets_created_time", "created_time"), + Index("idx_evaluation_datasets_workspace", "workspace"), ) def to_mlflow_entity(self): diff --git a/mlflow/store/workspace/__init__.py b/mlflow/store/workspace/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/mlflow/store/workspace/dbmodels/__init__.py b/mlflow/store/workspace/dbmodels/__init__.py new file mode 100644 index 0000000000000..ee895f8ba3782 --- /dev/null +++ b/mlflow/store/workspace/dbmodels/__init__.py @@ -0,0 +1,3 @@ +from mlflow.store.workspace.dbmodels.models import SqlWorkspace + +__all__ = ["SqlWorkspace"] diff --git a/mlflow/store/workspace/dbmodels/models.py b/mlflow/store/workspace/dbmodels/models.py new file mode 100644 index 0000000000000..354d0fccb8b98 --- /dev/null +++ b/mlflow/store/workspace/dbmodels/models.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import sqlalchemy as sa +from sqlalchemy import Column, String, Text + +from mlflow.entities.workspace import Workspace +from mlflow.store.db.base_sql_model import Base + + +class SqlWorkspace(Base): + __tablename__ = "workspaces" + + name = Column(String(63), nullable=False) + description = Column(Text, nullable=True) + + __table_args__ = (sa.PrimaryKeyConstraint("name", name="workspaces_pk"),) + + def __repr__(self) -> str: # pragma: no cover + return f"" + + def to_mlflow_entity(self) -> Workspace: + return Workspace(name=self.name, description=self.description) diff --git a/mlflow/utils/workspace_utils.py b/mlflow/utils/workspace_utils.py new file mode 100644 index 0000000000000..3d6a81d00cca6 --- /dev/null +++ b/mlflow/utils/workspace_utils.py @@ -0,0 +1,5 @@ +"""Utility helpers for workspace-aware database schema defaults.""" + +DEFAULT_WORKSPACE_NAME = "default" + +__all__ = ["DEFAULT_WORKSPACE_NAME"] diff --git a/tests/db/check_migration.py b/tests/db/check_migration.py index 271e8f6c473f5..36a378dcbef57 100644 --- a/tests/db/check_migration.py +++ b/tests/db/check_migration.py @@ -49,6 +49,16 @@ SqlModelVersionTag.__tablename__, ] SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" +WORKSPACE_TABLES = { + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", + "webhooks", +} class Model(mlflow.pyfunc.PythonModel): @@ -72,7 +82,7 @@ def log_everything(): client.create_registered_model( registered_model_name, tags={"tag": "registered_model"}, description="description" ) - client.create_model_version( + model_version = client.create_model_version( registered_model_name, model_info.model_uri, run_id=run.info.run_id, @@ -80,6 +90,41 @@ def log_everything(): run_link="run_link", description="description", ) + client.set_registered_model_alias( + name=registered_model_name, + alias="prod", + version=model_version.version, + ) + # Create an additional experiment/model to ensure workspace backfills cover multiple resources. + mlflow.create_experiment(uuid.uuid4().hex) + client.create_registered_model(uuid.uuid4().hex) + client.create_webhook( + name=f"migration-webhook-{uuid.uuid4().hex}", + url="https://example.com/hook", + events=["model_version.created"], + description="workspace-migration-check", + ) + engine = sa.create_engine(os.environ["MLFLOW_TRACKING_URI"]) + metadata = sa.MetaData() + evaluation_datasets_table = sa.Table( + "evaluation_datasets", + metadata, + autoload_with=engine, + ) + with engine.begin() as conn: + conn.execute( + sa.insert(evaluation_datasets_table).values( + dataset_id=uuid.uuid4().hex, + name="workspace-migration-dataset", + schema="{}", + profile="{}", + digest=uuid.uuid4().hex, + created_time=0, + last_update_time=0, + created_by="user", + last_updated_by="user", + ) + ) def connect_to_mlflow_db(): @@ -113,6 +158,10 @@ def post_migration(): df_actual = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn) df_expected = pd.read_pickle(SNAPSHOTS_DIR / f"{table}.pkl") pd.testing.assert_frame_equal(df_actual[df_expected.columns], df_expected) + for table in WORKSPACE_TABLES: + df = pd.read_sql(sa.text(f"SELECT DISTINCT workspace FROM {table}"), conn) + assert not df["workspace"].isna().any(), f"{table} contains NULL workspace values" + assert set(df["workspace"]) == {"default"}, f"{table} contains non-default workspaces" if __name__ == "__main__": diff --git a/tests/db/schemas/mssql.sql b/tests/db/schemas/mssql.sql index 966b5eca98a61..6e0746da4f3a9 100644 --- a/tests/db/schemas/mssql.sql +++ b/tests/db/schemas/mssql.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255) COLLATE "SQL_Latin1_General_CP1_CI_AS", last_updated_by VARCHAR(255) COLLATE "SQL_Latin1_General_CP1_CI_AS", + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,7 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32) COLLATE "SQL_Latin1_General_CP1_CI_AS", creation_time BIGINT, last_update_time BIGINT, - CONSTRAINT experiment_pk PRIMARY KEY (experiment_id) + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name) ) @@ -79,7 +82,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000) COLLATE "SQL_Latin1_General_CP1_CI_AS", - CONSTRAINT registered_model_pk PRIMARY KEY (name) + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -93,10 +97,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, + description VARCHAR COLLATE "SQL_Latin1_General_CP1_CI_AS", + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, experiment_id INTEGER NOT NULL, @@ -180,8 +192,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", run_link VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", storage_location VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - CONSTRAINT "FK__model_vers__name__6B24EA82" FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -189,8 +202,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -198,8 +212,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, value VARCHAR(5000) COLLATE "SQL_Latin1_General_CP1_CI_AS", name VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - CONSTRAINT "FK__registered__name__6EF57B66" FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -346,8 +361,9 @@ CREATE TABLE model_version_tags ( value VARCHAR COLLATE "SQL_Latin1_General_CP1_CI_AS", name VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - CONSTRAINT "FK__model_version_ta__71D1E811" FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/mysql.sql b/tests/db/schemas/mysql.sql index 413f1f3f855d6..8d92510577a20 100644 --- a/tests/db/schemas/mysql.sql +++ b/tests/db/schemas/mysql.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, PRIMARY KEY (dataset_id) ) @@ -37,8 +38,10 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, PRIMARY KEY (experiment_id), - CONSTRAINT experiments_lifecycle_stage CHECK ((`lifecycle_stage` in (_utf8mb4'active',_utf8mb4'deleted'))) + CONSTRAINT experiments_lifecycle_stage CHECK ((`lifecycle_stage` in (_utf8mb4'active',_utf8mb4'deleted'))), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name) ) @@ -80,7 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - PRIMARY KEY (name) + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, name) ) @@ -94,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -182,8 +194,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - PRIMARY KEY (name, version), - CONSTRAINT model_versions_ibfk_1 FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -191,8 +204,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -200,8 +214,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - PRIMARY KEY (key, name), - CONSTRAINT registered_model_tags_ibfk_1 FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -354,8 +369,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - PRIMARY KEY (key, name, version), - CONSTRAINT model_version_tags_ibfk_1 FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/postgresql.sql b/tests/db/schemas/postgresql.sql index ffeb4527e26b5..9081fd88a546c 100644 --- a/tests/db/schemas/postgresql.sql +++ b/tests/db/schemas/postgresql.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,8 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), - CONSTRAINT experiments_name_key UNIQUE (name), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name), CONSTRAINT experiments_lifecycle_stage CHECK (lifecycle_stage::text = ANY (ARRAY['active'::character varying, 'deleted'::character varying]::text[])) ) @@ -81,7 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - CONSTRAINT registered_model_pk PRIMARY KEY (name) + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -95,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -184,8 +195,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - CONSTRAINT model_versions_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -193,8 +205,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -202,8 +215,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - CONSTRAINT registered_model_tags_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -353,8 +367,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - CONSTRAINT model_version_tags_name_version_fkey FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/sqlite.sql b/tests/db/schemas/sqlite.sql index c7b87e3461731..721286a886e81 100644 --- a/tests/db/schemas/sqlite.sql +++ b/tests/db/schemas/sqlite.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,8 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), - UNIQUE (name), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name), CONSTRAINT experiments_lifecycle_stage CHECK (lifecycle_stage IN ('active', 'deleted')) ) @@ -81,8 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - CONSTRAINT registered_model_pk PRIMARY KEY (name), - UNIQUE (name) + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -96,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -185,8 +195,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -194,8 +205,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -203,8 +215,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -356,8 +369,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/test_schema.py b/tests/db/test_schema.py index 15758546d6caa..1bcec568d7ada 100644 --- a/tests/db/test_schema.py +++ b/tests/db/test_schema.py @@ -1,11 +1,19 @@ import difflib +import logging import re from pathlib import Path from typing import NamedTuple import pytest -from sqlalchemy import create_engine -from sqlalchemy.schema import CreateTable, MetaData +from sqlalchemy import create_engine, inspect +from sqlalchemy.schema import CreateTable, MetaData, UniqueConstraint + +_logger = logging.getLogger(__name__) + +_DIALECT_REFLECTED_UNIQUE_CONSTRAINTS = { + "mysql": {"uq_experiments_workspace_name"}, + "mssql": {"uq_experiments_workspace_name"}, +} import mlflow from mlflow.environment_variables import MLFLOW_TRACKING_URI @@ -25,6 +33,7 @@ def dump_schema(db_uri): engine = create_engine(db_uri) created_tables_metadata = MetaData() created_tables_metadata.reflect(bind=engine) + _reattach_missing_unique_constraints(engine, created_tables_metadata) # Write out table schema as described in # https://docs.sqlalchemy.org/en/13/faq/metadata_schema.html#how-can-i-get-the-create-table-drop-table-output-as-a-string lines = [] @@ -34,6 +43,70 @@ def dump_schema(db_uri): return "\n".join(lines) +def _reattach_missing_unique_constraints(engine, metadata): + constraint_names = _DIALECT_REFLECTED_UNIQUE_CONSTRAINTS.get(engine.dialect.name) + if not constraint_names: + return + inspector = inspect(engine) + for table in metadata.sorted_tables: + existing_unique_columns = { + tuple(constraint.columns.keys()) + for constraint in table.constraints + if isinstance(constraint, UniqueConstraint) + } + # Not all dialects reflect `UniqueConstraint` objects the same way. MySQL reports + # them as indexes; MSSQL doesn't implement `get_unique_constraints` at all. We + # normalize the reflection results via `_get_unique_constraints` so the same code + # path can reattach missing `UniqueConstraint`s across dialects. + for unique in _get_unique_constraints(inspector, engine.dialect.name, table.name): + name = unique.get("name") + columns = tuple(unique.get("column_names") or ()) + duplicates_index = unique.get("duplicates_index") + if not columns or name not in constraint_names: + continue + if engine.dialect.name == "mysql" and not duplicates_index: + # MySQL exposes unique constraints as unique indexes. SQLAlchemy treats those as + # indexes during reflection, so the reflected metadata lacks the original + # `UniqueConstraint`. Only recreate constraints that are backed by an actual + # unique index reported via `duplicates_index`. + continue + if columns in existing_unique_columns: + continue + missing_columns = tuple(column for column in columns if column not in table.c) + if missing_columns: + _logger.warning( + "Skipping recreation of unique constraint '%s' on table '%s' due to " + "missing columns: %s", + name, + table.name, + ", ".join(missing_columns), + ) + continue + constraint = UniqueConstraint(*[table.c[column] for column in columns], name=name) + table.append_constraint(constraint) + existing_unique_columns.add(columns) + + +def _get_unique_constraints(inspector, dialect, table_name): + try: + unique_constraints = inspector.get_unique_constraints(table_name) + except NotImplementedError: + unique_constraints = None + if unique_constraints is None: + unique_constraints = [] + if not unique_constraints: + unique_constraints = [ + { + "name": index.get("name"), + "column_names": index.get("column_names"), + "duplicates_index": index.get("unique"), + } + for index in inspector.get_indexes(table_name) + if index.get("unique") + ] + return unique_constraints + + class _CreateTable(NamedTuple): table: str columns: str diff --git a/tests/db/test_workspace_migration.py b/tests/db/test_workspace_migration.py new file mode 100644 index 0000000000000..30272de2ae5c5 --- /dev/null +++ b/tests/db/test_workspace_migration.py @@ -0,0 +1,1073 @@ +import os +import re +from contextlib import contextmanager + +import pytest +import sqlalchemy as sa +from alembic import command + +from mlflow.store.db.utils import _get_alembic_config +from mlflow.store.tracking.dbmodels.initial_models import Base as InitialBase + +_LEGACY_REGISTERED_MODEL_TAGS = sa.table( + "registered_model_tags", + sa.column("key"), + sa.column("value"), + sa.column("name"), +) + +_LEGACY_MODEL_VERSION_TAGS = sa.table( + "model_version_tags", + sa.column("key"), + sa.column("value"), + sa.column("name"), + sa.column("version"), +) + +_LEGACY_REGISTERED_MODEL_ALIASES = sa.table( + "registered_model_aliases", + sa.column("alias"), + sa.column("version"), + sa.column("name"), +) + +_LEGACY_EVALUATION_DATASETS = sa.table( + "evaluation_datasets", + sa.column("dataset_id"), + sa.column("name"), + sa.column("schema"), + sa.column("profile"), + sa.column("digest"), + sa.column("created_time"), + sa.column("last_update_time"), + sa.column("created_by"), + sa.column("last_updated_by"), +) + +_WORKSPACE_TABLES = ( + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", +) + +_REGISTERED_MODEL_TAGS = sa.table( + "registered_model_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), +) + +_MODEL_VERSION_TAGS = sa.table( + "model_version_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), + sa.column("version"), +) + +_REGISTERED_MODEL_ALIASES = sa.table( + "registered_model_aliases", + sa.column("workspace"), + sa.column("name"), + sa.column("alias"), + sa.column("version"), +) + +_EVALUATION_DATASETS = sa.table( + "evaluation_datasets", + sa.column("dataset_id"), + sa.column("name"), + sa.column("schema"), + sa.column("profile"), + sa.column("digest"), + sa.column("created_time"), + sa.column("last_update_time"), + sa.column("created_by"), + sa.column("last_updated_by"), + sa.column("workspace"), +) + +REVISION = "1b5f0d9ad7c1" +PREVIOUS_REVISION = "bf29a5ff90ea" + +DB_URI = os.environ.get("MLFLOW_TRACKING_URI") +USE_EXTERNAL_DB = DB_URI is not None and not DB_URI.startswith("sqlite") + + +@pytest.fixture(scope="session", autouse=True) +def _upgrade_external_db_to_head_after_suite(): + """ + When running under Docker (i.e., with a shared external DB), make sure the DB ends up on the + latest revision once this module finishes. Individual tests intentionally downgrade to the + pre-workspace schema, so without this hook, the subsequent suites in the database workflow would + run against an outdated schema after new migrations land. + """ + yield + if USE_EXTERNAL_DB: + config = _get_alembic_config(DB_URI) + command.upgrade(config, "head") + + +@contextmanager +def _identity_insert(conn, table_name: str): + if conn.dialect.name != "mssql": + yield + return + + conn.execute(sa.text(f"SET IDENTITY_INSERT {table_name} ON")) + try: + yield + finally: + conn.execute(sa.text(f"SET IDENTITY_INSERT {table_name} OFF")) + + +def _insert_table_row(conn, table, **values): + conn.execute(sa.insert(table).values(**values)) + + +def _assert_workspace_column(inspector, table_name: str, expected_default: str): + columns = inspector.get_columns(table_name) + workspace = next((col for col in columns if col["name"] == "workspace"), None) + assert workspace is not None, f"{table_name} lacks workspace column" + assert not workspace.get("nullable", False) + default_value = _get_workspace_default(workspace) + assert default_value == expected_default + + +def _assert_workspace_columns(inspector, expected_default: str = "default"): + for table in _WORKSPACE_TABLES: + _assert_workspace_column(inspector, table, expected_default) + + +def _has_index(inspector, table: str, index_name: str, columns: list[str]): + indexes = inspector.get_indexes(table) + return any( + index["name"] == index_name and index.get("column_names") == columns for index in indexes + ) + + +def _prepare_database(tmp_path): + if USE_EXTERNAL_DB: + engine = sa.create_engine(DB_URI) + with engine.begin() as conn: + metadata = sa.MetaData() + metadata.reflect(bind=conn) + metadata.drop_all(bind=conn) + InitialBase.metadata.create_all(conn) + config = _get_alembic_config(DB_URI) + else: + db_path = tmp_path / "workspace_migration.sqlite" + url = f"sqlite:///{db_path}" + engine = sa.create_engine(url) + InitialBase.metadata.create_all(engine) + config = _get_alembic_config(url) + command.upgrade(config, PREVIOUS_REVISION) + return engine, config + + +def _seed_pre_workspace_entities(conn): + # This intentionally uses raw SQL matching the legacy schema (no workspace columns) + # so the migration under test is fully responsible for adding/backfilling the new + # fields. The helper insert functions below operate on the post-migration schema + # and therefore cannot be reused here. + with _identity_insert(conn, "experiments"): + conn.execute( + sa.text( + """ + INSERT INTO experiments ( + experiment_id, + name, + artifact_location, + lifecycle_stage, + creation_time, + last_update_time + ) + VALUES ( + :experiment_id, + :name, + :artifact_location, + :lifecycle_stage, + :creation_time, + :last_update_time + ) + """ + ), + { + "experiment_id": 1, + "name": "exp-default", + "artifact_location": "path", + "lifecycle_stage": "active", + "creation_time": 0, + "last_update_time": 0, + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO runs ( + run_uuid, + name, + source_type, + source_name, + entry_point_name, + user_id, + status, + start_time, + end_time, + source_version, + lifecycle_stage, + artifact_uri, + experiment_id + ) + VALUES ( + :run_uuid, + :name, + :source_type, + :source_name, + :entry_point_name, + :user_id, + :status, + :start_time, + :end_time, + :source_version, + :lifecycle_stage, + :artifact_uri, + :experiment_id + ) + """ + ), + { + "run_uuid": "run-default", + "name": "upgrade-validation-run", + "source_type": "LOCAL", + "source_name": "script.py", + "entry_point_name": "main", + "user_id": "user", + "status": "FINISHED", + "start_time": 0, + "end_time": 1, + "source_version": "abc123", + "lifecycle_stage": "active", + "artifact_uri": "path/artifacts", + "experiment_id": 1, + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO registered_models (name, creation_time, last_updated_time, description) + VALUES (:name, :creation_time, :last_updated_time, :description) + """ + ), + {"name": "rm-default", "creation_time": 0, "last_updated_time": 0, "description": "desc"}, + ) + conn.execute( + sa.text( + """ + INSERT INTO model_versions ( + name, + version, + creation_time, + last_updated_time, + user_id, + current_stage, + description, + source, + run_id, + status, + status_message, + run_link, + storage_location + ) + VALUES ( + :name, + :version, + :creation_time, + :last_updated_time, + :user_id, + :current_stage, + :description, + :source, + :run_id, + :status, + :status_message, + :run_link, + :storage_location + ) + """ + ), + { + "name": "rm-default", + "version": 1, + "creation_time": 0, + "last_updated_time": 0, + "user_id": "user", + "current_stage": "None", + "description": "desc", + "source": "source", + "run_id": "run-id", + "status": "READY", + "status_message": "message", + "run_link": "link", + "storage_location": "location", + }, + ) + _insert_table_row( + conn, + _LEGACY_REGISTERED_MODEL_TAGS, + key="tag", + value="value", + name="rm-default", + ) + _insert_table_row( + conn, + _LEGACY_MODEL_VERSION_TAGS, + key="tag", + value="value", + name="rm-default", + version=1, + ) + _insert_table_row( + conn, + _LEGACY_REGISTERED_MODEL_ALIASES, + alias="alias", + version=1, + name="rm-default", + ) + _insert_table_row( + conn, + _LEGACY_EVALUATION_DATASETS, + dataset_id="ds-default", + name="Dataset", + schema="schema", + profile="profile", + digest="digest", + created_time=0, + last_update_time=0, + created_by="user", + last_updated_by="user", + ) + + +def _get_workspace_default(column_info): + default = column_info.get("default") or column_info.get("server_default") + if default is None: + return None + value = str(default).strip() + if value.startswith("(") and value.endswith(")"): + value = value[1:-1] + value = value.strip() + value = value.strip("'\"") + if "::" in value: + value = value.split("::", 1)[0] + return value.strip("'\"") + + +def _add_workspace(conn, name: str, description: str): + conn.execute( + sa.text("INSERT INTO workspaces (name, description) VALUES (:name, :description)"), + {"name": name, "description": description}, + ) + + +def _insert_experiment( + conn, + *, + experiment_id: int, + name: str, + workspace: str, + artifact_location: str = "path", + lifecycle_stage: str = "active", +): + with _identity_insert(conn, "experiments"): + conn.execute( + sa.text( + """ + INSERT INTO experiments ( + experiment_id, + name, + artifact_location, + lifecycle_stage, + creation_time, + last_update_time, + workspace + ) + VALUES ( + :experiment_id, + :name, + :artifact_location, + :lifecycle_stage, + :creation_time, + :last_update_time, + :workspace + ) + """ + ), + { + "experiment_id": experiment_id, + "name": name, + "artifact_location": artifact_location, + "lifecycle_stage": lifecycle_stage, + "creation_time": 0, + "last_update_time": 0, + "workspace": workspace, + }, + ) + + +def _insert_run( + conn, + *, + run_uuid: str, + experiment_id: int, + name: str = "run", + artifact_uri: str = "path/artifacts", +): + conn.execute( + sa.text( + """ + INSERT INTO runs ( + run_uuid, + name, + source_type, + source_name, + entry_point_name, + user_id, + status, + start_time, + end_time, + source_version, + lifecycle_stage, + artifact_uri, + experiment_id + ) + VALUES ( + :run_uuid, + :name, + :source_type, + :source_name, + :entry_point_name, + :user_id, + :status, + :start_time, + :end_time, + :source_version, + :lifecycle_stage, + :artifact_uri, + :experiment_id + ) + """ + ), + { + "run_uuid": run_uuid, + "name": name, + "source_type": "LOCAL", + "source_name": "script.py", + "entry_point_name": "main", + "user_id": "user", + "status": "FINISHED", + "start_time": 0, + "end_time": 1, + "source_version": "abc123", + "lifecycle_stage": "active", + "artifact_uri": artifact_uri, + "experiment_id": experiment_id, + }, + ) + + +def _insert_registered_model( + conn, + *, + name: str, + workspace: str, + description: str = "desc", + creation_time: int = 0, +): + conn.execute( + sa.text( + """ + INSERT INTO registered_models ( + name, + creation_time, + last_updated_time, + description, + workspace + ) + VALUES ( + :name, + :creation_time, + :last_updated_time, + :description, + :workspace + ) + """ + ), + { + "name": name, + "creation_time": creation_time, + "last_updated_time": creation_time, + "description": description, + "workspace": workspace, + }, + ) + + +def _insert_model_version( + conn, + *, + name: str, + version: int, + workspace: str, + run_id: str = "run-id", + storage_location: str = "location", +): + conn.execute( + sa.text( + """ + INSERT INTO model_versions ( + name, + version, + creation_time, + last_updated_time, + user_id, + current_stage, + description, + source, + run_id, + status, + status_message, + run_link, + storage_location, + workspace + ) + VALUES ( + :name, + :version, + :creation_time, + :last_updated_time, + :user_id, + :current_stage, + :description, + :source, + :run_id, + :status, + :status_message, + :run_link, + :storage_location, + :workspace + ) + """ + ), + { + "name": name, + "version": version, + "creation_time": 0, + "last_updated_time": 0, + "user_id": "user", + "current_stage": "None", + "description": "desc", + "source": "source", + "run_id": run_id, + "status": "READY", + "status_message": "message", + "run_link": "link", + "storage_location": storage_location, + "workspace": workspace, + }, + ) + + +_REGISTERED_MODEL_TAGS = sa.table( + "registered_model_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), +) + + +def _insert_registered_model_tag( + conn, + *, + workspace: str, + name: str, + key: str, + value: str = "value", +): + _insert_table_row( + conn, + _REGISTERED_MODEL_TAGS, + workspace=workspace, + key=key, + value=value, + name=name, + ) + + +_MODEL_VERSION_TAGS = sa.table( + "model_version_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), + sa.column("version"), +) + + +def _insert_model_version_tag( + conn, + *, + workspace: str, + name: str, + version: int, + key: str, + value: str = "value", +): + _insert_table_row( + conn, + _MODEL_VERSION_TAGS, + workspace=workspace, + key=key, + value=value, + name=name, + version=version, + ) + + +_REGISTERED_MODEL_ALIASES = sa.table( + "registered_model_aliases", + sa.column("workspace"), + sa.column("name"), + sa.column("alias"), + sa.column("version"), +) + + +def _insert_registered_model_alias( + conn, + *, + workspace: str, + name: str, + alias: str, + version: int = 1, +): + _insert_table_row( + conn, + _REGISTERED_MODEL_ALIASES, + workspace=workspace, + name=name, + alias=alias, + version=version, + ) + + +_EVALUATION_DATASETS = sa.table( + "evaluation_datasets", + sa.column("dataset_id"), + sa.column("name"), + sa.column("schema"), + sa.column("profile"), + sa.column("digest"), + sa.column("created_time"), + sa.column("last_update_time"), + sa.column("created_by"), + sa.column("last_updated_by"), + sa.column("workspace"), +) + + +def _insert_evaluation_dataset( + conn, + *, + dataset_id: str, + workspace: str, + name: str = "Dataset", + digest: str = "digest", +): + conn.execute( + sa.insert(_EVALUATION_DATASETS).values( + dataset_id=dataset_id, + name=name, + schema="schema", + profile="profile", + digest=digest, + created_time=0, + last_update_time=0, + created_by="user", + last_updated_by="user", + workspace=workspace, + ) + ) + + +def _fetch_conflicts(conn, table_name: str, columns: tuple[str, ...]): + metadata = sa.MetaData() + table = sa.Table(table_name, metadata, autoload_with=conn) + group_columns = [table.c[column] for column in columns] + stmt = sa.select(*group_columns).group_by(*group_columns).having(sa.func.count() > 1) + return conn.execute(stmt).fetchall() + + +def test_workspace_migration_upgrade_adds_columns_and_backfills(tmp_path): + engine, config = _prepare_database(tmp_path) + try: + with engine.begin() as conn: + _seed_pre_workspace_entities(conn) + + command.upgrade(config, REVISION) + inspector = sa.inspect(engine) + + _assert_workspace_columns(inspector, "default") + + with engine.connect() as conn: + assert conn.execute( + sa.text( + "SELECT experiment_id, name, workspace FROM experiments ORDER BY experiment_id" + ) + ).fetchall() == [(1, "exp-default", "default")] + + assert conn.execute( + sa.text("SELECT run_uuid, experiment_id FROM runs ORDER BY run_uuid") + ).fetchall() == [("run-default", 1)] + + assert conn.execute( + sa.text("SELECT name, workspace FROM registered_models") + ).fetchall() == [("rm-default", "default")] + + assert conn.execute( + sa.text("SELECT name, version, workspace FROM model_versions") + ).fetchall() == [("rm-default", 1, "default")] + + assert conn.execute( + sa.select( + _REGISTERED_MODEL_TAGS.c.workspace, + _REGISTERED_MODEL_TAGS.c.name, + _REGISTERED_MODEL_TAGS.c.key, + ) + ).fetchall() == [("default", "rm-default", "tag")] + + assert conn.execute( + sa.select( + _MODEL_VERSION_TAGS.c.workspace, + _MODEL_VERSION_TAGS.c.name, + _MODEL_VERSION_TAGS.c.version, + _MODEL_VERSION_TAGS.c.key, + ) + ).fetchall() == [("default", "rm-default", 1, "tag")] + + assert conn.execute( + sa.text("SELECT workspace, name, alias FROM registered_model_aliases") + ).fetchall() == [("default", "rm-default", "alias")] + + assert conn.execute( + sa.text("SELECT dataset_id, workspace FROM evaluation_datasets") + ).fetchall() == [("ds-default", "default")] + + assert conn.execute( + sa.text("SELECT name, description FROM workspaces ORDER BY name") + ).fetchall() == [("default", "Default workspace for legacy resources")] + + pk_registered_models = inspector.get_pk_constraint("registered_models") + assert pk_registered_models["constrained_columns"] == ["workspace", "name"] + + pk_model_versions = inspector.get_pk_constraint("model_versions") + assert pk_model_versions["constrained_columns"] == [ + "workspace", + "name", + "version", + ] + + pk_registered_model_tags = inspector.get_pk_constraint("registered_model_tags") + assert pk_registered_model_tags["constrained_columns"] == [ + "workspace", + "key", + "name", + ] + + pk_model_version_tags = inspector.get_pk_constraint("model_version_tags") + assert pk_model_version_tags["constrained_columns"] == [ + "workspace", + "key", + "name", + "version", + ] + + pk_model_aliases = inspector.get_pk_constraint("registered_model_aliases") + assert pk_model_aliases["constrained_columns"] == [ + "workspace", + "name", + "alias", + ] + + try: + unique_experiments = inspector.get_unique_constraints("experiments") + except NotImplementedError: + if inspector.bind.dialect.name == "mssql": + unique_experiments = None + else: + raise + if unique_experiments is not None: + assert any( + {"workspace", "name"} == set(constraint.get("column_names", [])) + for constraint in unique_experiments + ) + + fk_model_versions = inspector.get_foreign_keys("model_versions") + assert any( + fk.get("constrained_columns") == ["workspace", "name"] + and fk.get("referred_table") == "registered_models" + for fk in fk_model_versions + ) + + assert _has_index(inspector, "experiments", "idx_experiments_workspace", ["workspace"]) + assert _has_index( + inspector, + "experiments", + "idx_experiments_workspace_creation_time", + ["workspace", "creation_time"], + ) + assert _has_index( + inspector, "registered_models", "idx_registered_models_workspace", ["workspace"] + ) + assert _has_index( + inspector, "evaluation_datasets", "idx_evaluation_datasets_workspace", ["workspace"] + ) + finally: + engine.dispose() + + +def test_workspace_migration_downgrade_reverts_schema(tmp_path): + engine, config = _prepare_database(tmp_path) + try: + command.upgrade(config, REVISION) + with engine.begin() as conn: + _add_workspace(conn, "team-a", "Team A") + _insert_experiment(conn, experiment_id=1, name="exp-default", workspace="default") + _insert_run( + conn, + run_uuid="run-default", + experiment_id=1, + name="downgrade-validation-run", + ) + _insert_experiment(conn, experiment_id=2, name="exp-team-a", workspace="team-a") + + command.downgrade(config, PREVIOUS_REVISION) + inspector = sa.inspect(engine) + + tables = inspector.get_table_names() + assert "workspaces" not in tables + + for table in ( + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", + ): + column_names = {col["name"] for col in inspector.get_columns(table)} + assert "workspace" not in column_names + + with engine.connect() as conn: + assert conn.execute( + sa.text("SELECT experiment_id, name FROM experiments ORDER BY experiment_id") + ).fetchall() == [(1, "exp-default"), (2, "exp-team-a")] + assert conn.execute( + sa.text("SELECT run_uuid, experiment_id FROM runs ORDER BY run_uuid") + ).fetchall() == [("run-default", 1)] + + pk_registered_models = inspector.get_pk_constraint("registered_models") + assert pk_registered_models["constrained_columns"] == ["name"] + + pk_model_versions = inspector.get_pk_constraint("model_versions") + assert pk_model_versions["constrained_columns"] == ["name", "version"] + + pk_registered_model_tags = inspector.get_pk_constraint("registered_model_tags") + assert pk_registered_model_tags["constrained_columns"] == ["key", "name"] + + pk_model_version_tags = inspector.get_pk_constraint("model_version_tags") + assert pk_model_version_tags["constrained_columns"] == ["key", "name", "version"] + + pk_registered_model_aliases = inspector.get_pk_constraint("registered_model_aliases") + assert pk_registered_model_aliases["constrained_columns"] == ["name", "alias"] + + try: + unique_experiments = inspector.get_unique_constraints("experiments") + except NotImplementedError: + if inspector.bind.dialect.name == "mssql": + unique_experiments = None + else: + raise + if unique_experiments is not None: + assert any( + set(constraint.get("column_names", [])) == {"name"} + for constraint in unique_experiments + ) + + fk_model_versions = inspector.get_foreign_keys("model_versions") + assert any( + fk.get("constrained_columns") == ["name"] + and fk.get("referred_table") == "registered_models" + for fk in fk_model_versions + ) + + fk_registered_model_tags = inspector.get_foreign_keys("registered_model_tags") + assert any( + fk.get("constrained_columns") == ["name"] + and fk.get("referred_table") == "registered_models" + for fk in fk_registered_model_tags + ) + + fk_model_version_tags = inspector.get_foreign_keys("model_version_tags") + assert any( + fk.get("constrained_columns") == ["name", "version"] + and fk.get("referred_table") == "model_versions" + for fk in fk_model_version_tags + ) + finally: + engine.dispose() + + +def _setup_experiment_conflict(conn): + _insert_experiment(conn, experiment_id=1, name="duplicate-exp", workspace="default") + _insert_run(conn, run_uuid="run-exp-default", experiment_id=1) + _insert_experiment(conn, experiment_id=2, name="duplicate-exp", workspace="team-a") + + +def _setup_registered_model_conflict(conn): + _insert_registered_model(conn, name="duplicate-model", workspace="default") + _insert_registered_model(conn, name="duplicate-model", workspace="team-a") + + +def _setup_model_version_conflict(conn): + _insert_registered_model(conn, name="mv-model", workspace="default") + _insert_registered_model(conn, name="mv-model", workspace="team-a") + _insert_model_version(conn, name="mv-model", version=1, workspace="default") + _insert_model_version(conn, name="mv-model", version=1, workspace="team-a") + + +def _setup_registered_model_tag_conflict(conn): + _insert_registered_model(conn, name="tag-model", workspace="default") + _insert_registered_model(conn, name="tag-model", workspace="team-a") + _insert_registered_model_tag(conn, workspace="default", name="tag-model", key="tag-key") + _insert_registered_model_tag(conn, workspace="team-a", name="tag-model", key="tag-key") + + +def _setup_model_version_tag_conflict(conn): + _insert_registered_model(conn, name="mvt-model", workspace="default") + _insert_registered_model(conn, name="mvt-model", workspace="team-a") + _insert_model_version(conn, name="mvt-model", version=1, workspace="default") + _insert_model_version(conn, name="mvt-model", version=1, workspace="team-a") + _insert_model_version_tag( + conn, workspace="default", name="mvt-model", version=1, key="mv-tag-key" + ) + _insert_model_version_tag( + conn, workspace="team-a", name="mvt-model", version=1, key="mv-tag-key" + ) + + +def _setup_registered_model_alias_conflict(conn): + _insert_registered_model(conn, name="alias-model", workspace="default") + _insert_registered_model(conn, name="alias-model", workspace="team-a") + _insert_registered_model_alias(conn, workspace="default", name="alias-model", alias="latest") + _insert_registered_model_alias(conn, workspace="team-a", name="alias-model", alias="latest") + + +def _setup_evaluation_dataset_conflict(conn): + _insert_evaluation_dataset( + conn, dataset_id="ds-default", name="duplicate-ds", workspace="default" + ) + _insert_evaluation_dataset( + conn, dataset_id="ds-team-a", name="duplicate-ds", workspace="team-a" + ) + + +@pytest.mark.parametrize( + ("setup_conflict", "expected_fragment", "case_slug"), + [ + (_setup_experiment_conflict, "duplicate experiments with the same name", "experiments"), + ( + _setup_registered_model_conflict, + "duplicate registered models with the same name", + "models", + ), + ( + _setup_evaluation_dataset_conflict, + "duplicate evaluation datasets with the same name", + "evaluation_datasets", + ), + ], +) +def test_workspace_migration_downgrade_detects_conflicts( + tmp_path, setup_conflict, expected_fragment, case_slug +): + case_dir = tmp_path / f"conflict_{case_slug}" + case_dir.mkdir() + engine, config = _prepare_database(case_dir) + try: + command.upgrade(config, REVISION) + with engine.begin() as conn: + _add_workspace(conn, "team-a", "Team A") + setup_conflict(conn) + + with pytest.raises( + RuntimeError, + match=re.escape(expected_fragment), + ): + command.downgrade(config, PREVIOUS_REVISION) + finally: + engine.dispose() + + +@pytest.mark.parametrize( + ("setup_conflict", "table_name", "columns", "case_slug"), + [ + ( + _setup_model_version_conflict, + "model_versions", + ("name", "version"), + "model_versions", + ), + ( + _setup_registered_model_tag_conflict, + "registered_model_tags", + ("name", "key"), + "registered_model_tags", + ), + ( + _setup_model_version_tag_conflict, + "model_version_tags", + ("name", "version", "key"), + "model_version_tags", + ), + ( + _setup_registered_model_alias_conflict, + "registered_model_aliases", + ("name", "alias"), + "registered_model_aliases", + ), + ], +) +def test_workspace_migration_conflict_detection_queries( + tmp_path, setup_conflict, table_name, columns, case_slug +): + case_dir = tmp_path / f"conflict_query_{case_slug}" + case_dir.mkdir() + engine, config = _prepare_database(case_dir) + try: + command.upgrade(config, REVISION) + with engine.begin() as conn: + _add_workspace(conn, "team-a", "Team A") + setup_conflict(conn) + conflicts = _fetch_conflicts(conn, table_name, columns) + assert conflicts, f"Expected conflicts for {table_name}, found none" + finally: + engine.dispose() diff --git a/tests/resources/db/latest_schema.sql b/tests/resources/db/latest_schema.sql index 6903019a42c37..fecf909f2db82 100644 --- a/tests/resources/db/latest_schema.sql +++ b/tests/resources/db/latest_schema.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,8 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), - UNIQUE (name), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name), CONSTRAINT experiments_lifecycle_stage CHECK (lifecycle_stage IN ('active', 'deleted')) ) @@ -81,8 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - CONSTRAINT registered_model_pk PRIMARY KEY (name), - UNIQUE (name) + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -96,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -185,8 +195,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -194,8 +205,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -203,8 +215,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -356,8 +369,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/store/tracking/test_sqlalchemy_store_schema.py b/tests/store/tracking/test_sqlalchemy_store_schema.py index 37f314a6d0610..cfe9039c40f6c 100644 --- a/tests/store/tracking/test_sqlalchemy_store_schema.py +++ b/tests/store/tracking/test_sqlalchemy_store_schema.py @@ -9,6 +9,10 @@ from alembic.script import ScriptDirectory import mlflow.db + +# Import workspace models temporarily for tests to pass. +# This can be removed once we have a workspace store imported. +import mlflow.store.workspace.dbmodels as _workspace_models # noqa: F401 from mlflow.exceptions import MlflowException from mlflow.store.db.base_sql_model import Base from mlflow.store.db.utils import _get_alembic_config, _verify_schema