Skip to content

Commit 538f364

Browse files
Merge pull request #1064 from dapomeroy/executable_when
Add `when` support to executable directive
2 parents f31dd19 + dd5a25a commit 538f364

File tree

9 files changed

+242
-47
lines changed

9 files changed

+242
-47
lines changed

lib/ramble/ramble/application.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -963,11 +963,27 @@ def _check_object_validators(self):
963963
else:
964964
logger.warn(err_msg)
965965

966+
def _get_filtered_executables(self) -> dict:
967+
"""Returns a dict of executables that satisfy `when` conditions"""
968+
filtered_executables = {}
969+
all_executables = self.executables.copy()
970+
for when_set, executables in all_executables.items():
971+
if self.expander.satisfies(when_set, variant_set=self.object_variants):
972+
for executable in executables:
973+
if executable in filtered_executables:
974+
logger.die(
975+
f"Executable {executable} is defined for overlapping `when` "
976+
"conditions. Ensure conditions are mutually exclusive."
977+
)
978+
filtered_executables.update(executables)
979+
980+
return filtered_executables
981+
966982
def _define_custom_executables(self):
967983
# Define custom executables
968984
if namespace.custom_executables in self.internals:
969985
for name, conf in self.internals[namespace.custom_executables].items():
970-
if name in self.executables or name in self.custom_executables:
986+
if name in self._get_filtered_executables() or name in self.custom_executables:
971987
experiment_namespace = self.expander.expand_var_name("experiment_namespace")
972988
raise ExecutableNameError(
973989
f"In experiment {experiment_namespace} "
@@ -996,11 +1012,18 @@ def _get_executable_graph(self, workload_name):
9961012
builtin_objects.append(obj)
9971013
all_builtins.append(builtins)
9981014

999-
all_executables = self.executables.copy()
1000-
all_executables.update(self.custom_executables)
1015+
filtered_executables = self._get_filtered_executables()
1016+
filtered_executables.update(self.custom_executables)
1017+
1018+
filtered_exec_order = []
1019+
for executable in exec_order:
1020+
if executable in filtered_executables or any(executable in b for b in all_builtins):
1021+
filtered_exec_order.append(executable)
1022+
else:
1023+
logger.debug(f"Skipping executable {executable}. `When` conditions not satisfied.")
10011024

10021025
executable_graph = ramble.graphs.ExecutableGraph(
1003-
exec_order, all_executables, builtin_objects, all_builtins, self
1026+
filtered_exec_order, filtered_executables, builtin_objects, all_builtins, self
10041027
)
10051028

10061029
# Perform executable injection
@@ -1693,9 +1716,10 @@ def _archive_experiments(self, workspace, app_inst=None):
16931716
# Copy all log files from executables
16941717
exec_logs = set()
16951718
workload = self.workloads[self.expander.workload_name]
1719+
filtered_executables = self._get_filtered_executables()
16961720
for exec_name in workload.executables:
1697-
if exec_name in self.executables:
1698-
exec_obj = self.executables[exec_name]
1721+
if exec_name in filtered_executables:
1722+
exec_obj = filtered_executables[exec_name]
16991723
exec_log = self.expander.expand_var(exec_obj.redirect)
17001724
exec_logs.add(exec_log)
17011725

lib/ramble/ramble/language/application_language.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _execute_workload_groups(app):
117117

118118

119119
@application_directive("executables")
120-
def executable(name, template, **kwargs):
120+
def executable(name, template, when=None, **kwargs):
121121
"""Adds an executable to this application
122122
123123
Defines a new executable that can be used to configure workloads and
@@ -141,12 +141,18 @@ def executable(name, template, **kwargs):
141141
both) to capture. Defaults to stdout
142142
run_in_background (bool): Optional, Declare if the command should
143143
run in the background. Defaults to False
144+
when (list | None): List of when conditions to apply to directive
144145
"""
145146

146147
def _execute_executable(app):
147148
from ramble.util.executable import CommandExecutable
148149

149-
app.executables[name] = CommandExecutable(name=name, template=template, **kwargs)
150+
when_list = ramble.language.language_helpers.build_when_list(when, app, name, "executable")
151+
when_set = frozenset(when_list)
152+
if when_set not in app.executables:
153+
app.executables[when_set] = {}
154+
155+
app.executables[when_set][name] = CommandExecutable(name=name, template=template, **kwargs)
150156

151157
return _execute_executable
152158

lib/ramble/ramble/language/language_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _wrapper(*args, **_kwargs):
238238
"modifier_variable",
239239
"package_manager_variable",
240240
"workflow_manager_variable",
241+
"executable",
241242
]:
242243
msg = (
243244
'directive "{0}" cannot be used within a "when"'

lib/ramble/ramble/language/language_helpers.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
# except according to those terms.
88

99
import fnmatch
10-
from typing import Any, List
10+
from collections import OrderedDict
11+
from typing import Any, List, Union
1112

1213
from ramble.language.language_base import DirectiveError
1314

@@ -73,16 +74,17 @@ def merge_definitions(
7374
single_type, multiple_type, single_arg_name, multiple_arg_name, directive_name
7475
)
7576

76-
all_types = []
77+
merged_types = []
7778

7879
if single_type:
79-
all_types.append(single_type)
80+
merged_types.append(single_type)
8081

8182
if multiple_type:
82-
expanded_multiple_type = expand_patterns(multiple_type, multiple_pattern_match)
83-
all_types.extend(expanded_multiple_type)
83+
merged_types.extend(multiple_type)
8484

85-
return all_types
85+
merged_types_expanded = expand_patterns(merged_types, multiple_pattern_match)
86+
87+
return merged_types_expanded
8688

8789

8890
def require_definition(
@@ -128,32 +130,52 @@ def require_definition(
128130
)
129131

130132

131-
def expand_patterns(multiple_type: list, multiple_pattern_match: list):
133+
def expand_patterns(merged_types: list, multiple_pattern_match: Union[list, dict]):
132134
"""Expand wildcard patterns within a list of names
133135
134136
This method takes an input list containing wildcard patterns and expands the
135137
wildcard with values matching a list of names. Returns a list containing
136138
matching names and any inputs with zero matches.
137139
140+
If multiple_pattern_match is a dict keyed on 'when', it checks the input
141+
against patterns in all 'when' conditions, without evaluating them, and
142+
returns a list containing names that match under any when condition, and
143+
any inputs with zero matches.
144+
138145
Args:
139-
multiple_types: List of strings for type names, may contain wildcards
140-
multiple_pattern_match: List of strings to match against patterns in multiple_type
146+
merged_types: List of strings for type names, may contain wildcards
147+
multiple_pattern_match: List of strings (optional: nested in when_set
148+
dict) to match against patterns in merged_types
141149
142150
Returns:
143151
List of expanded patterns matching the names list plus patterns
144-
not found in the names list.
152+
not found in the names list.
145153
"""
146-
147-
expanded_patterns = []
148-
for input in multiple_type:
149-
matched_inputs = fnmatch.filter(multiple_pattern_match, input)
150-
if matched_inputs:
151-
for matching_name in matched_inputs:
152-
expanded_patterns.append(matching_name)
154+
expanded_patterns = OrderedDict()
155+
for input in merged_types:
156+
expanded = False
157+
if (
158+
multiple_pattern_match
159+
and isinstance(multiple_pattern_match, dict)
160+
and isinstance(next(iter(multiple_pattern_match)), frozenset)
161+
):
162+
for _, pattern_list in multiple_pattern_match.items():
163+
matched_inputs = fnmatch.filter(pattern_list, input)
164+
if matched_inputs:
165+
expanded = True
166+
for match in matched_inputs:
167+
expanded_patterns[match] = ""
153168
else:
154-
expanded_patterns.append(input)
169+
matched_inputs = fnmatch.filter(multiple_pattern_match, input)
170+
if matched_inputs:
171+
expanded = True
172+
for match in matched_inputs:
173+
expanded_patterns[match] = ""
174+
175+
if not expanded:
176+
expanded_patterns[input] = ""
155177

156-
return expanded_patterns
178+
return list(expanded_patterns.keys())
157179

158180

159181
def build_when_list(

lib/ramble/ramble/test/application.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"mutable_config", "mutable_mock_workspace_path", "mutable_mock_apps_repo"
1717
)
1818

19+
_FS = frozenset()
20+
1921

2022
def basic_exp_dict():
2123
"""To set expander consistently with test_wl2 of builtin.mock/applications/basic"""
@@ -68,15 +70,15 @@ def test_basic_app(mutable_mock_apps_repo):
6870
assert len(basic_inst.workloads["test_wl"].executables) == 1
6971
foo_exec = basic_inst.workloads["test_wl"].find_executable("foo")
7072
assert foo_exec is not None
71-
foo_exec = basic_inst.executables[foo_exec]
73+
foo_exec = basic_inst.executables[_FS][foo_exec]
7274
assert foo_exec.template == ["bar"]
7375
assert not foo_exec.mpi
7476

7577
assert len(basic_inst.workloads["test_wl"].inputs) == 1
7678
example_input = basic_inst.workloads["test_wl"].find_input("input")
7779
assert example_input is not None
7880

79-
assert len(basic_inst.workloads["test_wl"].variables[frozenset()]) == 2
81+
assert len(basic_inst.workloads["test_wl"].variables[_FS]) == 2
8082
possible_vars = basic_inst.workloads["test_wl"].find_variable("my_var")
8183
assert len(possible_vars) == 1
8284
assert possible_vars[0].default == "1.0"
@@ -86,7 +88,7 @@ def test_basic_app(mutable_mock_apps_repo):
8688
assert len(basic_inst.workloads["test_wl2"].executables) == 1
8789
bar_exec = basic_inst.workloads["test_wl2"].find_executable("bar")
8890
assert bar_exec is not None
89-
bar_exec = basic_inst.executables[bar_exec]
91+
bar_exec = basic_inst.executables[_FS][bar_exec]
9092
assert bar_exec.template == ["baz"]
9193
assert bar_exec.mpi
9294

@@ -104,8 +106,8 @@ def test_basic_app(mutable_mock_apps_repo):
104106
assert exec_graph.get_node("bar") is not None
105107
assert exec_graph.get_node("builtin::env_vars") is not None
106108

107-
assert "test_fom" in basic_inst.figures_of_merit[frozenset()][frozenset()]
108-
fom_conf = basic_inst.figures_of_merit[frozenset()][frozenset()]["test_fom"]
109+
assert "test_fom" in basic_inst.figures_of_merit[_FS][_FS]
110+
fom_conf = basic_inst.figures_of_merit[_FS][_FS]["test_fom"]
109111
assert fom_conf["log_file"] == "{log_file}"
110112
assert fom_conf["regex"] == r"(?P<test>[0-9]+\.[0-9]+).*seconds.*" # noqa: W605
111113
assert fom_conf["group_name"] == "test"
@@ -192,7 +194,7 @@ def test_required_builtins(mutable_mock_apps_repo, app):
192194
app_inst.define_variable("application_name", app)
193195

194196
required_builtins = []
195-
for builtin, conf in app_inst.builtins[frozenset()].items():
197+
for builtin, conf in app_inst.builtins[_FS].items():
196198
if conf[app_inst._builtin_required_key]:
197199
required_builtins.append(builtin)
198200

@@ -211,7 +213,7 @@ def test_register_builtin_app(mutable_mock_apps_repo):
211213

212214
required_builtins = []
213215
excluded_builtins = []
214-
for builtin, conf in app_inst.builtins[frozenset()].items():
216+
for builtin, conf in app_inst.builtins[_FS].items():
215217
if conf[app_inst._builtin_required_key]:
216218
required_builtins.append(builtin)
217219
else:

lib/ramble/ramble/test/application_inheritance.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,21 @@
88

99
from ramble.test.application import basic_exp_dict
1010

11+
_FS = frozenset()
12+
1113

1214
def test_basic_inheritance(mutable_mock_apps_repo):
1315
app_inst = mutable_mock_apps_repo.get("basic-inherited")
1416
exp_dict = basic_exp_dict()
1517
app_inst.set_variables(exp_dict, None)
1618
app_inst.define_variable("application_name", "basic-inherited")
1719

18-
assert "foo" in app_inst.executables
19-
assert app_inst.executables["foo"].template == ["bar"]
20-
assert not app_inst.executables["foo"].mpi
21-
assert "bar" in app_inst.executables
22-
assert app_inst.executables["bar"].template == ["baz"]
23-
assert app_inst.executables["bar"].mpi
20+
assert "foo" in app_inst.executables[_FS]
21+
assert app_inst.executables[_FS]["foo"].template == ["bar"]
22+
assert not app_inst.executables[_FS]["foo"].mpi
23+
assert "bar" in app_inst.executables[_FS]
24+
assert app_inst.executables[_FS]["bar"].template == ["baz"]
25+
assert app_inst.executables[_FS]["bar"].mpi
2426

2527
assert "test_wl" in app_inst.workloads
2628
assert app_inst.workloads["test_wl"].executables == ["foo"]
@@ -49,8 +51,8 @@ def test_basic_inheritance(mutable_mock_apps_repo):
4951
assert exec_graph.get_node("foo") is not None
5052
assert exec_graph.get_node("builtin::env_vars") is not None
5153

52-
assert "test_fom" in app_inst.figures_of_merit[frozenset()][frozenset()]
53-
fom_conf = app_inst.figures_of_merit[frozenset()][frozenset()]["test_fom"]
54+
assert "test_fom" in app_inst.figures_of_merit[_FS][_FS]
55+
fom_conf = app_inst.figures_of_merit[_FS][_FS]["test_fom"]
5456
assert fom_conf["log_file"] == "{log_file}"
5557
assert fom_conf["regex"] == r"(?P<test>[0-9]+\.[0-9]+).*seconds.*" # noqa: W605
5658
assert fom_conf["group_name"] == "test"

lib/ramble/ramble/test/application_language.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
ExecutableApplication, # noqa: F405
1818
]
1919

20+
_FS = frozenset()
21+
2022

2123
@deprecation.fail_if_not_removed
2224
@pytest.mark.parametrize("app_class", app_types)
@@ -257,18 +259,16 @@ def test_executable_directive(app_class):
257259

258260
assert hasattr(app_inst, "executables")
259261
for exe_name, conf in test_defs.items():
260-
assert exe_name in app_inst.executables
262+
assert exe_name in app_inst.executables[_FS]
261263
for conf_name, conf_val in conf.items():
262-
assert hasattr(app_inst.executables[exe_name], conf_name)
263-
assert conf_val == getattr(app_inst.executables[exe_name], conf_name)
264+
assert hasattr(app_inst.executables[_FS][exe_name], conf_name)
265+
assert conf_val == getattr(app_inst.executables[_FS][exe_name], conf_name)
264266

265267

266268
@pytest.mark.parametrize("app_class", app_types)
267269
def test_figure_of_merit_directive(app_class):
268270
test_defs = {}
269271

270-
_FS = frozenset()
271-
272272
app_inst = app_class("/not/a/path")
273273
test_defs.update(add_figure_of_merit(app_inst))
274274

0 commit comments

Comments
 (0)