diff --git a/docs/api_reference/api_inventory.txt b/docs/api_reference/api_inventory.txt index 7f8ca89707d1a..0899539839a54 100644 --- a/docs/api_reference/api_inventory.txt +++ b/docs/api_reference/api_inventory.txt @@ -528,6 +528,7 @@ mlflow.entities.WebhookStatus.to_proto mlflow.entities.WebhookTestResult mlflow.entities.WebhookTestResult.from_proto mlflow.entities.WebhookTestResult.to_proto +mlflow.entities.Workspace mlflow.entities.assessment.Assessment mlflow.entities.assessment.Expectation mlflow.entities.assessment.Feedback @@ -627,6 +628,7 @@ mlflow.entities.webhook.Webhook mlflow.entities.webhook.WebhookEvent mlflow.entities.webhook.WebhookStatus mlflow.entities.webhook.WebhookTestResult +mlflow.entities.workspace.Workspace mlflow.evaluate mlflow.exceptions.get_error_code mlflow.finalize_logged_model diff --git a/mlflow/entities/__init__.py b/mlflow/entities/__init__.py index ea631158757df..5cc03b222265b 100644 --- a/mlflow/entities/__init__.py +++ b/mlflow/entities/__init__.py @@ -62,6 +62,7 @@ WebhookStatus, WebhookTestResult, ) +from mlflow.entities.workspace import Workspace __all__ = [ "Experiment", @@ -125,6 +126,7 @@ "WebhookEvent", "WebhookStatus", "WebhookTestResult", + "Workspace", ] diff --git a/mlflow/entities/workspace.py b/mlflow/entities/workspace.py new file mode 100644 index 0000000000000..cfd74daf09eaf --- /dev/null +++ b/mlflow/entities/workspace.py @@ -0,0 +1,13 @@ +"""Workspace entity shared between server and stores.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class Workspace: + """Minimal metadata describing a workspace.""" + + name: str + description: str | None = None diff --git a/mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py b/mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py new file mode 100644 index 0000000000000..9f9ae6218b546 --- /dev/null +++ b/mlflow/store/db_migrations/versions/1b5f0d9ad7c1_add_workspace_columns_and_catalog.py @@ -0,0 +1,732 @@ +"""Add workspace columns and catalog table + +Create Date: 2025-11-18 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import inspect + +# revision identifiers, used by Alembic. +revision = "1b5f0d9ad7c1" +down_revision = "bf29a5ff90ea" +branch_labels = None +depends_on = None + +_NAMING_CONVENTION = { + "pk": "pk_%(table_name)s", + "fk": "fk_%(table_name)s_%(referred_table_name)s_%(column_0_name)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", +} + +_WORKSPACE_TABLES = [ + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", + "webhooks", +] + +# Older SQLite migrations emitted unnamed foreign keys. When batch-altering tables we need the +# legacy names so we can drop the constraints deterministically; this mapping gives us the +# aliases for those historical definitions. +_SQLITE_LEGACY_FKS = { + ( + "model_versions", + "registered_models", + ("name",), + ): "fk_model_versions_registered_models_name", + ( + "registered_model_tags", + "registered_models", + ("name",), + ): "fk_registered_model_tags_registered_models_name", + ( + "model_version_tags", + "model_versions", + ("name", "version"), + ): "fk_model_version_tags_model_versions_name", + ( + "registered_model_aliases", + "registered_models", + ("name",), + ): "fk_registered_model_aliases_registered_models_name", +} + + +def _workspace_column(): + return sa.Column( + "workspace", + sa.String(length=63), + nullable=False, + server_default=sa.text("'default'"), + ) + + +def _fetch_mssql_unique_metadata( + conn, dialect_name: str, schema: str | None, table_name: str, for_indexes: bool = False +): + """Fetch unique constraint or index metadata for the given MSSQL table. + + SQLAlchemy's inspector doesn't implement ``get_unique_constraints`` for MSSQL, so we query the + catalog tables directly. When ``for_indexes`` is True we return unique indexes (needed when a + unique constraint is materialized as an index); otherwise we return metadata that looks like + ``get_unique_constraints``. + """ + if dialect_name != "mssql": + return [] + + if for_indexes: + query = sa.text( + """ + SELECT + i.name, + STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) AS column_names + FROM sys.indexes i + JOIN sys.tables t ON i.object_id = t.object_id + JOIN sys.schemas s ON t.schema_id = s.schema_id + JOIN sys.index_columns ic + ON i.object_id = ic.object_id + AND i.index_id = ic.index_id + JOIN sys.columns c + ON ic.object_id = c.object_id + AND ic.column_id = c.column_id + WHERE i.is_unique = 1 + AND i.is_primary_key = 0 + AND ic.is_included_column = 0 + AND t.name = :table_name + AND (:schema IS NULL OR s.name = :schema) + GROUP BY i.name + """ + ) + else: + query = sa.text( + """ + SELECT + kc.name, + STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) AS column_names + FROM sys.key_constraints kc + JOIN sys.tables t ON kc.parent_object_id = t.object_id + JOIN sys.schemas s ON t.schema_id = s.schema_id + JOIN sys.index_columns ic + ON kc.parent_object_id = ic.object_id + AND kc.unique_index_id = ic.index_id + JOIN sys.columns c + ON ic.object_id = c.object_id + AND ic.column_id = c.column_id + WHERE kc.type = 'UQ' + AND t.name = :table_name + AND (:schema IS NULL OR s.name = :schema) + GROUP BY kc.name + """ + ) + + result = conn.execute(query, {"table_name": table_name, "schema": schema}).fetchall() + return [ + { + "name": row[0], + "column_names": [col.strip() for col in row[1].split(",") if col] if row[1] else [], + } + for row in result + ] + + +def _with_batch(table_name): + return op.batch_alter_table(table_name, recreate="auto", naming_convention=_NAMING_CONVENTION) + + +# Once the tracking and model registry stores support workspaces, we can remove the +# server_default to ensure the stores are properly setting the workspace. +def upgrade(): + conn = op.get_bind() + inspector = inspect(conn) + dialect_name = conn.dialect.name + schema = op.get_context().version_table_schema or inspector.default_schema_name + + def _get_unique_constraints(table_name: str): + try: + return inspector.get_unique_constraints(table_name) + except NotImplementedError: + if dialect_name != "mssql": + raise + # SQL Server's inspector does not implement get_unique_constraints; fall back to + # querying the catalog tables directly via _fetch_mssql_unique_metadata. + return _fetch_mssql_unique_metadata( + conn, + dialect_name, + schema, + table_name, + for_indexes=False, + ) + + def _get_unique_indexes(table_name: str): + try: + return inspector.get_indexes(table_name) + except NotImplementedError: + if dialect_name != "mssql": + raise + + metadata = ( + _fetch_mssql_unique_metadata( + conn, + dialect_name, + schema, + table_name, + for_indexes=True, + ) + or [] + ) + + for entry in metadata: + entry["unique"] = True + return metadata + + def _detect_unique_on_name(table_name: str): + expected_name = _NAMING_CONVENTION["uq"] % { + "table_name": table_name, + "column_0_name": "name", + } + + for constraint in _get_unique_constraints(table_name) or []: + cols = constraint.get("column_names") or [] + name = constraint.get("name") + if cols == ["name"] or name == expected_name: + return name or expected_name, None + + for index in _get_unique_indexes(table_name) or []: + if index.get("unique") and index.get("column_names") == ["name"]: + return None, index["name"] + + return None, None + + def _collect_foreign_keys(table: str, referred_table: str): + names = [] + for fk in inspector.get_foreign_keys(table): + if fk.get("referred_table") == referred_table: + name = fk.get("name") + if not name and dialect_name == "sqlite": + key = (table, referred_table, tuple(fk.get("constrained_columns") or ())) + name = _SQLITE_LEGACY_FKS.get(key) + if not name: + continue + names.append(name) + return names + + def _create_workspace_indexes_and_catalog(): + op.create_index("idx_experiments_workspace", "experiments", ["workspace"]) + op.create_index("idx_registered_models_workspace", "registered_models", ["workspace"]) + op.create_index( + "idx_experiments_workspace_creation_time", + "experiments", + ["workspace", "creation_time"], + unique=False, + ) + op.create_index("idx_evaluation_datasets_workspace", "evaluation_datasets", ["workspace"]) + op.create_index("idx_webhooks_workspace", "webhooks", ["workspace"]) + + op.create_table( + "workspaces", + sa.Column("name", sa.String(length=63), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("name", name="workspaces_pk"), + ) + + metadata = sa.MetaData() + workspaces_table = sa.Table( + "workspaces", + metadata, + sa.Column("name", sa.String(length=63)), + sa.Column("description", sa.Text()), + ) + + conn.execute( + workspaces_table.insert().values( + name="default", + description="Default workspace for legacy resources", + ) + ) + + experiments_unique_constraint, experiments_unique_index = _detect_unique_on_name("experiments") + registered_models_unique_constraint, registered_models_unique_index = _detect_unique_on_name( + "registered_models" + ) + + fk_model_versions = _collect_foreign_keys("model_versions", "registered_models") + fk_registered_model_tags = _collect_foreign_keys("registered_model_tags", "registered_models") + fk_registered_model_aliases = _collect_foreign_keys( + "registered_model_aliases", "registered_models" + ) + fk_model_version_tags = _collect_foreign_keys("model_version_tags", "model_versions") + + if dialect_name == "sqlite": + # Let SQLite handle the foreign key drops inside the batches so each table is recreated + # only once. + with _with_batch("experiments") as batch_op: + if experiments_unique_constraint: + batch_op.drop_constraint(experiments_unique_constraint, type_="unique") + elif experiments_unique_index: + batch_op.drop_index(experiments_unique_index) + batch_op.add_column(_workspace_column()) + batch_op.create_unique_constraint( + "uq_experiments_workspace_name", + ["workspace", "name"], + ) + + with _with_batch("registered_models") as batch_op: + if registered_models_unique_constraint: + batch_op.drop_constraint(registered_models_unique_constraint, type_="unique") + elif registered_models_unique_index: + batch_op.drop_index(registered_models_unique_index) + batch_op.add_column(_workspace_column()) + batch_op.drop_constraint("registered_model_pk", type_="primary") + batch_op.create_primary_key("registered_model_pk", ["workspace", "name"]) + + with _with_batch("model_versions") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_model_versions: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("model_version_pk", type_="primary") + batch_op.create_primary_key("model_version_pk", ["workspace", "name", "version"]) + batch_op.create_foreign_key( + "fk_model_versions_registered_models", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + with _with_batch("registered_model_tags") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_registered_model_tags: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("registered_model_tag_pk", type_="primary") + batch_op.create_primary_key("registered_model_tag_pk", ["workspace", "key", "name"]) + batch_op.create_foreign_key( + "fk_registered_model_tags_registered_models", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + with _with_batch("model_version_tags") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_model_version_tags: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("model_version_tag_pk", type_="primary") + batch_op.create_primary_key( + "model_version_tag_pk", + ["workspace", "key", "name", "version"], + ) + batch_op.create_foreign_key( + "fk_model_version_tags_model_versions", + "model_versions", + ["workspace", "name", "version"], + ["workspace", "name", "version"], + onupdate="CASCADE", + ) + + with _with_batch("registered_model_aliases") as batch_op: + batch_op.add_column(_workspace_column()) + for fk_name in fk_registered_model_aliases: + batch_op.drop_constraint(fk_name, type_="foreignkey") + batch_op.drop_constraint("registered_model_alias_pk", type_="primary") + batch_op.create_primary_key( + "registered_model_alias_pk", + ["workspace", "name", "alias"], + ) + batch_op.create_foreign_key( + "fk_registered_model_aliases_registered_models", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + with _with_batch("evaluation_datasets") as batch_op: + batch_op.add_column(_workspace_column()) + + with _with_batch("webhooks") as batch_op: + batch_op.add_column(_workspace_column()) + + _create_workspace_indexes_and_catalog() + + return + + # Non-SQLite dialects can issue direct ALTER TABLE statements, which avoids rebuilding the + # tables. This code duplication is worth the performance gain of not rebuilding the tables. + # We could potentially leverage Alembic's batch_alter_table with recreate="auto" to avoid the + # duplication, but parts of this migration caused the tables to be recreated anyways. + + def _drop_fk_constraints(table_name: str, fk_names: list[str]): + for fk_name in fk_names: + if fk_name: + op.drop_constraint(fk_name, table_name=table_name, type_="foreignkey") + + def _drop_unique_on_name(table_name: str, constraint: str | None, index: str | None): + if constraint: + op.drop_constraint(constraint, table_name=table_name, type_="unique") + elif index: + op.drop_index(index, table_name=table_name) + + _drop_unique_on_name("experiments", experiments_unique_constraint, experiments_unique_index) + op.add_column( + "experiments", + _workspace_column(), + ) + op.create_unique_constraint( + "uq_experiments_workspace_name", + "experiments", + ["workspace", "name"], + ) + + _drop_fk_constraints("model_versions", fk_model_versions) + _drop_fk_constraints("registered_model_tags", fk_registered_model_tags) + _drop_fk_constraints("registered_model_aliases", fk_registered_model_aliases) + _drop_fk_constraints("model_version_tags", fk_model_version_tags) + + _drop_unique_on_name( + "registered_models", + registered_models_unique_constraint, + registered_models_unique_index, + ) + op.add_column( + "registered_models", + _workspace_column(), + ) + op.drop_constraint("registered_model_pk", "registered_models", type_="primary") + op.create_primary_key("registered_model_pk", "registered_models", ["workspace", "name"]) + + op.add_column( + "model_versions", + _workspace_column(), + ) + op.drop_constraint("model_version_pk", "model_versions", type_="primary") + op.create_primary_key("model_version_pk", "model_versions", ["workspace", "name", "version"]) + op.create_foreign_key( + "fk_model_versions_registered_models", + "model_versions", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + op.add_column( + "registered_model_tags", + _workspace_column(), + ) + op.drop_constraint("registered_model_tag_pk", "registered_model_tags", type_="primary") + op.create_primary_key( + "registered_model_tag_pk", + "registered_model_tags", + ["workspace", "key", "name"], + ) + op.create_foreign_key( + "fk_registered_model_tags_registered_models", + "registered_model_tags", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ) + + op.add_column( + "model_version_tags", + _workspace_column(), + ) + op.drop_constraint("model_version_tag_pk", "model_version_tags", type_="primary") + op.create_primary_key( + "model_version_tag_pk", + "model_version_tags", + ["workspace", "key", "name", "version"], + ) + op.create_foreign_key( + "fk_model_version_tags_model_versions", + "model_version_tags", + "model_versions", + ["workspace", "name", "version"], + ["workspace", "name", "version"], + onupdate="CASCADE", + ) + + op.add_column( + "registered_model_aliases", + _workspace_column(), + ) + op.drop_constraint("registered_model_alias_pk", "registered_model_aliases", type_="primary") + op.create_primary_key( + "registered_model_alias_pk", + "registered_model_aliases", + ["workspace", "name", "alias"], + ) + op.create_foreign_key( + "fk_registered_model_aliases_registered_models", + "registered_model_aliases", + "registered_models", + ["workspace", "name"], + ["workspace", "name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + for table in ["evaluation_datasets", "webhooks"]: + op.add_column( + table, + _workspace_column(), + ) + + _create_workspace_indexes_and_catalog() + + +def downgrade(): + conn = op.get_bind() + dialect_name = conn.dialect.name + + def _assert_no_workspace_conflicts( + table_name: str, + columns: tuple[str, ...], + resource_description: str, + ): + table = sa.Table(table_name, sa.MetaData(), autoload_with=conn) + group_columns = [table.c[column] for column in columns] + stmt = sa.select(*group_columns).group_by(*group_columns).having(sa.func.count() > 1) + conflicts = conn.execute(stmt).fetchall() + if conflicts: + formatted_conflicts = ", ".join( + "; ".join(f"{column}={value!r}" for column, value in zip(columns, row)) + for row in conflicts[:5] + ) + if len(conflicts) > 5: + formatted_conflicts += ", ..." + raise RuntimeError( + "Downgrade aborted: merging workspaces would create duplicate " + f"{resource_description}. Resolve the following conflicts by deleting or renaming " + f"the affected resources and retry: {formatted_conflicts}" + ) + + def _move_resources_to_default_workspace(table_name: str): + table = sa.Table(table_name, sa.MetaData(), autoload_with=conn) + conn.execute( + table.update().where(table.c.workspace != "default").values(workspace="default") + ) + + conflict_specs = [ + ("experiments", ("name",), "experiments with the same name"), + ("registered_models", ("name",), "registered models with the same name"), + ( + "evaluation_datasets", + ("name",), + "evaluation datasets with the same name", + ), + ( + "model_versions", + ("name", "version"), + "model versions with the same model name and version", + ), + ( + "registered_model_tags", + ("name", "key"), + "registered model tags with the same model name and key", + ), + ( + "model_version_tags", + ("name", "version", "key"), + "model version tags with the same model name, version, and key", + ), + ( + "registered_model_aliases", + ("name", "alias"), + "registered model aliases with the same model name and alias", + ), + ] + + for table_name, columns, description in conflict_specs: + _assert_no_workspace_conflicts(table_name, columns, description) + + for table in _WORKSPACE_TABLES: + _move_resources_to_default_workspace(table) + + drop_index_kwargs = {} if dialect_name == "mysql" else {"if_exists": True} + op.drop_index( + "idx_experiments_workspace_creation_time", table_name="experiments", **drop_index_kwargs + ) + op.drop_index( + "idx_registered_models_workspace", table_name="registered_models", **drop_index_kwargs + ) + op.drop_index("idx_experiments_workspace", table_name="experiments", **drop_index_kwargs) + op.drop_index( + "idx_evaluation_datasets_workspace", table_name="evaluation_datasets", **drop_index_kwargs + ) + op.drop_index("idx_webhooks_workspace", table_name="webhooks", **drop_index_kwargs) + + if dialect_name == "sqlite": + with _with_batch("model_version_tags") as batch_op: + batch_op.drop_constraint("fk_model_version_tags_model_versions", type_="foreignkey") + batch_op.drop_constraint("model_version_tag_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("model_version_tag_pk", ["key", "name", "version"]) + batch_op.create_foreign_key( + "model_version_tags_mv_fkey", + "model_versions", + ["name", "version"], + ["name", "version"], + onupdate="CASCADE", + ) + + with _with_batch("registered_model_aliases") as batch_op: + batch_op.drop_constraint( + "fk_registered_model_aliases_registered_models", type_="foreignkey" + ) + batch_op.drop_constraint("registered_model_alias_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("registered_model_alias_pk", ["name", "alias"]) + batch_op.create_foreign_key( + "registered_model_alias_name_fkey", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + + with _with_batch("registered_model_tags") as batch_op: + batch_op.drop_constraint( + "fk_registered_model_tags_registered_models", type_="foreignkey" + ) + batch_op.drop_constraint("registered_model_tag_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("registered_model_tag_pk", ["key", "name"]) + batch_op.create_foreign_key( + "registered_model_tags_name_fkey", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + + with _with_batch("model_versions") as batch_op: + batch_op.drop_constraint("fk_model_versions_registered_models", type_="foreignkey") + batch_op.drop_constraint("model_version_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("model_version_pk", ["name", "version"]) + batch_op.create_foreign_key( + "model_versions_name_fkey", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + + with _with_batch("registered_models") as batch_op: + batch_op.drop_constraint("registered_model_pk", type_="primary") + batch_op.drop_column("workspace") + batch_op.create_primary_key("registered_model_pk", ["name"]) + + with _with_batch("experiments") as batch_op: + batch_op.drop_constraint("uq_experiments_workspace_name", type_="unique") + batch_op.drop_column("workspace") + batch_op.create_unique_constraint("uq_experiments_name", ["name"]) + + with _with_batch("evaluation_datasets") as batch_op: + batch_op.drop_column("workspace") + + with _with_batch("webhooks") as batch_op: + batch_op.drop_column("workspace") + + op.drop_table("workspaces") + return + + op.drop_constraint( + "fk_model_version_tags_model_versions", "model_version_tags", type_="foreignkey" + ) + op.drop_constraint( + "fk_registered_model_aliases_registered_models", + "registered_model_aliases", + type_="foreignkey", + ) + op.drop_constraint( + "fk_registered_model_tags_registered_models", "registered_model_tags", type_="foreignkey" + ) + op.drop_constraint("fk_model_versions_registered_models", "model_versions", type_="foreignkey") + + op.drop_constraint("uq_experiments_workspace_name", table_name="experiments", type_="unique") + + op.drop_constraint("registered_model_alias_pk", "registered_model_aliases", type_="primary") + op.drop_constraint("model_version_tag_pk", "model_version_tags", type_="primary") + op.drop_constraint("registered_model_tag_pk", "registered_model_tags", type_="primary") + op.drop_constraint("model_version_pk", "model_versions", type_="primary") + op.drop_constraint("registered_model_pk", "registered_models", type_="primary") + + if dialect_name == "mssql": + # SQL Server binds defaults via named constraints. If we try to drop the column while a + # default is attached, the prior downgrade operations can leave behind those constraints, + # causing drop_column to fail. Clear the defaults explicitly first. + for table in _WORKSPACE_TABLES: + op.alter_column( + table_name=table, + column_name="workspace", + existing_type=sa.String(length=63), + existing_nullable=False, + server_default=None, + ) + + op.drop_column("model_version_tags", "workspace") + op.drop_column("registered_model_aliases", "workspace") + op.drop_column("registered_model_tags", "workspace") + op.drop_column("model_versions", "workspace") + op.drop_column("registered_models", "workspace") + op.drop_column("experiments", "workspace") + op.drop_column("evaluation_datasets", "workspace") + op.drop_column("webhooks", "workspace") + + op.create_primary_key("registered_model_pk", "registered_models", ["name"]) + op.create_primary_key("model_version_pk", "model_versions", ["name", "version"]) + op.create_primary_key("registered_model_tag_pk", "registered_model_tags", ["key", "name"]) + op.create_primary_key("model_version_tag_pk", "model_version_tags", ["key", "name", "version"]) + op.create_primary_key( + "registered_model_alias_pk", "registered_model_aliases", ["name", "alias"] + ) + + op.create_foreign_key( + "model_versions_name_fkey", + "model_versions", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + op.create_foreign_key( + "registered_model_tags_name_fkey", + "registered_model_tags", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ) + op.create_foreign_key( + "registered_model_alias_name_fkey", + "registered_model_aliases", + "registered_models", + ["name"], + ["name"], + onupdate="CASCADE", + ondelete="CASCADE", + ) + op.create_foreign_key( + "model_version_tags_mv_fkey", + "model_version_tags", + "model_versions", + ["name", "version"], + ["name", "version"], + onupdate="CASCADE", + ) + + op.drop_table("workspaces") + + op.create_unique_constraint("uq_experiments_name", "experiments", ["name"]) diff --git a/mlflow/store/model_registry/dbmodels/models.py b/mlflow/store/model_registry/dbmodels/models.py index 2075c50162e99..36d963e65e47f 100644 --- a/mlflow/store/model_registry/dbmodels/models.py +++ b/mlflow/store/model_registry/dbmodels/models.py @@ -1,3 +1,4 @@ +import sqlalchemy as sa from cryptography.fernet import Fernet from sqlalchemy import ( BigInteger, @@ -32,12 +33,20 @@ from mlflow.environment_variables import MLFLOW_WEBHOOK_SECRET_ENCRYPTION_KEY from mlflow.store.db.base_sql_model import Base from mlflow.utils.time import get_current_time_millis +from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME class SqlRegisteredModel(Base): __tablename__ = "registered_models" - name = Column(String(256), unique=True, nullable=False) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) creation_time = Column(BigInteger, default=get_current_time_millis) @@ -45,7 +54,7 @@ class SqlRegisteredModel(Base): description = Column(String(5000), nullable=True) - __table_args__ = (PrimaryKeyConstraint("name", name="registered_model_pk"),) + __table_args__ = (PrimaryKeyConstraint("workspace", "name", name="registered_model_pk"),) def __repr__(self): return ( @@ -76,7 +85,14 @@ def to_mlflow_entity(self): class SqlModelVersion(Base): __tablename__ = "model_versions" - name = Column(String(256), ForeignKey("registered_models.name", onupdate="cascade")) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) version = Column(Integer, nullable=False) @@ -107,7 +123,14 @@ class SqlModelVersion(Base): "SqlRegisteredModel", backref=backref("model_versions", cascade="all") ) - __table_args__ = (PrimaryKeyConstraint("name", "version", name="model_version_pk"),) + __table_args__ = ( + ForeignKeyConstraint( + ["workspace", "name"], + ["registered_models.workspace", "registered_models.name"], + onupdate="cascade", + ), + PrimaryKeyConstraint("workspace", "name", "version", name="model_version_pk"), + ) # entity mappers def to_mlflow_entity(self): @@ -132,7 +155,14 @@ def to_mlflow_entity(self): class SqlRegisteredModelTag(Base): __tablename__ = "registered_model_tags" - name = Column(String(256), ForeignKey("registered_models.name", onupdate="cascade")) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) key = Column(String(250), nullable=False) @@ -143,7 +173,14 @@ class SqlRegisteredModelTag(Base): "SqlRegisteredModel", backref=backref("registered_model_tags", cascade="all") ) - __table_args__ = (PrimaryKeyConstraint("key", "name", name="registered_model_tag_pk"),) + __table_args__ = ( + ForeignKeyConstraint( + ["workspace", "name"], + ["registered_models.workspace", "registered_models.name"], + onupdate="cascade", + ), + PrimaryKeyConstraint("workspace", "key", "name", name="registered_model_tag_pk"), + ) def __repr__(self): return f"" @@ -156,7 +193,14 @@ def to_mlflow_entity(self): class SqlModelVersionTag(Base): __tablename__ = "model_version_tags" - name = Column(String(256)) + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + + name = Column(String(256), nullable=False) version = Column(Integer) @@ -164,20 +208,18 @@ class SqlModelVersionTag(Base): value = Column(Text, nullable=True) - # linked entities - model_version = relationship( - "SqlModelVersion", - foreign_keys=[name, version], - backref=backref("model_version_tags", cascade="all"), - ) - __table_args__ = ( - PrimaryKeyConstraint("key", "name", "version", name="model_version_tag_pk"), ForeignKeyConstraint( - ("name", "version"), - ("model_versions.name", "model_versions.version"), + ["workspace", "name", "version"], + ["model_versions.workspace", "model_versions.name", "model_versions.version"], onupdate="cascade", ), + PrimaryKeyConstraint("workspace", "key", "name", "version", name="model_version_tag_pk"), + ) + + # linked entities + model_version = relationship( + "SqlModelVersion", backref=backref("model_version_tags", cascade="all") ) def __repr__(self): @@ -190,15 +232,15 @@ def to_mlflow_entity(self): class SqlRegisteredModelAlias(Base): __tablename__ = "registered_model_aliases" - name = Column( - String(256), - ForeignKey( - "registered_models.name", - onupdate="cascade", - ondelete="cascade", - name="registered_model_alias_name_fkey", - ), + + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), ) + + name = Column(String(256), nullable=False) alias = Column(String(256), nullable=False) version = Column(Integer, nullable=False) @@ -207,7 +249,16 @@ class SqlRegisteredModelAlias(Base): "SqlRegisteredModel", backref=backref("registered_model_aliases", cascade="all") ) - __table_args__ = (PrimaryKeyConstraint("name", "alias", name="registered_model_alias_pk"),) + __table_args__ = ( + ForeignKeyConstraint( + ["workspace", "name"], + ["registered_models.workspace", "registered_models.name"], + onupdate="cascade", + ondelete="cascade", + name="registered_model_alias_registered_model_fkey", + ), + PrimaryKeyConstraint("workspace", "name", "alias", name="registered_model_alias_pk"), + ) def __repr__(self): return f"" @@ -247,6 +298,12 @@ def process_result_value(self, value, dialect): class SqlWebhook(Base): __tablename__ = "webhooks" + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) webhook_id = Column(String(256), nullable=False) name = Column(String(256), nullable=False) description = Column(String(1000), nullable=True) @@ -261,6 +318,7 @@ class SqlWebhook(Base): PrimaryKeyConstraint("webhook_id", name="webhook_pk"), Index("idx_webhooks_status", "status"), Index("idx_webhooks_name", "name"), + Index("idx_webhooks_workspace", "workspace"), ) def __repr__(self): diff --git a/mlflow/store/model_registry/sqlalchemy_store.py b/mlflow/store/model_registry/sqlalchemy_store.py index 499fd713c3a23..f63fc3a37128e 100644 --- a/mlflow/store/model_registry/sqlalchemy_store.py +++ b/mlflow/store/model_registry/sqlalchemy_store.py @@ -61,6 +61,7 @@ _validate_webhook_name, _validate_webhook_url, ) +from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME _logger = logging.getLogger(__name__) @@ -194,12 +195,14 @@ def create_registered_model(self, name, tags=None, description=None, deployment_ creation_time=creation_time, last_updated_time=creation_time, description=description, + workspace=DEFAULT_WORKSPACE_NAME, ) tags_dict = {} for tag in tags or []: tags_dict[tag.key] = tag.value registered_model.registered_model_tags = [ - SqlRegisteredModelTag(key=key, value=value) for key, value in tags_dict.items() + SqlRegisteredModelTag(workspace=DEFAULT_WORKSPACE_NAME, key=key, value=value) + for key, value in tags_dict.items() ] session.add(registered_model) session.flush() @@ -696,7 +699,11 @@ def set_registered_model_tag(self, name, tag): with self.ManagedSessionMaker() as session: # check if registered model exists self._get_registered_model(session, name) - session.merge(SqlRegisteredModelTag(name=name, key=tag.key, value=tag.value)) + session.merge( + SqlRegisteredModelTag( + workspace=DEFAULT_WORKSPACE_NAME, name=name, key=tag.key, value=tag.value + ) + ) def delete_registered_model_tag(self, name, key): """ @@ -803,12 +810,14 @@ def next_version(sql_registered_model): run_id=run_id, run_link=run_link, description=description, + workspace=DEFAULT_WORKSPACE_NAME, ) tags_dict = {} for tag in tags or []: tags_dict[tag.key] = tag.value model_version.model_version_tags = [ - SqlModelVersionTag(key=key, value=value) for key, value in tags_dict.items() + SqlModelVersionTag(workspace=DEFAULT_WORKSPACE_NAME, key=key, value=value) + for key, value in tags_dict.items() ] session.add_all([sql_registered_model, model_version]) session.flush() @@ -1193,7 +1202,13 @@ def set_model_version_tag(self, name, version, tag): # check if model version exists self._get_sql_model_version(session, name, version) session.merge( - SqlModelVersionTag(name=name, version=version, key=tag.key, value=tag.value) + SqlModelVersionTag( + workspace=DEFAULT_WORKSPACE_NAME, + name=name, + version=version, + key=tag.key, + value=tag.value, + ) ) def delete_model_version_tag(self, name, version, key): @@ -1248,7 +1263,11 @@ def set_registered_model_alias(self, name, alias, version): with self.ManagedSessionMaker() as session: # check if model version exists self._get_sql_model_version(session, name, version) - session.merge(SqlRegisteredModelAlias(name=name, alias=alias, version=version)) + session.merge( + SqlRegisteredModelAlias( + workspace=DEFAULT_WORKSPACE_NAME, name=name, alias=alias, version=version + ) + ) def delete_registered_model_alias(self, name, alias): """ diff --git a/mlflow/store/tracking/dbmodels/models.py b/mlflow/store/tracking/dbmodels/models.py index 75d84b60cd1b0..0d38bdb02c01e 100644 --- a/mlflow/store/tracking/dbmodels/models.py +++ b/mlflow/store/tracking/dbmodels/models.py @@ -61,6 +61,7 @@ from mlflow.tracing.utils import generate_assessment_id from mlflow.utils.mlflow_tags import MLFLOW_USER, _get_run_name_from_tags from mlflow.utils.time import get_current_time_millis +from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME SourceTypes = [ SourceType.to_string(SourceType.NOTEBOOK), @@ -94,10 +95,19 @@ class SqlExperiment(Base): """ Experiment ID: `Integer`. *Primary Key* for ``experiment`` table. """ - name = Column(String(256), unique=True, nullable=False) + name = Column(String(256), nullable=False) """ - Experiment name: `String` (limit 256 characters). Defined as *Unique* and *Non null* in - table schema. + Experiment name: `String` (limit 256 characters). Unique per workspace (see + ``uq_experiments_workspace_name``) and *Non null* in the table schema. + """ + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + """ + Workspace identifier for this experiment. Defaults to ``'default'`` for legacy rows. """ artifact_location = Column(String(256), nullable=True) """ @@ -124,6 +134,7 @@ class SqlExperiment(Base): name="experiments_lifecycle_stage", ), PrimaryKeyConstraint("experiment_id", name="experiment_pk"), + UniqueConstraint("workspace", "name", name="uq_experiments_workspace_name"), ) def __repr__(self): @@ -1301,6 +1312,16 @@ class SqlEvaluationDataset(Base): *Primary Key* for ``evaluation_datasets`` table. """ + workspace = Column( + String(63), + nullable=False, + default=DEFAULT_WORKSPACE_NAME, + server_default=sa.text(f"'{DEFAULT_WORKSPACE_NAME}'"), + ) + """ + Workspace name that scopes this dataset. Defaults to ``'default'`` for legacy rows. + """ + name = Column(String(255), nullable=False) """ Dataset name: `String` (limit 255 characters). *Non null* in table schema. @@ -1355,6 +1376,7 @@ class SqlEvaluationDataset(Base): PrimaryKeyConstraint("dataset_id", name="evaluation_datasets_pk"), Index("index_evaluation_datasets_name", "name"), Index("index_evaluation_datasets_created_time", "created_time"), + Index("idx_evaluation_datasets_workspace", "workspace"), ) def to_mlflow_entity(self): diff --git a/mlflow/store/workspace/__init__.py b/mlflow/store/workspace/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/mlflow/store/workspace/dbmodels/__init__.py b/mlflow/store/workspace/dbmodels/__init__.py new file mode 100644 index 0000000000000..ee895f8ba3782 --- /dev/null +++ b/mlflow/store/workspace/dbmodels/__init__.py @@ -0,0 +1,3 @@ +from mlflow.store.workspace.dbmodels.models import SqlWorkspace + +__all__ = ["SqlWorkspace"] diff --git a/mlflow/store/workspace/dbmodels/models.py b/mlflow/store/workspace/dbmodels/models.py new file mode 100644 index 0000000000000..354d0fccb8b98 --- /dev/null +++ b/mlflow/store/workspace/dbmodels/models.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import sqlalchemy as sa +from sqlalchemy import Column, String, Text + +from mlflow.entities.workspace import Workspace +from mlflow.store.db.base_sql_model import Base + + +class SqlWorkspace(Base): + __tablename__ = "workspaces" + + name = Column(String(63), nullable=False) + description = Column(Text, nullable=True) + + __table_args__ = (sa.PrimaryKeyConstraint("name", name="workspaces_pk"),) + + def __repr__(self) -> str: # pragma: no cover + return f"" + + def to_mlflow_entity(self) -> Workspace: + return Workspace(name=self.name, description=self.description) diff --git a/mlflow/utils/workspace_utils.py b/mlflow/utils/workspace_utils.py new file mode 100644 index 0000000000000..3d6a81d00cca6 --- /dev/null +++ b/mlflow/utils/workspace_utils.py @@ -0,0 +1,5 @@ +"""Utility helpers for workspace-aware database schema defaults.""" + +DEFAULT_WORKSPACE_NAME = "default" + +__all__ = ["DEFAULT_WORKSPACE_NAME"] diff --git a/tests/db/check_migration.py b/tests/db/check_migration.py index 271e8f6c473f5..36a378dcbef57 100644 --- a/tests/db/check_migration.py +++ b/tests/db/check_migration.py @@ -49,6 +49,16 @@ SqlModelVersionTag.__tablename__, ] SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" +WORKSPACE_TABLES = { + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", + "webhooks", +} class Model(mlflow.pyfunc.PythonModel): @@ -72,7 +82,7 @@ def log_everything(): client.create_registered_model( registered_model_name, tags={"tag": "registered_model"}, description="description" ) - client.create_model_version( + model_version = client.create_model_version( registered_model_name, model_info.model_uri, run_id=run.info.run_id, @@ -80,6 +90,41 @@ def log_everything(): run_link="run_link", description="description", ) + client.set_registered_model_alias( + name=registered_model_name, + alias="prod", + version=model_version.version, + ) + # Create an additional experiment/model to ensure workspace backfills cover multiple resources. + mlflow.create_experiment(uuid.uuid4().hex) + client.create_registered_model(uuid.uuid4().hex) + client.create_webhook( + name=f"migration-webhook-{uuid.uuid4().hex}", + url="https://example.com/hook", + events=["model_version.created"], + description="workspace-migration-check", + ) + engine = sa.create_engine(os.environ["MLFLOW_TRACKING_URI"]) + metadata = sa.MetaData() + evaluation_datasets_table = sa.Table( + "evaluation_datasets", + metadata, + autoload_with=engine, + ) + with engine.begin() as conn: + conn.execute( + sa.insert(evaluation_datasets_table).values( + dataset_id=uuid.uuid4().hex, + name="workspace-migration-dataset", + schema="{}", + profile="{}", + digest=uuid.uuid4().hex, + created_time=0, + last_update_time=0, + created_by="user", + last_updated_by="user", + ) + ) def connect_to_mlflow_db(): @@ -113,6 +158,10 @@ def post_migration(): df_actual = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn) df_expected = pd.read_pickle(SNAPSHOTS_DIR / f"{table}.pkl") pd.testing.assert_frame_equal(df_actual[df_expected.columns], df_expected) + for table in WORKSPACE_TABLES: + df = pd.read_sql(sa.text(f"SELECT DISTINCT workspace FROM {table}"), conn) + assert not df["workspace"].isna().any(), f"{table} contains NULL workspace values" + assert set(df["workspace"]) == {"default"}, f"{table} contains non-default workspaces" if __name__ == "__main__": diff --git a/tests/db/schemas/mssql.sql b/tests/db/schemas/mssql.sql index 966b5eca98a61..6e0746da4f3a9 100644 --- a/tests/db/schemas/mssql.sql +++ b/tests/db/schemas/mssql.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255) COLLATE "SQL_Latin1_General_CP1_CI_AS", last_updated_by VARCHAR(255) COLLATE "SQL_Latin1_General_CP1_CI_AS", + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,7 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32) COLLATE "SQL_Latin1_General_CP1_CI_AS", creation_time BIGINT, last_update_time BIGINT, - CONSTRAINT experiment_pk PRIMARY KEY (experiment_id) + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name) ) @@ -79,7 +82,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000) COLLATE "SQL_Latin1_General_CP1_CI_AS", - CONSTRAINT registered_model_pk PRIMARY KEY (name) + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -93,10 +97,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, + description VARCHAR COLLATE "SQL_Latin1_General_CP1_CI_AS", + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, experiment_id INTEGER NOT NULL, @@ -180,8 +192,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", run_link VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", storage_location VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - CONSTRAINT "FK__model_vers__name__6B24EA82" FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -189,8 +202,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -198,8 +212,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, value VARCHAR(5000) COLLATE "SQL_Latin1_General_CP1_CI_AS", name VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - CONSTRAINT "FK__registered__name__6EF57B66" FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -346,8 +361,9 @@ CREATE TABLE model_version_tags ( value VARCHAR COLLATE "SQL_Latin1_General_CP1_CI_AS", name VARCHAR(256) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - CONSTRAINT "FK__model_version_ta__71D1E811" FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) COLLATE "SQL_Latin1_General_CP1_CI_AS" DEFAULT ('default') NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/mysql.sql b/tests/db/schemas/mysql.sql index 413f1f3f855d6..8d92510577a20 100644 --- a/tests/db/schemas/mysql.sql +++ b/tests/db/schemas/mysql.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, PRIMARY KEY (dataset_id) ) @@ -37,8 +38,10 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, PRIMARY KEY (experiment_id), - CONSTRAINT experiments_lifecycle_stage CHECK ((`lifecycle_stage` in (_utf8mb4'active',_utf8mb4'deleted'))) + CONSTRAINT experiments_lifecycle_stage CHECK ((`lifecycle_stage` in (_utf8mb4'active',_utf8mb4'deleted'))), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name) ) @@ -80,7 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - PRIMARY KEY (name) + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, name) ) @@ -94,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -182,8 +194,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - PRIMARY KEY (name, version), - CONSTRAINT model_versions_ibfk_1 FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -191,8 +204,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -200,8 +214,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - PRIMARY KEY (key, name), - CONSTRAINT registered_model_tags_ibfk_1 FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -354,8 +369,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - PRIMARY KEY (key, name, version), - CONSTRAINT model_version_tags_ibfk_1 FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/postgresql.sql b/tests/db/schemas/postgresql.sql index ffeb4527e26b5..9081fd88a546c 100644 --- a/tests/db/schemas/postgresql.sql +++ b/tests/db/schemas/postgresql.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,8 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), - CONSTRAINT experiments_name_key UNIQUE (name), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name), CONSTRAINT experiments_lifecycle_stage CHECK (lifecycle_stage::text = ANY (ARRAY['active'::character varying, 'deleted'::character varying]::text[])) ) @@ -81,7 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - CONSTRAINT registered_model_pk PRIMARY KEY (name) + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -95,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -184,8 +195,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - CONSTRAINT model_versions_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -193,8 +205,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -202,8 +215,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - CONSTRAINT registered_model_tags_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -353,8 +367,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - CONSTRAINT model_version_tags_name_version_fkey FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default'::character varying NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/sqlite.sql b/tests/db/schemas/sqlite.sql index c7b87e3461731..721286a886e81 100644 --- a/tests/db/schemas/sqlite.sql +++ b/tests/db/schemas/sqlite.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,8 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), - UNIQUE (name), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name), CONSTRAINT experiments_lifecycle_stage CHECK (lifecycle_stage IN ('active', 'deleted')) ) @@ -81,8 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - CONSTRAINT registered_model_pk PRIMARY KEY (name), - UNIQUE (name) + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -96,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -185,8 +195,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -194,8 +205,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -203,8 +215,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -356,8 +369,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/db/test_schema.py b/tests/db/test_schema.py index 15758546d6caa..1bcec568d7ada 100644 --- a/tests/db/test_schema.py +++ b/tests/db/test_schema.py @@ -1,11 +1,19 @@ import difflib +import logging import re from pathlib import Path from typing import NamedTuple import pytest -from sqlalchemy import create_engine -from sqlalchemy.schema import CreateTable, MetaData +from sqlalchemy import create_engine, inspect +from sqlalchemy.schema import CreateTable, MetaData, UniqueConstraint + +_logger = logging.getLogger(__name__) + +_DIALECT_REFLECTED_UNIQUE_CONSTRAINTS = { + "mysql": {"uq_experiments_workspace_name"}, + "mssql": {"uq_experiments_workspace_name"}, +} import mlflow from mlflow.environment_variables import MLFLOW_TRACKING_URI @@ -25,6 +33,7 @@ def dump_schema(db_uri): engine = create_engine(db_uri) created_tables_metadata = MetaData() created_tables_metadata.reflect(bind=engine) + _reattach_missing_unique_constraints(engine, created_tables_metadata) # Write out table schema as described in # https://docs.sqlalchemy.org/en/13/faq/metadata_schema.html#how-can-i-get-the-create-table-drop-table-output-as-a-string lines = [] @@ -34,6 +43,70 @@ def dump_schema(db_uri): return "\n".join(lines) +def _reattach_missing_unique_constraints(engine, metadata): + constraint_names = _DIALECT_REFLECTED_UNIQUE_CONSTRAINTS.get(engine.dialect.name) + if not constraint_names: + return + inspector = inspect(engine) + for table in metadata.sorted_tables: + existing_unique_columns = { + tuple(constraint.columns.keys()) + for constraint in table.constraints + if isinstance(constraint, UniqueConstraint) + } + # Not all dialects reflect `UniqueConstraint` objects the same way. MySQL reports + # them as indexes; MSSQL doesn't implement `get_unique_constraints` at all. We + # normalize the reflection results via `_get_unique_constraints` so the same code + # path can reattach missing `UniqueConstraint`s across dialects. + for unique in _get_unique_constraints(inspector, engine.dialect.name, table.name): + name = unique.get("name") + columns = tuple(unique.get("column_names") or ()) + duplicates_index = unique.get("duplicates_index") + if not columns or name not in constraint_names: + continue + if engine.dialect.name == "mysql" and not duplicates_index: + # MySQL exposes unique constraints as unique indexes. SQLAlchemy treats those as + # indexes during reflection, so the reflected metadata lacks the original + # `UniqueConstraint`. Only recreate constraints that are backed by an actual + # unique index reported via `duplicates_index`. + continue + if columns in existing_unique_columns: + continue + missing_columns = tuple(column for column in columns if column not in table.c) + if missing_columns: + _logger.warning( + "Skipping recreation of unique constraint '%s' on table '%s' due to " + "missing columns: %s", + name, + table.name, + ", ".join(missing_columns), + ) + continue + constraint = UniqueConstraint(*[table.c[column] for column in columns], name=name) + table.append_constraint(constraint) + existing_unique_columns.add(columns) + + +def _get_unique_constraints(inspector, dialect, table_name): + try: + unique_constraints = inspector.get_unique_constraints(table_name) + except NotImplementedError: + unique_constraints = None + if unique_constraints is None: + unique_constraints = [] + if not unique_constraints: + unique_constraints = [ + { + "name": index.get("name"), + "column_names": index.get("column_names"), + "duplicates_index": index.get("unique"), + } + for index in inspector.get_indexes(table_name) + if index.get("unique") + ] + return unique_constraints + + class _CreateTable(NamedTuple): table: str columns: str diff --git a/tests/db/test_workspace_migration.py b/tests/db/test_workspace_migration.py new file mode 100644 index 0000000000000..30272de2ae5c5 --- /dev/null +++ b/tests/db/test_workspace_migration.py @@ -0,0 +1,1073 @@ +import os +import re +from contextlib import contextmanager + +import pytest +import sqlalchemy as sa +from alembic import command + +from mlflow.store.db.utils import _get_alembic_config +from mlflow.store.tracking.dbmodels.initial_models import Base as InitialBase + +_LEGACY_REGISTERED_MODEL_TAGS = sa.table( + "registered_model_tags", + sa.column("key"), + sa.column("value"), + sa.column("name"), +) + +_LEGACY_MODEL_VERSION_TAGS = sa.table( + "model_version_tags", + sa.column("key"), + sa.column("value"), + sa.column("name"), + sa.column("version"), +) + +_LEGACY_REGISTERED_MODEL_ALIASES = sa.table( + "registered_model_aliases", + sa.column("alias"), + sa.column("version"), + sa.column("name"), +) + +_LEGACY_EVALUATION_DATASETS = sa.table( + "evaluation_datasets", + sa.column("dataset_id"), + sa.column("name"), + sa.column("schema"), + sa.column("profile"), + sa.column("digest"), + sa.column("created_time"), + sa.column("last_update_time"), + sa.column("created_by"), + sa.column("last_updated_by"), +) + +_WORKSPACE_TABLES = ( + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", +) + +_REGISTERED_MODEL_TAGS = sa.table( + "registered_model_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), +) + +_MODEL_VERSION_TAGS = sa.table( + "model_version_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), + sa.column("version"), +) + +_REGISTERED_MODEL_ALIASES = sa.table( + "registered_model_aliases", + sa.column("workspace"), + sa.column("name"), + sa.column("alias"), + sa.column("version"), +) + +_EVALUATION_DATASETS = sa.table( + "evaluation_datasets", + sa.column("dataset_id"), + sa.column("name"), + sa.column("schema"), + sa.column("profile"), + sa.column("digest"), + sa.column("created_time"), + sa.column("last_update_time"), + sa.column("created_by"), + sa.column("last_updated_by"), + sa.column("workspace"), +) + +REVISION = "1b5f0d9ad7c1" +PREVIOUS_REVISION = "bf29a5ff90ea" + +DB_URI = os.environ.get("MLFLOW_TRACKING_URI") +USE_EXTERNAL_DB = DB_URI is not None and not DB_URI.startswith("sqlite") + + +@pytest.fixture(scope="session", autouse=True) +def _upgrade_external_db_to_head_after_suite(): + """ + When running under Docker (i.e., with a shared external DB), make sure the DB ends up on the + latest revision once this module finishes. Individual tests intentionally downgrade to the + pre-workspace schema, so without this hook, the subsequent suites in the database workflow would + run against an outdated schema after new migrations land. + """ + yield + if USE_EXTERNAL_DB: + config = _get_alembic_config(DB_URI) + command.upgrade(config, "head") + + +@contextmanager +def _identity_insert(conn, table_name: str): + if conn.dialect.name != "mssql": + yield + return + + conn.execute(sa.text(f"SET IDENTITY_INSERT {table_name} ON")) + try: + yield + finally: + conn.execute(sa.text(f"SET IDENTITY_INSERT {table_name} OFF")) + + +def _insert_table_row(conn, table, **values): + conn.execute(sa.insert(table).values(**values)) + + +def _assert_workspace_column(inspector, table_name: str, expected_default: str): + columns = inspector.get_columns(table_name) + workspace = next((col for col in columns if col["name"] == "workspace"), None) + assert workspace is not None, f"{table_name} lacks workspace column" + assert not workspace.get("nullable", False) + default_value = _get_workspace_default(workspace) + assert default_value == expected_default + + +def _assert_workspace_columns(inspector, expected_default: str = "default"): + for table in _WORKSPACE_TABLES: + _assert_workspace_column(inspector, table, expected_default) + + +def _has_index(inspector, table: str, index_name: str, columns: list[str]): + indexes = inspector.get_indexes(table) + return any( + index["name"] == index_name and index.get("column_names") == columns for index in indexes + ) + + +def _prepare_database(tmp_path): + if USE_EXTERNAL_DB: + engine = sa.create_engine(DB_URI) + with engine.begin() as conn: + metadata = sa.MetaData() + metadata.reflect(bind=conn) + metadata.drop_all(bind=conn) + InitialBase.metadata.create_all(conn) + config = _get_alembic_config(DB_URI) + else: + db_path = tmp_path / "workspace_migration.sqlite" + url = f"sqlite:///{db_path}" + engine = sa.create_engine(url) + InitialBase.metadata.create_all(engine) + config = _get_alembic_config(url) + command.upgrade(config, PREVIOUS_REVISION) + return engine, config + + +def _seed_pre_workspace_entities(conn): + # This intentionally uses raw SQL matching the legacy schema (no workspace columns) + # so the migration under test is fully responsible for adding/backfilling the new + # fields. The helper insert functions below operate on the post-migration schema + # and therefore cannot be reused here. + with _identity_insert(conn, "experiments"): + conn.execute( + sa.text( + """ + INSERT INTO experiments ( + experiment_id, + name, + artifact_location, + lifecycle_stage, + creation_time, + last_update_time + ) + VALUES ( + :experiment_id, + :name, + :artifact_location, + :lifecycle_stage, + :creation_time, + :last_update_time + ) + """ + ), + { + "experiment_id": 1, + "name": "exp-default", + "artifact_location": "path", + "lifecycle_stage": "active", + "creation_time": 0, + "last_update_time": 0, + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO runs ( + run_uuid, + name, + source_type, + source_name, + entry_point_name, + user_id, + status, + start_time, + end_time, + source_version, + lifecycle_stage, + artifact_uri, + experiment_id + ) + VALUES ( + :run_uuid, + :name, + :source_type, + :source_name, + :entry_point_name, + :user_id, + :status, + :start_time, + :end_time, + :source_version, + :lifecycle_stage, + :artifact_uri, + :experiment_id + ) + """ + ), + { + "run_uuid": "run-default", + "name": "upgrade-validation-run", + "source_type": "LOCAL", + "source_name": "script.py", + "entry_point_name": "main", + "user_id": "user", + "status": "FINISHED", + "start_time": 0, + "end_time": 1, + "source_version": "abc123", + "lifecycle_stage": "active", + "artifact_uri": "path/artifacts", + "experiment_id": 1, + }, + ) + conn.execute( + sa.text( + """ + INSERT INTO registered_models (name, creation_time, last_updated_time, description) + VALUES (:name, :creation_time, :last_updated_time, :description) + """ + ), + {"name": "rm-default", "creation_time": 0, "last_updated_time": 0, "description": "desc"}, + ) + conn.execute( + sa.text( + """ + INSERT INTO model_versions ( + name, + version, + creation_time, + last_updated_time, + user_id, + current_stage, + description, + source, + run_id, + status, + status_message, + run_link, + storage_location + ) + VALUES ( + :name, + :version, + :creation_time, + :last_updated_time, + :user_id, + :current_stage, + :description, + :source, + :run_id, + :status, + :status_message, + :run_link, + :storage_location + ) + """ + ), + { + "name": "rm-default", + "version": 1, + "creation_time": 0, + "last_updated_time": 0, + "user_id": "user", + "current_stage": "None", + "description": "desc", + "source": "source", + "run_id": "run-id", + "status": "READY", + "status_message": "message", + "run_link": "link", + "storage_location": "location", + }, + ) + _insert_table_row( + conn, + _LEGACY_REGISTERED_MODEL_TAGS, + key="tag", + value="value", + name="rm-default", + ) + _insert_table_row( + conn, + _LEGACY_MODEL_VERSION_TAGS, + key="tag", + value="value", + name="rm-default", + version=1, + ) + _insert_table_row( + conn, + _LEGACY_REGISTERED_MODEL_ALIASES, + alias="alias", + version=1, + name="rm-default", + ) + _insert_table_row( + conn, + _LEGACY_EVALUATION_DATASETS, + dataset_id="ds-default", + name="Dataset", + schema="schema", + profile="profile", + digest="digest", + created_time=0, + last_update_time=0, + created_by="user", + last_updated_by="user", + ) + + +def _get_workspace_default(column_info): + default = column_info.get("default") or column_info.get("server_default") + if default is None: + return None + value = str(default).strip() + if value.startswith("(") and value.endswith(")"): + value = value[1:-1] + value = value.strip() + value = value.strip("'\"") + if "::" in value: + value = value.split("::", 1)[0] + return value.strip("'\"") + + +def _add_workspace(conn, name: str, description: str): + conn.execute( + sa.text("INSERT INTO workspaces (name, description) VALUES (:name, :description)"), + {"name": name, "description": description}, + ) + + +def _insert_experiment( + conn, + *, + experiment_id: int, + name: str, + workspace: str, + artifact_location: str = "path", + lifecycle_stage: str = "active", +): + with _identity_insert(conn, "experiments"): + conn.execute( + sa.text( + """ + INSERT INTO experiments ( + experiment_id, + name, + artifact_location, + lifecycle_stage, + creation_time, + last_update_time, + workspace + ) + VALUES ( + :experiment_id, + :name, + :artifact_location, + :lifecycle_stage, + :creation_time, + :last_update_time, + :workspace + ) + """ + ), + { + "experiment_id": experiment_id, + "name": name, + "artifact_location": artifact_location, + "lifecycle_stage": lifecycle_stage, + "creation_time": 0, + "last_update_time": 0, + "workspace": workspace, + }, + ) + + +def _insert_run( + conn, + *, + run_uuid: str, + experiment_id: int, + name: str = "run", + artifact_uri: str = "path/artifacts", +): + conn.execute( + sa.text( + """ + INSERT INTO runs ( + run_uuid, + name, + source_type, + source_name, + entry_point_name, + user_id, + status, + start_time, + end_time, + source_version, + lifecycle_stage, + artifact_uri, + experiment_id + ) + VALUES ( + :run_uuid, + :name, + :source_type, + :source_name, + :entry_point_name, + :user_id, + :status, + :start_time, + :end_time, + :source_version, + :lifecycle_stage, + :artifact_uri, + :experiment_id + ) + """ + ), + { + "run_uuid": run_uuid, + "name": name, + "source_type": "LOCAL", + "source_name": "script.py", + "entry_point_name": "main", + "user_id": "user", + "status": "FINISHED", + "start_time": 0, + "end_time": 1, + "source_version": "abc123", + "lifecycle_stage": "active", + "artifact_uri": artifact_uri, + "experiment_id": experiment_id, + }, + ) + + +def _insert_registered_model( + conn, + *, + name: str, + workspace: str, + description: str = "desc", + creation_time: int = 0, +): + conn.execute( + sa.text( + """ + INSERT INTO registered_models ( + name, + creation_time, + last_updated_time, + description, + workspace + ) + VALUES ( + :name, + :creation_time, + :last_updated_time, + :description, + :workspace + ) + """ + ), + { + "name": name, + "creation_time": creation_time, + "last_updated_time": creation_time, + "description": description, + "workspace": workspace, + }, + ) + + +def _insert_model_version( + conn, + *, + name: str, + version: int, + workspace: str, + run_id: str = "run-id", + storage_location: str = "location", +): + conn.execute( + sa.text( + """ + INSERT INTO model_versions ( + name, + version, + creation_time, + last_updated_time, + user_id, + current_stage, + description, + source, + run_id, + status, + status_message, + run_link, + storage_location, + workspace + ) + VALUES ( + :name, + :version, + :creation_time, + :last_updated_time, + :user_id, + :current_stage, + :description, + :source, + :run_id, + :status, + :status_message, + :run_link, + :storage_location, + :workspace + ) + """ + ), + { + "name": name, + "version": version, + "creation_time": 0, + "last_updated_time": 0, + "user_id": "user", + "current_stage": "None", + "description": "desc", + "source": "source", + "run_id": run_id, + "status": "READY", + "status_message": "message", + "run_link": "link", + "storage_location": storage_location, + "workspace": workspace, + }, + ) + + +_REGISTERED_MODEL_TAGS = sa.table( + "registered_model_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), +) + + +def _insert_registered_model_tag( + conn, + *, + workspace: str, + name: str, + key: str, + value: str = "value", +): + _insert_table_row( + conn, + _REGISTERED_MODEL_TAGS, + workspace=workspace, + key=key, + value=value, + name=name, + ) + + +_MODEL_VERSION_TAGS = sa.table( + "model_version_tags", + sa.column("workspace"), + sa.column("key"), + sa.column("value"), + sa.column("name"), + sa.column("version"), +) + + +def _insert_model_version_tag( + conn, + *, + workspace: str, + name: str, + version: int, + key: str, + value: str = "value", +): + _insert_table_row( + conn, + _MODEL_VERSION_TAGS, + workspace=workspace, + key=key, + value=value, + name=name, + version=version, + ) + + +_REGISTERED_MODEL_ALIASES = sa.table( + "registered_model_aliases", + sa.column("workspace"), + sa.column("name"), + sa.column("alias"), + sa.column("version"), +) + + +def _insert_registered_model_alias( + conn, + *, + workspace: str, + name: str, + alias: str, + version: int = 1, +): + _insert_table_row( + conn, + _REGISTERED_MODEL_ALIASES, + workspace=workspace, + name=name, + alias=alias, + version=version, + ) + + +_EVALUATION_DATASETS = sa.table( + "evaluation_datasets", + sa.column("dataset_id"), + sa.column("name"), + sa.column("schema"), + sa.column("profile"), + sa.column("digest"), + sa.column("created_time"), + sa.column("last_update_time"), + sa.column("created_by"), + sa.column("last_updated_by"), + sa.column("workspace"), +) + + +def _insert_evaluation_dataset( + conn, + *, + dataset_id: str, + workspace: str, + name: str = "Dataset", + digest: str = "digest", +): + conn.execute( + sa.insert(_EVALUATION_DATASETS).values( + dataset_id=dataset_id, + name=name, + schema="schema", + profile="profile", + digest=digest, + created_time=0, + last_update_time=0, + created_by="user", + last_updated_by="user", + workspace=workspace, + ) + ) + + +def _fetch_conflicts(conn, table_name: str, columns: tuple[str, ...]): + metadata = sa.MetaData() + table = sa.Table(table_name, metadata, autoload_with=conn) + group_columns = [table.c[column] for column in columns] + stmt = sa.select(*group_columns).group_by(*group_columns).having(sa.func.count() > 1) + return conn.execute(stmt).fetchall() + + +def test_workspace_migration_upgrade_adds_columns_and_backfills(tmp_path): + engine, config = _prepare_database(tmp_path) + try: + with engine.begin() as conn: + _seed_pre_workspace_entities(conn) + + command.upgrade(config, REVISION) + inspector = sa.inspect(engine) + + _assert_workspace_columns(inspector, "default") + + with engine.connect() as conn: + assert conn.execute( + sa.text( + "SELECT experiment_id, name, workspace FROM experiments ORDER BY experiment_id" + ) + ).fetchall() == [(1, "exp-default", "default")] + + assert conn.execute( + sa.text("SELECT run_uuid, experiment_id FROM runs ORDER BY run_uuid") + ).fetchall() == [("run-default", 1)] + + assert conn.execute( + sa.text("SELECT name, workspace FROM registered_models") + ).fetchall() == [("rm-default", "default")] + + assert conn.execute( + sa.text("SELECT name, version, workspace FROM model_versions") + ).fetchall() == [("rm-default", 1, "default")] + + assert conn.execute( + sa.select( + _REGISTERED_MODEL_TAGS.c.workspace, + _REGISTERED_MODEL_TAGS.c.name, + _REGISTERED_MODEL_TAGS.c.key, + ) + ).fetchall() == [("default", "rm-default", "tag")] + + assert conn.execute( + sa.select( + _MODEL_VERSION_TAGS.c.workspace, + _MODEL_VERSION_TAGS.c.name, + _MODEL_VERSION_TAGS.c.version, + _MODEL_VERSION_TAGS.c.key, + ) + ).fetchall() == [("default", "rm-default", 1, "tag")] + + assert conn.execute( + sa.text("SELECT workspace, name, alias FROM registered_model_aliases") + ).fetchall() == [("default", "rm-default", "alias")] + + assert conn.execute( + sa.text("SELECT dataset_id, workspace FROM evaluation_datasets") + ).fetchall() == [("ds-default", "default")] + + assert conn.execute( + sa.text("SELECT name, description FROM workspaces ORDER BY name") + ).fetchall() == [("default", "Default workspace for legacy resources")] + + pk_registered_models = inspector.get_pk_constraint("registered_models") + assert pk_registered_models["constrained_columns"] == ["workspace", "name"] + + pk_model_versions = inspector.get_pk_constraint("model_versions") + assert pk_model_versions["constrained_columns"] == [ + "workspace", + "name", + "version", + ] + + pk_registered_model_tags = inspector.get_pk_constraint("registered_model_tags") + assert pk_registered_model_tags["constrained_columns"] == [ + "workspace", + "key", + "name", + ] + + pk_model_version_tags = inspector.get_pk_constraint("model_version_tags") + assert pk_model_version_tags["constrained_columns"] == [ + "workspace", + "key", + "name", + "version", + ] + + pk_model_aliases = inspector.get_pk_constraint("registered_model_aliases") + assert pk_model_aliases["constrained_columns"] == [ + "workspace", + "name", + "alias", + ] + + try: + unique_experiments = inspector.get_unique_constraints("experiments") + except NotImplementedError: + if inspector.bind.dialect.name == "mssql": + unique_experiments = None + else: + raise + if unique_experiments is not None: + assert any( + {"workspace", "name"} == set(constraint.get("column_names", [])) + for constraint in unique_experiments + ) + + fk_model_versions = inspector.get_foreign_keys("model_versions") + assert any( + fk.get("constrained_columns") == ["workspace", "name"] + and fk.get("referred_table") == "registered_models" + for fk in fk_model_versions + ) + + assert _has_index(inspector, "experiments", "idx_experiments_workspace", ["workspace"]) + assert _has_index( + inspector, + "experiments", + "idx_experiments_workspace_creation_time", + ["workspace", "creation_time"], + ) + assert _has_index( + inspector, "registered_models", "idx_registered_models_workspace", ["workspace"] + ) + assert _has_index( + inspector, "evaluation_datasets", "idx_evaluation_datasets_workspace", ["workspace"] + ) + finally: + engine.dispose() + + +def test_workspace_migration_downgrade_reverts_schema(tmp_path): + engine, config = _prepare_database(tmp_path) + try: + command.upgrade(config, REVISION) + with engine.begin() as conn: + _add_workspace(conn, "team-a", "Team A") + _insert_experiment(conn, experiment_id=1, name="exp-default", workspace="default") + _insert_run( + conn, + run_uuid="run-default", + experiment_id=1, + name="downgrade-validation-run", + ) + _insert_experiment(conn, experiment_id=2, name="exp-team-a", workspace="team-a") + + command.downgrade(config, PREVIOUS_REVISION) + inspector = sa.inspect(engine) + + tables = inspector.get_table_names() + assert "workspaces" not in tables + + for table in ( + "experiments", + "registered_models", + "model_versions", + "registered_model_tags", + "model_version_tags", + "registered_model_aliases", + "evaluation_datasets", + ): + column_names = {col["name"] for col in inspector.get_columns(table)} + assert "workspace" not in column_names + + with engine.connect() as conn: + assert conn.execute( + sa.text("SELECT experiment_id, name FROM experiments ORDER BY experiment_id") + ).fetchall() == [(1, "exp-default"), (2, "exp-team-a")] + assert conn.execute( + sa.text("SELECT run_uuid, experiment_id FROM runs ORDER BY run_uuid") + ).fetchall() == [("run-default", 1)] + + pk_registered_models = inspector.get_pk_constraint("registered_models") + assert pk_registered_models["constrained_columns"] == ["name"] + + pk_model_versions = inspector.get_pk_constraint("model_versions") + assert pk_model_versions["constrained_columns"] == ["name", "version"] + + pk_registered_model_tags = inspector.get_pk_constraint("registered_model_tags") + assert pk_registered_model_tags["constrained_columns"] == ["key", "name"] + + pk_model_version_tags = inspector.get_pk_constraint("model_version_tags") + assert pk_model_version_tags["constrained_columns"] == ["key", "name", "version"] + + pk_registered_model_aliases = inspector.get_pk_constraint("registered_model_aliases") + assert pk_registered_model_aliases["constrained_columns"] == ["name", "alias"] + + try: + unique_experiments = inspector.get_unique_constraints("experiments") + except NotImplementedError: + if inspector.bind.dialect.name == "mssql": + unique_experiments = None + else: + raise + if unique_experiments is not None: + assert any( + set(constraint.get("column_names", [])) == {"name"} + for constraint in unique_experiments + ) + + fk_model_versions = inspector.get_foreign_keys("model_versions") + assert any( + fk.get("constrained_columns") == ["name"] + and fk.get("referred_table") == "registered_models" + for fk in fk_model_versions + ) + + fk_registered_model_tags = inspector.get_foreign_keys("registered_model_tags") + assert any( + fk.get("constrained_columns") == ["name"] + and fk.get("referred_table") == "registered_models" + for fk in fk_registered_model_tags + ) + + fk_model_version_tags = inspector.get_foreign_keys("model_version_tags") + assert any( + fk.get("constrained_columns") == ["name", "version"] + and fk.get("referred_table") == "model_versions" + for fk in fk_model_version_tags + ) + finally: + engine.dispose() + + +def _setup_experiment_conflict(conn): + _insert_experiment(conn, experiment_id=1, name="duplicate-exp", workspace="default") + _insert_run(conn, run_uuid="run-exp-default", experiment_id=1) + _insert_experiment(conn, experiment_id=2, name="duplicate-exp", workspace="team-a") + + +def _setup_registered_model_conflict(conn): + _insert_registered_model(conn, name="duplicate-model", workspace="default") + _insert_registered_model(conn, name="duplicate-model", workspace="team-a") + + +def _setup_model_version_conflict(conn): + _insert_registered_model(conn, name="mv-model", workspace="default") + _insert_registered_model(conn, name="mv-model", workspace="team-a") + _insert_model_version(conn, name="mv-model", version=1, workspace="default") + _insert_model_version(conn, name="mv-model", version=1, workspace="team-a") + + +def _setup_registered_model_tag_conflict(conn): + _insert_registered_model(conn, name="tag-model", workspace="default") + _insert_registered_model(conn, name="tag-model", workspace="team-a") + _insert_registered_model_tag(conn, workspace="default", name="tag-model", key="tag-key") + _insert_registered_model_tag(conn, workspace="team-a", name="tag-model", key="tag-key") + + +def _setup_model_version_tag_conflict(conn): + _insert_registered_model(conn, name="mvt-model", workspace="default") + _insert_registered_model(conn, name="mvt-model", workspace="team-a") + _insert_model_version(conn, name="mvt-model", version=1, workspace="default") + _insert_model_version(conn, name="mvt-model", version=1, workspace="team-a") + _insert_model_version_tag( + conn, workspace="default", name="mvt-model", version=1, key="mv-tag-key" + ) + _insert_model_version_tag( + conn, workspace="team-a", name="mvt-model", version=1, key="mv-tag-key" + ) + + +def _setup_registered_model_alias_conflict(conn): + _insert_registered_model(conn, name="alias-model", workspace="default") + _insert_registered_model(conn, name="alias-model", workspace="team-a") + _insert_registered_model_alias(conn, workspace="default", name="alias-model", alias="latest") + _insert_registered_model_alias(conn, workspace="team-a", name="alias-model", alias="latest") + + +def _setup_evaluation_dataset_conflict(conn): + _insert_evaluation_dataset( + conn, dataset_id="ds-default", name="duplicate-ds", workspace="default" + ) + _insert_evaluation_dataset( + conn, dataset_id="ds-team-a", name="duplicate-ds", workspace="team-a" + ) + + +@pytest.mark.parametrize( + ("setup_conflict", "expected_fragment", "case_slug"), + [ + (_setup_experiment_conflict, "duplicate experiments with the same name", "experiments"), + ( + _setup_registered_model_conflict, + "duplicate registered models with the same name", + "models", + ), + ( + _setup_evaluation_dataset_conflict, + "duplicate evaluation datasets with the same name", + "evaluation_datasets", + ), + ], +) +def test_workspace_migration_downgrade_detects_conflicts( + tmp_path, setup_conflict, expected_fragment, case_slug +): + case_dir = tmp_path / f"conflict_{case_slug}" + case_dir.mkdir() + engine, config = _prepare_database(case_dir) + try: + command.upgrade(config, REVISION) + with engine.begin() as conn: + _add_workspace(conn, "team-a", "Team A") + setup_conflict(conn) + + with pytest.raises( + RuntimeError, + match=re.escape(expected_fragment), + ): + command.downgrade(config, PREVIOUS_REVISION) + finally: + engine.dispose() + + +@pytest.mark.parametrize( + ("setup_conflict", "table_name", "columns", "case_slug"), + [ + ( + _setup_model_version_conflict, + "model_versions", + ("name", "version"), + "model_versions", + ), + ( + _setup_registered_model_tag_conflict, + "registered_model_tags", + ("name", "key"), + "registered_model_tags", + ), + ( + _setup_model_version_tag_conflict, + "model_version_tags", + ("name", "version", "key"), + "model_version_tags", + ), + ( + _setup_registered_model_alias_conflict, + "registered_model_aliases", + ("name", "alias"), + "registered_model_aliases", + ), + ], +) +def test_workspace_migration_conflict_detection_queries( + tmp_path, setup_conflict, table_name, columns, case_slug +): + case_dir = tmp_path / f"conflict_query_{case_slug}" + case_dir.mkdir() + engine, config = _prepare_database(case_dir) + try: + command.upgrade(config, REVISION) + with engine.begin() as conn: + _add_workspace(conn, "team-a", "Team A") + setup_conflict(conn) + conflicts = _fetch_conflicts(conn, table_name, columns) + assert conflicts, f"Expected conflicts for {table_name}, found none" + finally: + engine.dispose() diff --git a/tests/resources/db/latest_schema.sql b/tests/resources/db/latest_schema.sql index 6903019a42c37..fecf909f2db82 100644 --- a/tests/resources/db/latest_schema.sql +++ b/tests/resources/db/latest_schema.sql @@ -26,6 +26,7 @@ CREATE TABLE evaluation_datasets ( last_update_time BIGINT, created_by VARCHAR(255), last_updated_by VARCHAR(255), + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT evaluation_datasets_pk PRIMARY KEY (dataset_id) ) @@ -37,8 +38,9 @@ CREATE TABLE experiments ( lifecycle_stage VARCHAR(32), creation_time BIGINT, last_update_time BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT experiment_pk PRIMARY KEY (experiment_id), - UNIQUE (name), + CONSTRAINT uq_experiments_workspace_name UNIQUE (workspace, name), CONSTRAINT experiments_lifecycle_stage CHECK (lifecycle_stage IN ('active', 'deleted')) ) @@ -81,8 +83,8 @@ CREATE TABLE registered_models ( creation_time BIGINT, last_updated_time BIGINT, description VARCHAR(5000), - CONSTRAINT registered_model_pk PRIMARY KEY (name), - UNIQUE (name) + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_pk PRIMARY KEY (workspace, name) ) @@ -96,10 +98,18 @@ CREATE TABLE webhooks ( creation_timestamp BIGINT, last_updated_timestamp BIGINT, deleted_timestamp BIGINT, + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, CONSTRAINT webhook_pk PRIMARY KEY (webhook_id) ) +CREATE TABLE workspaces ( + name VARCHAR(63) NOT NULL, + description TEXT, + CONSTRAINT workspaces_pk PRIMARY KEY (name) +) + + CREATE TABLE datasets ( dataset_uuid VARCHAR(36) NOT NULL, experiment_id INTEGER NOT NULL, @@ -185,8 +195,9 @@ CREATE TABLE model_versions ( status_message VARCHAR(500), run_link VARCHAR(500), storage_location VARCHAR(500), - CONSTRAINT model_version_pk PRIMARY KEY (name, version), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_pk PRIMARY KEY (workspace, name, version), + CONSTRAINT fk_model_versions_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -194,8 +205,9 @@ CREATE TABLE registered_model_aliases ( alias VARCHAR(256) NOT NULL, version INTEGER NOT NULL, name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_alias_pk PRIMARY KEY (name, alias), - CONSTRAINT registered_model_alias_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON DELETE CASCADE ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_alias_pk PRIMARY KEY (workspace, name, alias), + CONSTRAINT fk_registered_model_aliases_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -203,8 +215,9 @@ CREATE TABLE registered_model_tags ( key VARCHAR(250) NOT NULL, value VARCHAR(5000), name VARCHAR(256) NOT NULL, - CONSTRAINT registered_model_tag_pk PRIMARY KEY (key, name), - FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT registered_model_tag_pk PRIMARY KEY (workspace, key, name), + CONSTRAINT fk_registered_model_tags_registered_models FOREIGN KEY(workspace, name) REFERENCES registered_models (workspace, name) ON UPDATE CASCADE ) @@ -356,8 +369,9 @@ CREATE TABLE model_version_tags ( value TEXT, name VARCHAR(256) NOT NULL, version INTEGER NOT NULL, - CONSTRAINT model_version_tag_pk PRIMARY KEY (key, name, version), - FOREIGN KEY(name, version) REFERENCES model_versions (name, version) ON UPDATE CASCADE + workspace VARCHAR(63) DEFAULT 'default' NOT NULL, + CONSTRAINT model_version_tag_pk PRIMARY KEY (workspace, key, name, version), + CONSTRAINT fk_model_version_tags_model_versions FOREIGN KEY(workspace, name, version) REFERENCES model_versions (workspace, name, version) ON UPDATE CASCADE ) diff --git a/tests/store/tracking/test_sqlalchemy_store_schema.py b/tests/store/tracking/test_sqlalchemy_store_schema.py index 37f314a6d0610..cfe9039c40f6c 100644 --- a/tests/store/tracking/test_sqlalchemy_store_schema.py +++ b/tests/store/tracking/test_sqlalchemy_store_schema.py @@ -9,6 +9,10 @@ from alembic.script import ScriptDirectory import mlflow.db + +# Import workspace models temporarily for tests to pass. +# This can be removed once we have a workspace store imported. +import mlflow.store.workspace.dbmodels as _workspace_models # noqa: F401 from mlflow.exceptions import MlflowException from mlflow.store.db.base_sql_model import Base from mlflow.store.db.utils import _get_alembic_config, _verify_schema