7
7
import subprocess
8
8
import sys
9
9
from contextlib import suppress
10
+ from contextvars import ContextVar
10
11
from dataclasses import asdict , field , replace
11
12
from filecmp import dircmp
12
- from functools import cached_property , partial
13
+ from functools import cached_property , partial , wraps
13
14
from itertools import chain
14
15
from pathlib import Path
15
16
from shutil import rmtree
65
66
AnyByStrMutableMapping ,
66
67
JSONSerializable ,
67
68
LazyDict ,
69
+ Operation ,
70
+ ParamSpec ,
68
71
Phase ,
69
72
RelativePath ,
70
73
StrOrPath ,
73
76
from .vcs import get_git
74
77
75
78
_T = TypeVar ("_T" )
79
+ _P = ParamSpec ("_P" )
80
+
81
+ _operation : ContextVar [Operation ] = ContextVar ("_operation" )
82
+
83
+
84
+ def as_operation (value : Operation ) -> Callable [[Callable [_P , _T ]], Callable [_P , _T ]]:
85
+ """Decorator to set the current operation context, if not defined already.
86
+
87
+ This value is used to template specific configuration options.
88
+ """
89
+
90
+ def _decorator (func : Callable [_P , _T ]) -> Callable [_P , _T ]:
91
+ @wraps (func )
92
+ def _wrapper (* args : _P .args , ** kwargs : _P .kwargs ) -> _T :
93
+ token = _operation .set (_operation .get (value ))
94
+ try :
95
+ return func (* args , ** kwargs )
96
+ finally :
97
+ _operation .reset (token )
98
+
99
+ return _wrapper
100
+
101
+ return _decorator
76
102
77
103
78
104
@dataclass (config = ConfigDict (extra = "forbid" ))
@@ -248,7 +274,7 @@ def _cleanup(self) -> None:
248
274
for method in self ._cleanup_hooks :
249
275
method ()
250
276
251
- def _check_unsafe (self , mode : Literal [ "copy" , "update" ] ) -> None :
277
+ def _check_unsafe (self , mode : Operation ) -> None :
252
278
"""Check whether a template uses unsafe features."""
253
279
if self .unsafe or self .settings .is_trusted (self .template .url ):
254
280
return
@@ -327,8 +353,10 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
327
353
Arguments:
328
354
tasks: The list of tasks to run.
329
355
"""
356
+ operation = _operation .get ()
330
357
for i , task in enumerate (tasks ):
331
358
extra_context = {f"_{ k } " : v for k , v in task .extra_vars .items ()}
359
+ extra_context ["_copier_operation" ] = operation
332
360
333
361
if not cast_to_bool (self ._render_value (task .condition , extra_context )):
334
362
continue
@@ -358,7 +386,7 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
358
386
/ Path (self ._render_string (str (task .working_directory ), extra_context ))
359
387
).absolute ()
360
388
361
- extra_env = {k .upper (): str (v ) for k , v in task . extra_vars .items ()}
389
+ extra_env = {k [ 1 :] .upper (): str (v ) for k , v in extra_context .items ()}
362
390
with local .cwd (working_directory ), local .env (** extra_env ):
363
391
subprocess .run (task_cmd , shell = use_shell , check = True , env = local .env )
364
392
@@ -625,7 +653,14 @@ def _pathjoin(
625
653
@cached_property
626
654
def match_exclude (self ) -> Callable [[Path ], bool ]:
627
655
"""Get a callable to match paths against all exclusions."""
628
- return self ._path_matcher (self .all_exclusions )
656
+ # Include the current operation in the rendering context.
657
+ # Note: This method is a cached property, it needs to be regenerated
658
+ # when reusing an instance in different contexts.
659
+ extra_context = {"_copier_operation" : _operation .get ()}
660
+ return self ._path_matcher (
661
+ self ._render_string (exclusion , extra_context = extra_context )
662
+ for exclusion in self .all_exclusions
663
+ )
629
664
630
665
@cached_property
631
666
def match_skip (self ) -> Callable [[Path ], bool ]:
@@ -928,6 +963,7 @@ def template_copy_root(self) -> Path:
928
963
return self .template .local_abspath / subdir
929
964
930
965
# Main operations
966
+ @as_operation ("copy" )
931
967
def run_copy (self ) -> None :
932
968
"""Generate a subproject from zero, ignoring what was in the folder.
933
969
@@ -938,6 +974,11 @@ def run_copy(self) -> None:
938
974
939
975
See [generating a project][generating-a-project].
940
976
"""
977
+ with suppress (AttributeError ):
978
+ # We might have switched operation context, ensure the cached property
979
+ # is regenerated to re-render templates.
980
+ del self .match_exclude
981
+
941
982
self ._check_unsafe ("copy" )
942
983
self ._print_message (self .template .message_before_copy )
943
984
with Phase .use (Phase .PROMPT ):
@@ -967,6 +1008,7 @@ def run_copy(self) -> None:
967
1008
# TODO Unify printing tools
968
1009
print ("" ) # padding space
969
1010
1011
+ @as_operation ("copy" )
970
1012
def run_recopy (self ) -> None :
971
1013
"""Update a subproject, keeping answers but discarding evolution."""
972
1014
if self .subproject .template is None :
@@ -977,6 +1019,7 @@ def run_recopy(self) -> None:
977
1019
with replace (self , src_path = self .subproject .template .url ) as new_worker :
978
1020
new_worker .run_copy ()
979
1021
1022
+ @as_operation ("update" )
980
1023
def run_update (self ) -> None :
981
1024
"""Update a subproject that was already generated.
982
1025
@@ -1024,6 +1067,11 @@ def run_update(self) -> None:
1024
1067
print (
1025
1068
f"Updating to template version { self .template .version } " , file = sys .stderr
1026
1069
)
1070
+ with suppress (AttributeError ):
1071
+ # We might have switched operation context, ensure the cached property
1072
+ # is regenerated to re-render templates.
1073
+ del self .match_exclude
1074
+
1027
1075
self ._apply_update ()
1028
1076
self ._print_message (self .template .message_after_update )
1029
1077
0 commit comments