Skip to content

Commit eea53da

Browse files
committed
feat: Add _copier_conf.operation variable
1 parent 0315674 commit eea53da

File tree

6 files changed

+103
-4
lines changed

6 files changed

+103
-4
lines changed

copier/main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
_T = TypeVar("_T")
6161

6262

63+
Operation = Literal["copy", "recopy", "update"]
64+
65+
6366
@dataclass(config=ConfigDict(extra="forbid"))
6467
class Worker:
6568
"""Copier process state manager.
@@ -195,6 +198,7 @@ class Worker:
195198
unsafe: bool = False
196199
skip_answered: bool = False
197200
skip_tasks: bool = False
201+
operation: Operation = "copy"
198202

199203
answers: AnswersMap = field(default_factory=AnswersMap, init=False)
200204
_cleanup_hooks: list[Callable[[], None]] = field(default_factory=list, init=False)
@@ -234,7 +238,7 @@ def _cleanup(self) -> None:
234238
for method in self._cleanup_hooks:
235239
method()
236240

237-
def _check_unsafe(self, mode: Literal["copy", "update"]) -> None:
241+
def _check_unsafe(self, mode: Operation) -> None:
238242
"""Check whether a template uses unsafe features."""
239243
if self.unsafe:
240244
return
@@ -846,7 +850,7 @@ def run_recopy(self) -> None:
846850
"Cannot recopy because cannot obtain old template references "
847851
f"from `{self.subproject.answers_relpath}`."
848852
)
849-
with replace(self, src_path=self.subproject.template.url) as new_worker:
853+
with replace(self, src_path=self.subproject.template.url, operation="recopy") as new_worker:
850854
new_worker.run_copy()
851855

852856
def run_update(self) -> None:
@@ -896,8 +900,10 @@ def run_update(self) -> None:
896900
print(
897901
f"Updating to template version {self.template.version}", file=sys.stderr
898902
)
899-
self._apply_update()
900-
self._print_message(self.template.message_after_update)
903+
with replace(self, operation="update") as worker:
904+
worker._apply_update()
905+
worker._print_message(worker.template.message_after_update)
906+
self.answers = worker.answers
901907

902908
def _apply_update(self) -> None: # noqa: C901
903909
git = get_git()

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
from .helpers import Spawn
1414

15+
pytest_plugins = [
16+
"tests.templates",
17+
]
18+
1519

1620
@pytest.fixture
1721
def spawn() -> Spawn:

tests/templates.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import shutil
2+
from pathlib import Path
3+
from typing import Generator
4+
5+
import pytest
6+
7+
from .helpers import build_file_tree, git_save
8+
9+
10+
@pytest.fixture(params=("copy", "recopy", "update"))
11+
def operation_context_template(tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest) -> Generator[Path, None, None]:
12+
src = tmp_path_factory.mktemp(f"operation_template_{request.param}")
13+
try:
14+
build_file_tree(
15+
{
16+
(src / f"{{% if _copier_conf.operation == '{request.param}' %}}foo{{% endif %}}"): "foo",
17+
(src / "bar"): "bar",
18+
(src / "{{ _copier_conf.answers_file }}.jinja"): "{{ _copier_answers|to_nice_yaml }}",
19+
}
20+
)
21+
git_save(src, tag="1.0.0")
22+
yield src
23+
finally:
24+
shutil.rmtree(src, ignore_errors=True)
25+
26+
27+
@pytest.fixture
28+
def operation_context_template_v2(operation_context_template: Path) -> Path:
29+
conditional_file = next(iter(operation_context_template.glob("*foo*")))
30+
build_file_tree(
31+
{
32+
conditional_file: "foo_update",
33+
(operation_context_template / "bar"): "bar_update",
34+
}
35+
)
36+
git_save(operation_context_template, tag="2.0.0")
37+
return operation_context_template

tests/test_copy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,13 @@ def test_multiselect_choices_preserve_order(
945945
)
946946
copier.run_copy(str(src), dst, data={"q": ["three", "one", "two"]})
947947
assert yaml.safe_load((dst / "q.yml").read_text()) == ["one", "two", "three"]
948+
949+
950+
def test_operation_context(tmp_path: Path, operation_context_template: Path) -> None:
951+
run_copy(str(operation_context_template), tmp_path)
952+
conditional_file = tmp_path / "foo"
953+
expected = "_copy" in operation_context_template.name
954+
assert conditional_file.exists() is expected
955+
if expected:
956+
assert conditional_file.read_text() == "foo"
957+
assert (tmp_path / "bar").read_text() == "bar"

tests/test_recopy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,23 @@ def test_recopy_works_without_replay(tpl: str, tmp_path: Path) -> None:
7272
# Recopy
7373
run_recopy(tmp_path, skip_answered=True, overwrite=True)
7474
assert (tmp_path / "name.txt").read_text() == "This is my name: Mario."
75+
76+
77+
def test_operation_context(tmp_path: Path, operation_context_template: Path) -> None:
78+
run_copy(str(operation_context_template), tmp_path)
79+
git_save(tmp_path)
80+
conditional_file = tmp_path / "foo"
81+
expected_copy = "_copy" in operation_context_template.name
82+
expected_recopy = "recopy" in operation_context_template.name
83+
assert conditional_file.exists() is expected_copy
84+
assert (tmp_path / "bar").read_text() == "bar"
85+
if expected_copy:
86+
assert conditional_file.read_text() == "foo"
87+
conditional_file.unlink()
88+
(tmp_path / "bar").write_text("baz")
89+
git_save(tmp_path)
90+
run_recopy(str(tmp_path), overwrite=True)
91+
assert conditional_file.exists() is expected_recopy
92+
if expected_recopy:
93+
assert conditional_file.read_text() == "foo"
94+
assert (tmp_path / "bar").read_text() == "bar"

tests/test_updatediff.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
build_file_tree,
2626
git,
2727
git_init,
28+
git_save,
2829
)
2930

3031

@@ -1290,3 +1291,24 @@ def test_update_with_new_file_in_template_and_project_via_migration(
12901291
>>>>>>> after updating
12911292
"""
12921293
)
1294+
1295+
1296+
def test_operation_context(tmp_path: Path, operation_context_template: Path, request: pytest.FixtureRequest) -> None:
1297+
run_copy(str(operation_context_template), tmp_path)
1298+
conditional_file = tmp_path / "foo"
1299+
expected_copy = "_copy" in operation_context_template.name
1300+
expected_update = "update" in operation_context_template.name
1301+
assert conditional_file.exists() is expected_copy
1302+
assert (tmp_path / "bar").read_text() == "bar"
1303+
if expected_copy:
1304+
assert conditional_file.read_text() == "foo"
1305+
git_save(tmp_path)
1306+
request.getfixturevalue("operation_context_template_v2")
1307+
run_update(str(tmp_path), overwrite=True)
1308+
if expected_update:
1309+
assert conditional_file.read_text() == "foo_update"
1310+
elif expected_copy:
1311+
assert conditional_file.read_text() == "foo"
1312+
else:
1313+
assert not conditional_file.exists()
1314+
assert (tmp_path / "bar").read_text() == "bar_update"

0 commit comments

Comments
 (0)