Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 28 additions & 58 deletions alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,12 @@ def version_table_impl(
)
if version_table_pk:
vt.append_constraint(
PrimaryKeyConstraint(
"version_num", name=f"{version_table}_pkc"
)
PrimaryKeyConstraint("version_num", name=f"{version_table}_pkc")
)

return vt

def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
def requires_recreate_in_batch(self, batch_op: BatchOperationsImpl) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Expand All @@ -195,9 +191,7 @@ def requires_recreate_in_batch(
"""
return False

def prep_table_for_batch(
self, batch_impl: ApplyBatchImpl, table: Table
) -> None:
def prep_table_for_batch(self, batch_impl: ApplyBatchImpl, table: Table) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.

Expand All @@ -224,9 +218,7 @@ def _exec(
raise TypeError("SQL parameters not allowed with as_sql")

compile_kw: dict[str, Any]
if self.literal_binds and not isinstance(
construct, schema.DDLElement
):
if self.literal_binds and not isinstance(construct, schema.DDLElement):
compile_kw = dict(compile_kwargs={"literal_binds": True})
else:
compile_kw = {}
Expand All @@ -235,8 +227,7 @@ def _exec(
assert isinstance(construct, ClauseElement)
compiled = construct.compile(dialect=self.dialect, **compile_kw)
self.static_output(
str(compiled).replace("\t", " ").strip()
+ self.command_terminator
str(compiled).replace("\t", " ").strip() + self.command_terminator
)
return None
else:
Expand All @@ -246,9 +237,7 @@ def _exec(
conn = conn.execution_options(**execution_options)

if params and multiparams is not None:
raise TypeError(
"Can't send params and multiparams at the same time"
)
raise TypeError("Can't send params and multiparams at the same time")

if multiparams:
return conn.execute(construct, multiparams)
Expand All @@ -268,9 +257,7 @@ def alter_column(
column_name: str,
*,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefaultType, Literal[False]]
] = False,
server_default: Optional[Union[_ServerDefaultType, Literal[False]]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
Expand All @@ -287,8 +274,7 @@ def alter_column(
) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
"only make sense for MySQL",
"autoincrement and existing_autoincrement only make sense for MySQL",
stacklevel=3,
)
if nullable is not None:
Expand Down Expand Up @@ -410,9 +396,7 @@ def drop_column(
**kw,
) -> None:
self._exec(
base.DropColumn(
table_name, column, schema=schema, if_exists=if_exists
)
base.DropColumn(table_name, column, schema=schema, if_exists=if_exists)
)

def add_constraint(self, const: Any, **kw: Any) -> None:
Expand All @@ -431,9 +415,7 @@ def rename_table(
new_table_name: Union[str, quoted_name],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
self._exec(base.RenameTable(old_table_name, new_table_name, schema=schema))

def create_table(self, table: Table, **kw: Any) -> None:
table.dispatch.before_create(
Expand Down Expand Up @@ -503,9 +485,7 @@ def bulk_insert(
sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
if not isinstance(v, sqla_compat._literal_bindparam)
else v
)
for k, v in row.items()
Expand Down Expand Up @@ -570,9 +550,7 @@ def _column_types_match(
inspector_all_terms = " ".join(
[inspector_params.token0] + inspector_params.tokens
)
metadata_all_terms = " ".join(
[metadata_params.token0] + metadata_params.tokens
)
metadata_all_terms = " ".join([metadata_params.token0] + metadata_params.tokens)

for batch in synonyms:
if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
Expand All @@ -582,9 +560,7 @@ def _column_types_match(
return True
return False

def _column_args_match(
self, inspected_params: Params, meta_params: Params
) -> bool:
def _column_args_match(self, inspected_params: Params, meta_params: Params) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
Expand Down Expand Up @@ -651,9 +627,7 @@ def correct_for_autogen_constraints(

def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
if existing.type._type_affinity is not new_type._type_affinity:
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
existing_transfer["expr"] = cast(existing_transfer["expr"], new_type)

def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
Expand All @@ -665,9 +639,7 @@ def render_ddl_sql_expr(

compile_kw = {"literal_binds": True, "include_table": False}

return str(
expr.compile(dialect=self.dialect, compile_kwargs=compile_kw)
)
return str(expr.compile(dialect=self.dialect, compile_kwargs=compile_kw))

def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
return self.autogen_column_reflect
Expand All @@ -688,6 +660,14 @@ def autogen_column_reflect(self, inspector, table, column_info):

"""

def autogen_table_reflect(self, inspector, table):
"""A hook that is called when a Table is reflected from the
database during the autogenerate process.

Dialects can elect to modify the information gathered here.

"""

def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
Expand Down Expand Up @@ -780,9 +760,7 @@ def compare_indexes(
This method returns a ``ComparisonResult``.
"""
msg: List[str] = []
unique_msg = self._compare_index_unique(
metadata_index, reflected_index
)
unique_msg = self._compare_index_unique(metadata_index, reflected_index)
if unique_msg:
msg.append(unique_msg)
m_sig = self._create_metadata_constraint_sig(metadata_index)
Expand All @@ -803,9 +781,7 @@ def compare_indexes(
)

if m_sig.column_names != r_sig.column_names:
msg.append(
f"expression {r_sig.column_names} to {m_sig.column_names}"
)
msg.append(f"expression {r_sig.column_names} to {m_sig.column_names}")

if msg:
return ComparisonResult.Different(msg)
Expand All @@ -824,19 +800,13 @@ def compare_unique_constraint(

This method returns a ``ComparisonResult``.
"""
metadata_tup = self._create_metadata_constraint_sig(
metadata_constraint
)
reflected_tup = self._create_reflected_constraint_sig(
reflected_constraint
)
metadata_tup = self._create_metadata_constraint_sig(metadata_constraint)
reflected_tup = self._create_reflected_constraint_sig(reflected_constraint)

meta_sig = metadata_tup.unnamed
conn_sig = reflected_tup.unnamed
if conn_sig != meta_sig:
return ComparisonResult.Different(
f"expression {conn_sig} to {meta_sig}"
)
return ComparisonResult.Different(f"expression {conn_sig} to {meta_sig}")
else:
return ComparisonResult.Equal()

Expand Down
34 changes: 19 additions & 15 deletions alembic/ddl/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ class SQLiteImpl(DefaultImpl):
see: http://bugs.python.org/issue10740
"""

def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
def requires_recreate_in_batch(self, batch_op: BatchOperationsImpl) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Expand All @@ -60,9 +58,9 @@ def requires_recreate_in_batch(
for op in batch_op.batch:
if op[0] == "add_column":
col = op[1][1]
if isinstance(
col.server_default, schema.DefaultClause
) and isinstance(col.server_default.arg, sql.ClauseElement):
if isinstance(col.server_default, schema.DefaultClause) and isinstance(
col.server_default.arg, sql.ClauseElement
):
return True
elif (
isinstance(col.server_default, Computed)
Expand Down Expand Up @@ -160,6 +158,17 @@ def autogen_column_reflect(
):
column_info["default"] = "(%s)" % (column_info["default"],)

def autogen_table_reflect(self, inspector, table):
sql_text = sql.text(
"SELECT sql FROM sqlite_master WHERE name=:name AND type='table'"
)
res = inspector.bind.execute(sql_text, {"name": table.name}).scalar()
if res:
if re.search(r"\bSTRICT\b\s*;?\s*$", res, re.I):
table.kwargs["sqlite_strict"] = True
if re.search(r"\bWITHOUT ROWID\b", res, re.I):
table.kwargs["sqlite_with_rowid"] = False

def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw
) -> str:
Expand All @@ -169,9 +178,8 @@ def render_ddl_sql_expr(
expr, is_server_default=is_server_default, **kw
)

if (
is_server_default
and self._guess_if_default_is_unparenthesized_sql_expr(str_expr)
if is_server_default and self._guess_if_default_is_unparenthesized_sql_expr(
str_expr
):
str_expr = "(%s)" % (str_expr,)
return str_expr
Expand All @@ -186,9 +194,7 @@ def cast_for_batch_migrate(
existing.type._type_affinity is not new_type._type_affinity
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
existing_transfer["expr"] = cast(existing_transfer["expr"], new_type)

def correct_for_autogen_constraints(
self,
Expand All @@ -201,9 +207,7 @@ def correct_for_autogen_constraints(


@compiles(RenameTable, "sqlite")
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
def visit_rename_table(element: RenameTable, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
Expand Down
Loading