Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
AnyByStrMutableMapping,
JSONSerializable,
LazyDict,
Phase,
RelativePath,
StrOrPath,
)
Expand Down Expand Up @@ -277,10 +278,18 @@ def _external_data(self) -> LazyDict:
Files will only be parsed lazily on 1st access. This helps avoiding
circular dependencies when the file name also comes from a variable.
"""

def _render(path: str) -> str:
with Phase.use(Phase.UNDEFINED):
return self._render_string(path)

# Given those values are lazily rendered on 1st access then cached
# the phase value is irrelevant and could be misleading.
# As a consequence it is explicitely set to "undefined".
return LazyDict(
**{
name: lambda path=path: load_answersfile_data(
self.dst_path, self._render_string(path)
self.dst_path, _render(path)
)
for name, path in self.template.external_data.items()
}
Expand Down Expand Up @@ -375,6 +384,7 @@ def _render_context(self) -> AnyByStrMutableMapping:
_copier_conf=conf,
_folder_name=self.subproject.local_abspath.name,
_copier_python=sys.executable,
_copier_phase=Phase.current(),
)

def _path_matcher(self, patterns: Iterable[str]) -> Callable[[Path], bool]:
Expand Down Expand Up @@ -560,7 +570,9 @@ def answers_relpath(self) -> Path:
"""
path = self.answers_file or self.template.answers_relpath
template = self.jinja_env.from_string(str(path))
return Path(template.render(**self.answers.combined))
return Path(
template.render(_copier_phase=Phase.current(), **self.answers.combined)
)

@cached_property
def all_exclusions(self) -> Sequence[str]:
Expand Down Expand Up @@ -928,7 +940,8 @@ def run_copy(self) -> None:
"""
self._check_unsafe("copy")
self._print_message(self.template.message_before_copy)
self._ask()
with Phase.use(Phase.PROMPT):
self._ask()
was_existing = self.subproject.local_abspath.exists()
try:
if not self.quiet:
Expand All @@ -937,12 +950,14 @@ def run_copy(self) -> None:
f"\nCopying from template version {self.template.version}",
file=sys.stderr,
)
self._render_template()
with Phase.use(Phase.RENDER):
self._render_template()
if not self.quiet:
# TODO Unify printing tools
print("") # padding space
if not self.skip_tasks:
self._execute_tasks(self.template.tasks)
with Phase.use(Phase.TASKS):
self._execute_tasks(self.template.tasks)
except Exception:
if not was_existing and self.cleanup_on_error:
rmtree(self.subproject.local_abspath)
Expand Down Expand Up @@ -1044,9 +1059,10 @@ def _apply_update(self) -> None: # noqa: C901
) as old_worker:
old_worker.run_copy()
# Run pre-migration tasks
self._execute_tasks(
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
)
with Phase.use(Phase.MIGRATE):
self._execute_tasks(
self.template.migration_tasks("before", self.subproject.template) # type: ignore[arg-type]
)
# Create a Git tree object from the current (possibly dirty) index
# and keep the object reference.
with local.cwd(subproject_top):
Expand Down Expand Up @@ -1120,7 +1136,7 @@ def _apply_update(self) -> None: # noqa: C901
self._git_initialize_repo()
new_copy_head = git("rev-parse", "HEAD").strip()
# Extract diff between temporary destination and real destination
# with some special handling of newly added files in both the poject
# with some special handling of newly added files in both the project
# and the template.
with local.cwd(old_copy):
# Configure borrowing Git objects from the real destination and
Expand Down Expand Up @@ -1265,9 +1281,10 @@ def _apply_update(self) -> None: # noqa: C901
_remove_old_files(subproject_top, compared)

# Run post-migration tasks
self._execute_tasks(
self.template.migration_tasks("after", self.subproject.template) # type: ignore[arg-type]
)
with Phase.use(Phase.MIGRATE):
self._execute_tasks(
self.template.migration_tasks("after", self.subproject.template) # type: ignore[arg-type]
)

def _git_initialize_repo(self) -> None:
"""Initialize a git repository in the current directory."""
Expand Down
37 changes: 37 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Complex types, annotations, validators."""

from __future__ import annotations

from contextlib import contextmanager
from contextvars import ContextVar
from enum import Enum
from pathlib import Path
from typing import (
Annotated,
Any,
Callable,
Dict,
Iterator,
Literal,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -75,3 +81,34 @@ def __getitem__(self, key: str) -> Any:
if key not in self.done:
self.done[key] = self.pending[key]()
return self.done[key]


class Phase(str, Enum):
"""The known execution phases."""

PROMPT = "prompt"
TASKS = "tasks"
MIGRATE = "migrate"
RENDER = "render"
UNDEFINED = "undefined"

def __str__(self) -> str:
return str(self.value)

@classmethod
@contextmanager
def use(cls, phase: Phase) -> Iterator[None]:
"""Set the current phase for the duration of a context."""
token = _phase.set(phase)
try:
yield
finally:
_phase.reset(token)

@classmethod
def current(cls) -> Phase:
"""Get the current phase."""
return _phase.get()


_phase: ContextVar[Phase] = ContextVar("phase", default=Phase.UNDEFINED)
12 changes: 10 additions & 2 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import json
import warnings
from collections import ChainMap
from collections.abc import Mapping, Sequence
from dataclasses import field
from datetime import datetime
from functools import cached_property
from hashlib import sha512
from os import urandom
from pathlib import Path
from typing import Any, Callable, Literal, Mapping, Sequence
from typing import Any, Callable, Literal

import yaml
from jinja2 import UndefinedError
Expand All @@ -33,6 +34,7 @@
AnyByStrMutableMapping,
LazyDict,
MissingType,
Phase,
StrOrPath,
)

