Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/django_upgrade/fixers/mail_api_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
),
}

MESSAGE_MODULE_NAMES = {"EmailMessage", "EmailMultiAlternatives"}
MESSAGE_MODULE_NAMES = frozenset({"EmailMessage", "EmailMultiAlternatives"})


@fixer.register(ast.Call)
Expand Down
12 changes: 7 additions & 5 deletions src/django_upgrade/fixers/model_field_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def defined_enumeration_types(module: ast.Module, up_to_line: int) -> set[str]:
}


DJANGO_CHOICES_TYPES = {
"TextChoices",
"IntegerChoices",
"Choices",
}
DJANGO_CHOICES_TYPES = frozenset(
{
"TextChoices",
"IntegerChoices",
"Choices",
}
)


def _is_django_choices_type(
Expand Down
8 changes: 5 additions & 3 deletions src/django_upgrade/fixers/on_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
min_version=(1, 9),
)

RELATION_FIELD_NAMES = frozenset({"ForeignKey", "OneToOneField"})


@fixer.register(ast.ImportFrom)
def visit_ImportFrom(
Expand All @@ -31,7 +33,7 @@ def visit_ImportFrom(
if (
node.module == "django.db.models"
and is_rewritable_import_from(node)
and any(alias.name in {"ForeignKey", "OneToOneField"} for alias in node.names)
and any(alias.name in RELATION_FIELD_NAMES for alias in node.names)
):
yield (
ast_start_offset(node),
Expand Down Expand Up @@ -70,14 +72,14 @@ def visit_Call(
(
(
isinstance(node.func, ast.Attribute)
and node.func.attr in {"ForeignKey", "OneToOneField"}
and node.func.attr in RELATION_FIELD_NAMES
and (models_imported := "models" in state.from_imports["django.db"])
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "models"
)
or (
isinstance(node.func, ast.Name)
and node.func.id in {"ForeignKey", "OneToOneField"}
and node.func.id in RELATION_FIELD_NAMES
and node.func.id in state.from_imports["django.db.models"]
and (models_imported := False) is False # force walrus
)
Expand Down
14 changes: 8 additions & 6 deletions src/django_upgrade/fixers/postgres_float_range_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
min_version=(2, 2),
)

MODULES = {
"django.contrib.postgres.fields",
"django.contrib.postgres.fields.ranges",
"django.contrib.postgres.forms",
"django.contrib.postgres.forms.ranges",
}
MODULES = frozenset(
{
"django.contrib.postgres.fields",
"django.contrib.postgres.fields.ranges",
"django.contrib.postgres.forms",
"django.contrib.postgres.forms.ranges",
}
)
NAME_MAP = {
"FloatRangeField": "DecimalRangeField",
}
Expand Down
4 changes: 3 additions & 1 deletion src/django_upgrade/fixers/request_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
min_version=(2, 2),
)

SPECIAL_HEADERS = frozenset({"CONTENT_LENGTH", "CONTENT_TYPE"})


@fixer.register(ast.Subscript)
def visit_Subscript(
Expand Down Expand Up @@ -109,7 +111,7 @@ def get_header_name(meta_name: str) -> str | None:
http_prefix = "HTTP_"
if meta_name.startswith(http_prefix):
name = meta_name[len(http_prefix) :]
elif meta_name in {"CONTENT_LENGTH", "CONTENT_TYPE"}:
elif meta_name in SPECIAL_HEADERS:
name = meta_name
else:
return None
Expand Down