Skip to content

Commit 6066903

Browse files
Merge pull request #994 from dapomeroy/register_phase_when
Add 'when' to register_phase directive
2 parents 48d0ddf + f435b7f commit 6066903

File tree

6 files changed

+114
-5
lines changed

6 files changed

+114
-5
lines changed

lib/ramble/ramble/application.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,11 @@ def get_pipeline_phases(self, pipeline, phase_filters=None):
461461
final_added_index = None
462462
if pipeline in self._pipeline_graphs:
463463
for idx, phase in enumerate(self._pipeline_graphs[pipeline].walk()):
464-
for phase_filter in phase_filters:
465-
if fnmatch.fnmatch(phase.key, phase_filter):
466-
phases.add(phase)
467-
final_added_index = idx
464+
if self.expander.satisfies(phase.when, variant_set=self.object_variants):
465+
for phase_filter in phase_filters:
466+
if fnmatch.fnmatch(phase.key, phase_filter):
467+
phases.add(phase)
468+
final_added_index = idx
468469

469470
include_phase_deps = ramble.config.get("config:include_phase_dependencies")
470471
if include_phase_deps:

lib/ramble/ramble/language/language_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def _wrapper(*args, **_kwargs):
213213
"required_package",
214214
"define_compiler",
215215
"package_manager_config",
216+
"register_phase",
216217
]:
217218
msg = (
218219
'directive "{0}" cannot be used within a "when"'

lib/ramble/ramble/language/shared_language.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def define_compiler(
138138
compiler (str): Package name to use for compilation
139139
package_manager (str): Glob supported pattern to match package managers
140140
this compiler applies to
141+
when (list | None): List of when conditions to apply to directive
141142
"""
142143

143144
def _execute_define_compiler(obj):
@@ -184,6 +185,7 @@ def software_spec(
184185
compiler (str): Package name to use as compiler for compiling this package
185186
package_manager (str): Glob supported pattern to match package managers
186187
this package applies to
188+
when (list | None): List of when conditions to apply to directive
187189
"""
188190

189191
def _execute_software_spec(obj):
@@ -221,6 +223,7 @@ def package_manager_config(name, config, package_manager=None, when=None, **kwar
221223
name (str): Name of this configuration
222224
config (str): Configuration option to set
223225
package_manager (str): Name of the package manager this config should be used with
226+
when (list | None): List of when conditions to apply to directive
224227
"""
225228

226229
def _execute_package_manager_config(obj):
@@ -251,6 +254,7 @@ def required_package(name, package_manager=None, when=None, **kwargs):
251254
Args:
252255
name (str): Name of required package
253256
package_manager (str): Glob package manager name to apply this required package to
257+
when (list | None): List of when conditions to apply to directive
254258
"""
255259

256260
def _execute_required_package(obj):
@@ -419,7 +423,7 @@ def _store_builtin(obj):
419423

420424

421425
@shared_directive("phase_definitions")
422-
def register_phase(name, pipeline=None, run_before=None, run_after=None, **kwargs):
426+
def register_phase(name, pipeline=None, run_before=None, run_after=None, when=None, **kwargs):
423427
"""Register a phase
424428
425429
Phases are portions of a pipeline that will execute when
@@ -436,6 +440,7 @@ def register_phase(name, pipeline=None, run_before=None, run_after=None, **kwarg
436440
pipeline (str): The name of the pipeline this phase should be registered into.
437441
run_before (list(str) | None): A list of phase names this phase should run before
438442
run_after (list(str) | None): A list of phase names this phase should run after
443+
when (list | None): List of when conditions to apply to directive
439444
"""
440445
if run_before is None:
441446
run_before = []
@@ -445,6 +450,10 @@ def register_phase(name, pipeline=None, run_before=None, run_after=None, **kwarg
445450
def _execute_register_phase(obj):
446451
import ramble.util.graph
447452

453+
when_list = ramble.language.language_helpers.build_when_list(
454+
when, obj, name, "register_phase"
455+
)
456+
448457
if pipeline not in obj._pipelines:
449458
raise ramble.language.language_base.DirectiveError(
450459
"Directive register_phase was "
@@ -490,6 +499,8 @@ def _execute_register_phase(obj):
490499
for after in run_after:
491500
phase_node.order_after(after)
492501

502+
phase_node.when = when_list
503+
493504
obj.phase_definitions[pipeline][name] = phase_node
494505

495506
return _execute_register_phase

lib/ramble/ramble/test/when.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2022-2025 The Ramble Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
# option. This file may not be copied, modified, or distributed
7+
# except according to those terms.
8+
9+
import pytest
10+
11+
import ramble.workspace
12+
from ramble.main import RambleCommand
13+
14+
pytestmark = pytest.mark.usefixtures("mutable_mock_workspace_path", "mutable_mock_apps_repo")
15+
16+
config = RambleCommand("config")
17+
workspace = RambleCommand("workspace")
18+
19+
20+
def test_register_phase_when(request):
21+
ws_name = request.node.name
22+
23+
global_args = ["-w", ws_name]
24+
25+
with ramble.workspace.create(ws_name) as ws:
26+
workspace(
27+
"manage",
28+
"experiments",
29+
"when-directives",
30+
"--wf",
31+
"test_wl",
32+
"-v",
33+
"n_ranks=1",
34+
"-v",
35+
"n_nodes=1",
36+
"-v",
37+
"processes_per_node=1",
38+
global_args=global_args,
39+
)
40+
41+
config("add", "variants:register_phase_when:true", global_args=global_args)
42+
43+
ws._re_read()
44+
output = workspace("setup", "--dry-run", global_args=global_args)
45+
46+
assert "Test Phase" in output
47+
48+
config("remove", "variants:register_phase_when:true", global_args=global_args)
49+
config("add", "variants:register_phase_when:false", global_args=global_args)
50+
51+
ws._re_read()
52+
output = workspace("setup", "--dry-run", global_args=global_args)
53+
54+
assert "Test Phase" not in output

lib/ramble/ramble/util/graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, key, attribute=None, obj_inst=None):
2828
self._order_before = []
2929
self._order_after = []
3030
self.obj_inst = obj_inst
31+
self.when = []
3132

3233
def set_attribute(self, attr):
3334
"""Sets the attribute of a graph node
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2022-2025 The Ramble Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6+
# option. This file may not be copied, modified, or distributed
7+
# except according to those terms.
8+
9+
from ramble.appkit import *
10+
11+
12+
class WhenDirectives(ExecutableApplication):
13+
name = "when-directives"
14+
15+
executable("test_exec", "echo '{test_variable}'", use_mpi=False)
16+
17+
workload("test_wl", executable="test_exec")
18+
19+
with default_args(workload="test_wl"):
20+
workload_variable(
21+
"test_variable",
22+
default="Test",
23+
description="Variable to print for testing",
24+
)
25+
26+
variant(
27+
"register_phase_when",
28+
default=False,
29+
values=[True, False],
30+
description="Register additional phase using when",
31+
)
32+
33+
with when("+register_phase_when"):
34+
register_phase(
35+
"test_phase",
36+
pipeline="setup",
37+
run_before=["get_inputs"],
38+
)
39+
40+
def _test_phase(self, workspace, app_inst):
41+
print("Test Phase")

0 commit comments

Comments
 (0)