Expand Down Expand Up @@ -464,7 +466,13 @@ def render_value(
else value
)
try:
return template.render({**self.answers.combined, **(extra_answers or {})})
return template.render(
{
**self.answers.combined,
**(extra_answers or {}),
"_copier_phase": Phase.current(),
}
)
except UndefinedError as error:
raise UserMessageError(str(error)) from error

Expand Down
10 changes: 10 additions & 0 deletions docs/creating.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ variable:

The name of the project root directory.

### `_copier_phase`

The current phase, one of `"prompt"`,`"tasks"`, `"migrate"` or `"render"`.

!!! note

There is also an additional `"undefined"` phase used when not in any phase.
You may encounter this phase when rendering outside of those phases,
when rendering lazily (and the phase notion can be irrelevant) or when testing.

## Variables (context-specific)

Some rendering contexts provide variables unique to them:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_answersfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,31 @@ def test_external_data_with_umlaut(
copier.run_copy(str(src), dst, defaults=True, overwrite=True)
answers = load_answersfile_data(dst, ".copier-answers.yml")
assert answers["ext_umlaut"] == "äöü"


def test_undefined_phase_in_external_data(
tmp_path_factory: pytest.TempPathFactory,
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

build_file_tree(
{
(src / "copier.yml"): (
"""\
_external_data:
data: '{{ _copier_phase }}.yml'
key: "{{ _external_data.data.key }}"
"""
),
(src / "{{ _copier_conf.answers_file }}.jinja"): (
"{{ _copier_answers|to_nice_yaml }}"
),
}
)
git_save(src, tag="v1")

(dst / "undefined.yml").write_text("key: value")

copier.run_copy(str(src), dst, defaults=True, overwrite=True)
answers = load_answersfile_data(dst, ".copier-answers.yml")
assert answers["key"] == "value"
18 changes: 18 additions & 0 deletions tests/test_answersfile_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,21 @@ def test_answersfile_templating_with_message_before_copy(
assert answers["module_name"] == "mymodule"
assert (dst / "result.txt").exists()
assert (dst / "result.txt").read_text() == "mymodule"


def test_answersfile_templating_phase(tmp_path_factory: pytest.TempPathFactory) -> None:
"""
Ensure `_copier_phase` is available while render `answers_relpath`.
Not because it is directly useful, but because some extensions might need it.
"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
src / "copier.yml": """\
_answers_file: ".copier-answers-{{ _copier_phase }}.yml"
""",
src / "{{ _copier_conf.answers_file }}.jinja": "",
}
)
copier.run_copy(str(src), dst, overwrite=True, unsafe=True)
assert (dst / ".copier-answers-render.yml").exists()
8 changes: 8 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,11 @@ def test_templated_choices(tmp_path_factory: pytest.TempPathFactory, spec: str)
)
copier.run_copy(str(src), dst, data={"q": "two"})
assert yaml.safe_load((dst / "q.txt").read_text()) == "two"


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree({src / "{{ _copier_phase }}.jinja": "{{ _copier_phase }}"})
copier.run_copy(str(src), dst)
assert (dst / "render").exists()
assert (dst / "render").read_text() == "render"
29 changes: 29 additions & 0 deletions tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,32 @@ def test_migration_jinja_variables(
assert f"{variable}={value}" in vars
else:
assert f"{variable}=" in vars


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
"""Test that the Phase variable is properly set."""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

with local.cwd(src):
build_file_tree(
{
**COPIER_ANSWERS_FILE,
"copier.yml": (
"""\
_migrations:
- touch {{ _copier_phase }}
"""
),
}
)
git_save(tag="v1")
with local.cwd(dst):
run_copy(src_path=str(src))
git_save()

with local.cwd(src):
git("tag", "v2")
with local.cwd(dst):
run_update(defaults=True, overwrite=True, unsafe=True)

assert (dst / "migrate").is_file()
16 changes: 16 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,19 @@ def test_os_specific_tasks(
monkeypatch.setattr("copier.main.OS", os)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / filename).exists()


def test_copier_phase_variable(tmp_path_factory: pytest.TempPathFactory) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): (
"""
_tasks:
- touch {{ _copier_phase }}
"""
)
}
)
copier.run_copy(str(src), dst, unsafe=True)
assert (dst / "tasks").exists()
19 changes: 19 additions & 0 deletions tests/test_templated_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,22 @@ def test_multiselect_choices_with_templated_default_value(
"python_version": "3.11",
"github_runner_python_version": ["3.11"],
}


def test_copier_phase_variable(
tmp_path_factory: pytest.TempPathFactory,
spawn: Spawn,
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): """\
phase:
type: str
default: "{{ _copier_phase }}"
"""
}
)
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
expect_prompt(tui, "phase", "str")
tui.expect_exact("prompt")
Loading