Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1dabe8e

Browse files
authoredMar 2, 2025··
feat(context): expose a _copier_phase context variable
1 parent 9736631 commit 1dabe8e

File tree

10 files changed

+204
-14
lines changed

10 files changed

+204
-14
lines changed
 

‎copier/main.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
AnyByStrMutableMapping,
6666
JSONSerializable,
6767
LazyDict,
68+
Phase,
6869
RelativePath,
6970
StrOrPath,
7071
)
@@ -277,10 +278,18 @@ def _external_data(self) -> LazyDict:
277278
Files will only be parsed lazily on 1st access. This helps avoiding
278279
circular dependencies when the file name also comes from a variable.
279280
"""
281+
282+
def _render(path: str) -> str:
283+
with Phase.use(Phase.UNDEFINED):
284+
return self._render_string(path)
285+
286+
# Given those values are lazily rendered on 1st access then cached
287+
# the phase value is irrelevant and could be misleading.
288+
# As a consequence it is explicitely set to "undefined".
280289
return LazyDict(
281290
**{
282291
name: lambda path=path: load_answersfile_data(
283-
self.dst_path, self._render_string(path)
292+
self.dst_path, _render(path)
284293
)
285294
for name, path in self.template.external_data.items()
286295
}
@@ -375,6 +384,7 @@ def _render_context(self) -> AnyByStrMutableMapping:
375384
_copier_conf=conf,
376385
_folder_name=self.subproject.local_abspath.name,
377386
_copier_python=sys.executable,
387+
_copier_phase=Phase.current(),
378388
)
379389

380390
def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]:
@@ -560,7 +570,9 @@ def answers_relpath(self) -> Path:
560570
"""
561571
path = self.answers_file or self.template.answers_relpath
562572
template = self.jinja_env.from_string(str(path))
563-
return Path(template.render(**self.answers.combined))
573+
return Path(
574+
template.render(_copier_phase=Phase.current(), **self.answers.combined)
575+
)
564576

