Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions alembic/ddl/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,31 @@ def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
):
cnfk.onupdate = "RESTRICT"

def compare_type(
self,
inspector_column: schema.Column[Any],
metadata_column: schema.Column,
) -> bool:
"""Override compare_type to properly detect MySQL native ENUM changes.

This addresses the issue where autogenerate fails to detect when new
values are added to or removed from MySQL native ENUM columns.
"""
metadata_type = metadata_column.type
inspector_type = inspector_column.type

# Check if both columns are MySQL native ENUMs
if isinstance(metadata_type, sqltypes.Enum) and isinstance(
inspector_type, sqltypes.Enum
):
# Compare the actual enum values; order matters for MySQL ENUMs.
# Changing the order of ENUM values is a schema change in MySQL.
if metadata_type.enums != inspector_type.enums:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Federico Caselli (CaselIT) wrote:

could we use set here? I don't think the order matters

View this in Gerrit at https://gerrit.sqlalchemy.org/c/sqlalchemy/alembic/+/6520

return True

# Fall back to default comparison for non-ENUM types
return super().compare_type(inspector_column, metadata_column)


class MariaDBImpl(MySQLImpl):
__dialect__ = "mariadb"
Expand Down
55 changes: 55 additions & 0 deletions tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy import Column
from sqlalchemy import Computed
from sqlalchemy import DATETIME
from sqlalchemy import Enum
from sqlalchemy import exc
from sqlalchemy import Float
from sqlalchemy import func
Expand All @@ -14,10 +15,12 @@
from sqlalchemy import Table
from sqlalchemy import text
from sqlalchemy import TIMESTAMP
from sqlalchemy.dialects.mysql import ENUM as MySQL_ENUM
from sqlalchemy.dialects.mysql import VARCHAR

from alembic import autogenerate
from alembic import op
from alembic import testing
from alembic import util
from alembic.autogenerate import api
from alembic.autogenerate.compare.constraints import _compare_nullable
Expand All @@ -27,6 +30,7 @@
from alembic.testing import combinations
from alembic.testing import config
from alembic.testing import eq_ignore_whitespace
from alembic.testing import is_
from alembic.testing.env import clear_staging_env
from alembic.testing.env import staging_env
from alembic.testing.fixtures import AlterColRoundTripFixture
Expand All @@ -40,6 +44,7 @@
from alembic.autogenerate.compare.types import (
_dialect_impl_compare_type as _compare_type,
)
from alembic.ddl.mysql import MySQLImpl


class MySQLOpTest(TestBase):
Expand Down Expand Up @@ -784,3 +789,53 @@ def test_render_add_index_expr_func(self):
"op.create_index('foo_idx', 't', "
"['x', sa.literal_column('(coalesce(y, 0))')], unique=False)",
)


class MySQLEnumCompareTest(TestBase):
"""Test MySQL native ENUM comparison in autogenerate."""

__only_on__ = "mysql", "mariadb"
__backend__ = True

@testing.fixture()
def connection(self):
with config.db.begin() as conn:
yield conn

@testing.combinations(
(
Enum("A", "B", "C", native_enum=True),
Enum("A", "B", "C", native_enum=True),
False,
),
(
Enum("A", "B", "C", native_enum=True),
Enum("A", "B", "C", "D", native_enum=True),
True,
),
(
Enum("A", "B", "C", "D", native_enum=True),
Enum("A", "B", "C", native_enum=True),
True,
),
(
Enum("A", "B", "C", native_enum=True),
Enum("C", "B", "A", native_enum=True),
True,
),
(MySQL_ENUM("A", "B", "C"), MySQL_ENUM("A", "B", "C"), False),
(MySQL_ENUM("A", "B", "C"), MySQL_ENUM("A", "B", "C", "D"), True),
id_="ssa",
argnames="inspected_type,metadata_type,expected",
)
def test_compare_enum_types(
self, inspected_type, metadata_type, expected, connection
):
impl = MySQLImpl(connection.dialect, connection, False, None, None, {})

is_(
impl.compare_type(
Column("x", inspected_type), Column("x", metadata_type)
),
expected,
)