Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 188 additions & 25 deletions src/time_machine/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(argv: Sequence[str] | None = None) -> int:

migrate_parser = subparsers.add_parser(
"migrate",
help="Migrate Python files from freezegun to time-machine",
help="Migrate Python files from freezegun or common.time to time-machine",
)
migrate_parser.add_argument("file", nargs="+")

Expand Down Expand Up @@ -76,7 +76,7 @@ def migrate_file(filename: str) -> int:


def migrate_contents(contents_text: str) -> str:
"""Migrate a single text from freezegun to time-machine."""
"""Migrate a single text from freezegun or common.time to time-machine."""
try:
ast_obj = ast_parse(contents_text)
except SyntaxError:
Expand Down Expand Up @@ -148,16 +148,21 @@ def visit(tree: ast.Module) -> Mapping[Offset, list[TokenFunc]]:
freezegun_import_seen = True
ret[ast_start_offset(node)].append(replace_import)
elif isinstance(node, ast.ImportFrom):
if (
node.module == "freezegun"
and len(node.names) == 1
and (alias := node.names[0]).name == "freeze_time"
and alias.asname is None
):
freeze_time_import_seen = True
ret[ast_start_offset(node)].append(
partial(replace_import_from, node=node)
if node.module in {"freezegun", "common.time"}:
has_unaliased_freeze_time = any(
alias.name == "freeze_time" and alias.asname is None
for alias in node.names
)
if has_unaliased_freeze_time:
freeze_time_import_seen = True
if len(node.names) == 1:
ret[ast_start_offset(node)].append(
partial(replace_import_from, node=node)
)
else:
ret[ast_start_offset(node)].append(
partial(replace_import_from_multi, node=node)
)
elif isinstance(node, ast.FunctionDef):
for decorator in node.decorator_list:
if (
Expand All @@ -181,9 +186,28 @@ def visit(tree: ast.Module) -> Mapping[Offset, list[TokenFunc]]:
ret[ast_start_offset(decorator.func)].append(
partial(switch_to_travel, node=decorator.func)
)
ret[ast_start_offset(decorator)].append(
partial(add_tick_false, node=decorator)
)
# remove tz=...
if _has_tz_keyword(decorator):
for kw in decorator.keywords:
if kw.arg == "tz":
ret[ast_start_offset(kw)].append(
partial(remove_keyword, node=kw)
)
break
if _has_auto_tick_seconds_keyword(decorator):
for kw in decorator.keywords:
if kw.arg == "auto_tick_seconds":
ret[ast_start_offset(kw)].append(
partial(remove_keyword, node=kw)
)
break
ret[ast_start_offset(decorator)].append(
partial(add_tick_true, node=decorator)
)
elif not _has_tick_keyword(decorator):
ret[ast_start_offset(decorator)].append(
partial(add_tick_false, node=decorator)
)

elif isinstance(node, ast.ClassDef):
if node.decorator_list and looks_like_unittest_class(node):
Expand All @@ -209,9 +233,28 @@ def visit(tree: ast.Module) -> Mapping[Offset, list[TokenFunc]]:
ret[ast_start_offset(decorator.func)].append(
partial(switch_to_travel, node=decorator.func)
)
ret[ast_start_offset(decorator)].append(
partial(add_tick_false, node=decorator)
)
# remove tz=...
if _has_tz_keyword(decorator):
for kw in decorator.keywords:
if kw.arg == "tz":
ret[ast_start_offset(kw)].append(
partial(remove_keyword, node=kw)
)
break
if _has_auto_tick_seconds_keyword(decorator):
for kw in decorator.keywords:
if kw.arg == "auto_tick_seconds":
ret[ast_start_offset(kw)].append(
partial(remove_keyword, node=kw)
)
break
ret[ast_start_offset(decorator)].append(
partial(add_tick_true, node=decorator)
)
elif not _has_tick_keyword(decorator):
ret[ast_start_offset(decorator)].append(
partial(add_tick_false, node=decorator)
)

elif isinstance(node, ast.With):
for item in node.items:
Expand All @@ -238,19 +281,53 @@ def visit(tree: ast.Module) -> Mapping[Offset, list[TokenFunc]]:
ret[ast_start_offset(context_expr.func)].append(
partial(switch_to_travel, node=context_expr.func)
)
ret[ast_start_offset(context_expr)].append(
partial(add_tick_false, node=context_expr)
)
# remove tz=...
if _has_tz_keyword(context_expr):
for kw in context_expr.keywords:
if kw.arg == "tz":
ret[ast_start_offset(kw)].append(
partial(remove_keyword, node=kw)
)
break
# If auto_tick_seconds is present, remove it and add tick=True
if _has_auto_tick_seconds_keyword(context_expr):
for kw in context_expr.keywords:
if kw.arg == "auto_tick_seconds":
ret[ast_start_offset(kw)].append(
partial(remove_keyword, node=kw)
)
break
ret[ast_start_offset(context_expr)].append(
partial(add_tick_true, node=context_expr)
)
# Otherwise, only add tick=False when tick is not already present
elif not _has_tick_keyword(context_expr):
ret[ast_start_offset(context_expr)].append(
partial(add_tick_false, node=context_expr)
)

return ret # type: ignore [return-value]


def _has_tick_keyword(node: ast.Call) -> bool:
return any(kw.arg == "tick" for kw in node.keywords)


def _has_auto_tick_seconds_keyword(node: ast.Call) -> bool:
return any(kw.arg == "auto_tick_seconds" for kw in node.keywords)


def _has_tz_keyword(node: ast.Call) -> bool:
return any(kw.arg == "tz" for kw in node.keywords)


def _only_allowed_keywords(node: ast.Call) -> bool:
# Allow only supported keywords that we know how to migrate
return all(kw.arg in {"tick", "auto_tick_seconds", "tz"} for kw in node.keywords)


def migratable_call(node: ast.Call) -> bool:
return (
len(node.args) == 1
# We could allow tick being set, as long as we didn't then add it
and len(node.keywords) == 0
)
return len(node.args) == 1 and _only_allowed_keywords(node)


def looks_like_unittest_class(node: ast.ClassDef) -> bool:
Expand Down Expand Up @@ -354,6 +431,22 @@ def replace_import_from(tokens: list[Token], i: int, node: ast.ImportFrom) -> No
tokens[i : j + 1] = [Token(name=CODE, src="import time_machine")]


def replace_import_from_multi(
tokens: list[Token], i: int, node: ast.ImportFrom
) -> None:
j = find_last_token(tokens, i, node=node)
remaining_names: list[str] = []
for alias in node.names:
if alias.name == "freeze_time" and alias.asname is None:
continue
name_src = alias.name
if alias.asname:
name_src += f" as {alias.asname}"
remaining_names.append(name_src)
src = f"import time_machine\nfrom {node.module} import {', '.join(remaining_names)}"
tokens[i : j + 1] = [Token(name=CODE, src=src)]


def switch_to_travel(
tokens: list[Token], i: int, node: ast.Attribute | ast.Name
) -> None:
Expand All @@ -365,10 +458,80 @@ def add_tick_false(tokens: list[Token], i: int, node: ast.Call) -> None:
"""
Add `tick=False` to the function call.
"""
# If tick is already provided, do not add
has_tick_kw = any(
isinstance(kw, ast.keyword) and kw.arg == "tick" for kw in node.keywords
)
if has_tick_kw:
return
j = find_last_token(tokens, i, node=node)
tokens.insert(j, Token(name=CODE, src=", tick=False"))


def add_tick_true(tokens: list[Token], i: int, node: ast.Call) -> None:
"""
Add `tick=True` to the function call.
"""
has_tick_kw = any(
isinstance(kw, ast.keyword) and kw.arg == "tick" for kw in node.keywords
)
if has_tick_kw:
return
j = find_last_token(tokens, i, node=node)
tokens.insert(j, Token(name=CODE, src=", tick=True"))


def remove_keyword(tokens: list[Token], i: int, node: ast.keyword) -> None:
"""
Remove an entire keyword argument (e.g., `tz=utc_tz()`), including the
name, equals sign, the value expression, and an adjacent comma (prefer
removing a trailing comma when present; otherwise the preceding comma).
This avoids leaving behind stray commas like `,,`.
"""
j = find_last_token(tokens, i, node=node) # last token of the value expr

# Find the NAME token for the keyword arg to the left of the value
arg_name = getattr(node, "arg", None)
k = i
while k > 0 and not (tokens[k].name == "NAME" and tokens[k].src == arg_name):
k -= 1

# Compute removal bounds
start_remove = k # include the keyword name
end_remove = j + 1 # include the value expression

# Expand right to include spaces and a trailing comma if present
while end_remove < len(tokens) and tokens[end_remove].name == UNIMPORTANT_WS:
end_remove += 1
if end_remove < len(tokens) and tokens[end_remove].src == ",":
end_remove += 1
if end_remove < len(tokens) and tokens[end_remove].name == UNIMPORTANT_WS:
end_remove += 1

# Expand left from the arg name to include `=` and surrounding spaces
p = k + 1
while p < len(tokens) and tokens[p].name == UNIMPORTANT_WS:
p += 1
if p < len(tokens) and tokens[p].src == "=":
p += 1
while p < len(tokens) and tokens[p].name == UNIMPORTANT_WS:
p += 1
# start_remove already at the NAME token

# If we did not include a trailing comma, prefer removing a preceding comma
# and its surrounding spaces (for last-argument cases)
if end_remove == j + 1:
l = start_remove - 1
while l >= 0 and tokens[l].name == UNIMPORTANT_WS:
l -= 1
if l >= 0 and tokens[l].src == ",":
start_remove = l
# consume any whitespace between comma and name
# (the whitespace is between start_remove and original k)

tokens[start_remove:end_remove] = []


# Token functions


Expand Down