565577
@cached_property
566578
def all_exclusions(self) -> Sequence[str]:
@@ -928,7 +940,8 @@ def run_copy(self) -> None:
928940
"""
929941
self._check_unsafe("copy")
930942
self._print_message(self.template.message_before_copy)
931-
self._ask()
943+
with Phase.use(Phase.PROMPT):
944+
self._ask()
932945
was_existing = self.subproject.local_abspath.exists()
933946
try:
934947
if not self.quiet:
@@ -937,12 +950,14 @@ def run_copy(self) -> None:
937950
f"\nCopying from template version {self.template.version}",
938951
file=sys.stderr,
939952
)
940-
self._render_template()
953+
with Phase.use(Phase.RENDER):
954+
self._render_template()
941955
if not self.quiet:
942956
# TODO Unify printing tools
943957
print("") # padding space
944958
if not self.skip_tasks:
945-
self._execute_tasks(self.template.tasks)
959+
with Phase.use(Phase.TASKS):
960+
self._execute_tasks(self.template.tasks)
946961
except Exception:
947962
if not was_existing and self.cleanup_on_error:
948963
rmtree(self.subproject.local_abspath)
@@ -1044,9 +1059,10 @@ def _apply_update(self) -> None: # noqa: C901
10441059
) as old_worker:
10451060
old_worker.run_copy()
10461061
# Run pre-migration tasks
1047-
self._execute_tasks(
1048-
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
1049-
)
1062+
with Phase.use(Phase.MIGRATE):
1063+
self._execute_tasks(
1064+
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
1065+
)
10501066
# Create a Git tree object from the current (possibly dirty) index
10511067
# and keep the object reference.
10521068
with local.cwd(subproject_top):
@@ -1120,7 +1136,7 @@ def _apply_update(self) -> None: # noqa: C901
11201136
self._git_initialize_repo()
11211137
new_copy_head = git("rev-parse", "HEAD").strip()
11221138
# Extract diff between temporary destination and real destination
1123-
# with some special handling of newly added files in both the poject
1139+
# with some special handling of newly added files in both the project
11241140
# and the template.
11251141
with local.cwd(old_copy):
11261142
# Configure borrowing Git objects from the real destination and
@@ -1265,9 +1281,10 @@ def _apply_update(self) -> None: # noqa: C901
12651281
_remove_old_files(subproject_top, compared)
12661282

12671283
# Run post-migration tasks
1268-
self._execute_tasks(
1269-
self.template.migration_tasks("after", self.subproject.template) # type: ignore[arg-type]
1270-
)
1284+
with Phase.use(Phase.MIGRATE):
1285+
self._execute_tasks(
1286+
self.template.migration_tasks("after", self.subproject.template) # type: ignore[arg-type]
1287+
)
12711288

12721289
def _git_initialize_repo(self) -> None:
12731290
"""Initialize a git repository in the current directory."""

‎copier/types.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
"""Complex types, annotations, validators."""
22

3+
from __future__ import annotations
4+
5+
from contextlib import contextmanager
6+
from contextvars import ContextVar
7+
from enum import Enum
38
from pathlib import Path
49
from typing import (
510
Annotated,
611
Any,
712
Callable,
813
Dict,
14+
Iterator,
915
Literal,
1016
Mapping,
1117
MutableMapping,
@@ -75,3 +81,34 @@ def __getitem__(self, key: str) -> Any:
7581
if key not in self.done:
7682
self.done[key] = self.pending[key]()
7783
return self.done[key]
84+
85+
86+
class Phase(str, Enum):
87+
"""The known execution phases."""
88+
89+
PROMPT = "prompt"
90+
TASKS = "tasks"
91+
MIGRATE = "migrate"
92+
RENDER = "render"
93+
UNDEFINED = "undefined"
94+
95+
def __str__(self) -> str:
96+
return str(self.value)
97+
98+
@classmethod
99+
@contextmanager
100+
def use(cls, phase: Phase) -> Iterator[None]:
101+
"""Set the current phase for the duration of a context."""
102+
token = _phase.set(phase)
103+
try:
104+
yield
105+
finally:
106+
_phase.reset(token)
107+
108+
@classmethod
109+
def current(cls) -> Phase:
110+
"""Get the current phase."""
111+
return _phase.get()
112+
113+
114+
_phase: ContextVar[Phase] = ContextVar("phase", default=Phase.UNDEFINED)

‎copier/user_data.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import json
66
import warnings
77
from collections import ChainMap
8+
from collections.abc import Mapping, Sequence
89
from dataclasses import field
910
from datetime import datetime
1011
from functools import cached_property
1112
from hashlib import sha512
1213
from os import urandom
1314
from pathlib import Path
14-
from typing import Any, Callable, Literal, Mapping, Sequence
15+
from typing import Any, Callable, Literal
1516

1617
import yaml
1718
from jinja2 import UndefinedError
@@ -33,6 +34,7 @@
3334
AnyByStrMutableMapping,
3435
LazyDict,
3536
MissingType,
37+
Phase,
3638
StrOrPath,
3739
)
3840

@@ -464,7 +466,13 @@ def render_value(
464466
else value
465467
)
466468
try:
467-
return template.render({**self.answers.combined, **(extra_answers or {})})
469+
return template.render(
470+
{
471+
**self.answers.combined,
472+
**(extra_answers or {}),
473+
"_copier_phase": Phase.current(),
474+
}
475+
)
468476
except UndefinedError as error:
469477
raise UserMessageError(str(error)) from error
470478

‎docs/creating.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,16 @@ variable:
136136

137137
The name of the project root directory.
138138

139+
### `_copier_phase`
140+
141+
The current phase, one of `"prompt"`,`"tasks"`, `"migrate"` or `"render"`.
142+
143+
!!! note
144+
145+
There is also an additional `"undefined"` phase used when not in any phase.
146+
You may encounter this phase when rendering outside of those phases,
147+
when rendering lazily (and the phase notion can be irrelevant) or when testing.
148+
139149
## Variables (context-specific)
140150

141151
Some rendering contexts provide variables unique to them:

‎tests/test_answersfile.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,31 @@ def test_external_data_with_umlaut(
250250
copier.run_copy(str(src), dst, defaults=True, overwrite=True)
251251
answers = load_answersfile_data(dst, ".copier-answers.yml")
252252
assert answers["ext_umlaut"] == "äöü"
253+
254+
255+
def test_undefined_phase_in_external_data(
256+
tmp_path_factory: pytest.TempPathFactory,
257+
) -> None:
258+
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
259+
260+
build_file_tree(
261+
{
262+
(src / "copier.yml"): (
263+
"""\
264+
_external_data:
265+
data: '{{ _copier_phase }}.yml'
266+
key: "{{ _external_data.data.key }}"
267+
"""
268+
),
269+
(src / "{{ _copier_conf.answers_file }}.jinja"): (
270+
"{{ _copier_answers|to_nice_yaml }}"
271+
),
272+
}
273+
)
274+
git_save(src, tag="v1")
275+
276+
(dst / "undefined.yml").write_text("key: value")
277+
278+
copier.run_copy(str(src), dst, defaults=True, overwrite=True)
279+
answers = load_answersfile_data(dst, ".copier-answers.yml")
280+
assert answers["key"] == "value"

‎tests/test_answersfile_templating.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,21 @@ def test_answersfile_templating_with_message_before_copy(
118118
assert answers["module_name"] == "mymodule"
119119
assert (dst / "result.txt").exists()
120120
assert (dst / "result.txt").read_text() == "mymodule"
121+
122+
123+
def test_answersfile_templating_phase(tmp_path_factory: pytest.TempPathFactory) -> None:
124+
"""
125+
Ensure `_copier_phase` is available while render `answers_relpath`.
126+
Not because it is directly useful, but because some extensions might need it.
127+
"""
128+
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
129+
build_file_tree(
130+
{
131+
src / "copier.yml": """\
132+
_answers_file: ".copier-answers-{{ _copier_phase }}.yml"
133+
""",
134+
src / "{{ _copier_conf.answers_file }}.jinja": "",
135+
}
136+
)
137+
copier.run_copy(str(src), dst, overwrite=True, unsafe=True)
138+
assert (dst / ".copier-answers-render.yml").exists()

‎tests/test_copy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,3 +1041,11 @@ def test_templated_choices(tmp_path_factory: pytest.TempPathFactory, spec: str)
10411041
)
10421042
copier.run_copy(str(src), dst, data={"q": "two"})
10431043
assert yaml.safe_load((dst / "q.txt").read_text()) == "two"
1044+
1045+
1046+
def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
1047+
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
1048+
build_file_tree({src / "{{ _copier_phase }}.jinja": "{{ _copier_phase }}"})
1049+
copier.run_copy(str(src), dst)
1050+
assert (dst / "render").exists()
1051+
assert (dst / "render").read_text() == "render"

‎tests/test_migrations.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,32 @@ def test_migration_jinja_variables(
521521
assert f"{variable}={value}" in vars
522522
else:
523523
assert f"{variable}=" in vars
524+
525+
526+
def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
527+
"""Test that the Phase variable is properly set."""
528+
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
529+
530+
with local.cwd(src):
531+
build_file_tree(
532+
{
533+
**COPIER_ANSWERS_FILE,
534+
"copier.yml": (
535+
"""\
536+
_migrations:
537+
- touch {{ _copier_phase }}
538+
"""
539+
),
540+
}
541+
)
542+
git_save(tag="v1")
543+
with local.cwd(dst):
544+
run_copy(src_path=str(src))
545+
git_save()
546+
547+
with local.cwd(src):
548+
git("tag", "v2")
549+
with local.cwd(dst):
550+
run_update(defaults=True, overwrite=True, unsafe=True)
551+
552+
assert (dst / "migrate").is_file()

‎tests/test_tasks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,19 @@ def test_os_specific_tasks(
164164
monkeypatch.setattr("copier.main.OS", os)
165165
copier.run_copy(str(src), dst, unsafe=True)
166166
assert (dst / filename).exists()
167+
168+
169+
def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
170+
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
171+
build_file_tree(
172+
{
173+
(src / "copier.yml"): (
174+
"""
175+
_tasks:
176+
- touch {{ _copier_phase }}
177+
"""
178+
)
179+
}
180+
)
181+
copier.run_copy(str(src), dst, unsafe=True)
182+
assert (dst / "tasks").exists()

‎tests/test_templated_prompt.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,22 @@ def test_multiselect_choices_with_templated_default_value(
563563
"python_version": "3.11",
564564
"github_runner_python_version": ["3.11"],
565565
}
566+
567+
568+
def test_copier_phase_variable(
569+
tmp_path_factory: pytest.TempPathFactory,
570+
spawn: Spawn,
571+
) -> None:
572+
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
573+
build_file_tree(
574+
{
575+
(src / "copier.yml"): """\
576+
phase:
577+
type: str
578+
default: "{{ _copier_phase }}"
579+
"""
580+
}
581+
)
582+
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
583+
expect_prompt(tui, "phase", "str")
584+
tui.expect_exact("prompt")

0 commit comments

Comments
 (0)
Please sign in to comment.