Skip to content

Commit c9140c6

Browse files
committed
Add 'when' to input_file directive
1 parent 538f364 commit c9140c6

File tree

9 files changed

+142
-23
lines changed

9 files changed

+142
-23
lines changed

lib/ramble/ramble/application.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,19 +1349,41 @@ def _inputs_and_fetchers(self, workload=None):
13491349

13501350
workload_names = [workload] if workload else self.workloads.keys()
13511351

1352+
# Batch 'when' evaluation to avoid repeat expander calls
1353+
when_satisfied = set()
1354+
for when_set in self.inputs.keys():
1355+
if self.expander.satisfies(when_set, variant_set=self.object_variants):
1356+
when_satisfied.add(when_set)
1357+
13521358
inputs = {}
13531359
for workload_name in workload_names:
13541360
workload = self.workloads[workload_name]
13551361

13561362
for input_file in workload.inputs:
1357-
if input_file not in self.inputs:
1363+
inputs_found = 0
1364+
active_inputs = 0
1365+
input_conf = {}
1366+
for when_set, app_inputs in self.inputs.items():
1367+
if input_file in app_inputs:
1368+
inputs_found += 1
1369+
if when_set in when_satisfied:
1370+
active_inputs += 1
1371+
input_conf = app_inputs[input_file].copy()
1372+
1373+
if not inputs_found:
13581374
logger.die(
13591375
f"Workload {workload_name} references a non-existent input file "
13601376
f"{input_file}.\n"
13611377
f"Make sure this input file is defined before using it in a workload."
13621378
)
1363-
1364-
input_conf = self.inputs[input_file].copy()
1379+
if active_inputs == 0:
1380+
logger.debug(f"Skipping input {input_file}. `When` conditions not satisfied.")
1381+
continue
1382+
elif active_inputs > 1:
1383+
logger.die(
1384+
f"Input files {input_file} are defined with overlapping 'when' "
1385+
f"conditions. Make sure that conditions are mutually exclusive."
1386+
)
13651387

13661388
# Expand input value as it may be a var
13671389
expanded_url = self.expander.expand_var(input_conf["url"])

lib/ramble/ramble/language/application_language.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def input_file(
166166
sha256=None,
167167
extension=None,
168168
expand=True,
169+
when=None,
169170
**kwargs,
170171
):
171172
"""Adds an input file definition to this application
@@ -184,16 +185,23 @@ def input_file(
184185
extension (str): Optiona, the extension to use for the input, if it isn't part of the
185186
file name.
186187
expand (bool): Optional. Whether the input should be expanded or not. Defaults to True
188+
when (list | None): List of when conditions to apply to directive
187189
"""
188190

189191
def _execute_input_file(app):
190-
app.inputs[name] = {
192+
when_list = ramble.language.language_helpers.build_when_list(when, app, name, "input_file")
193+
when_set = frozenset(when_list)
194+
if when_set not in app.inputs:
195+
app.inputs[when_set] = {}
196+
197+
app.inputs[when_set][name] = {
191198
"url": url,
192199
"description": description,
193200
"target_dir": target_dir,
194201
"sha256": sha256,
195202
"extension": extension,
196203
"expand": expand,
204+
"when": when_list,
197205
}
198206

199207
return _execute_input_file

lib/ramble/ramble/language/language_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def _wrapper(*args, **_kwargs):
239239
"package_manager_variable",
240240
"workflow_manager_variable",
241241
"executable",
242+
"input_file",
242243
]:
243244
msg = (
244245
'directive "{0}" cannot be used within a "when"'

lib/ramble/ramble/test/application.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def test_basic_app(mutable_mock_apps_repo):
113113
assert fom_conf["group_name"] == "test"
114114
assert fom_conf["units"] == "s"
115115

116-
assert "input" in basic_inst.inputs
117-
assert basic_inst.inputs["input"]["url"] == "file:///tmp/test_file.log"
118-
assert basic_inst.inputs["input"]["description"] == "Not a file"
116+
assert "input" in basic_inst.inputs[_FS]
117+
assert basic_inst.inputs[_FS]["input"]["url"] == "file:///tmp/test_file.log"
118+
assert basic_inst.inputs[_FS]["input"]["description"] == "Not a file"
119119

120120

121121
@pytest.mark.parametrize("app_name", ["basic", "zlib"])
@@ -399,7 +399,7 @@ def test_set_default_experiment_variables(mutable_mock_apps_repo):
399399

400400
executable_application_instance.internals = {}
401401

402-
executable_application_instance.inputs = {"input": {"target_dir": "."}}
402+
executable_application_instance.inputs[_FS] = {"input": {"target_dir": "."}}
403403
executable_application_instance.variables = {}
404404

405405
executable_application_instance._set_default_experiment_variables()
@@ -424,7 +424,7 @@ def test_define_commands(mutable_mock_apps_repo):
424424

425425
executable_application_instance.internals = {}
426426

427-
executable_application_instance.inputs = {"input": {"target_dir": "."}}
427+
executable_application_instance.inputs[_FS] = {"input": {"target_dir": "."}}
428428
executable_application_instance.variables = {}
429429

430430
exec_graph = executable_application_instance._get_executable_graph("test_wl2")
@@ -495,7 +495,7 @@ def test_derive_variables_for_template_path(mutable_mock_apps_repo):
495495

496496
executable_application_instance.internals = {}
497497

498-
executable_application_instance.inputs = {"input": {"target_dir": "."}}
498+
executable_application_instance.inputs[_FS] = {"input": {"target_dir": "."}}
499499
executable_application_instance.variables = {}
500500

501501
exec_graph = executable_application_instance._get_executable_graph("test_wl2")

lib/ramble/ramble/test/application_inheritance.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def test_basic_inheritance(mutable_mock_apps_repo):
5858
assert fom_conf["group_name"] == "test"
5959
assert fom_conf["units"] == "s"
6060

61-
assert "input" in app_inst.inputs
62-
assert app_inst.inputs["input"]["url"] == "file:///tmp/test_file.log"
63-
assert app_inst.inputs["input"]["description"] == "Not a file"
64-
assert "inherited_input" in app_inst.inputs
65-
assert app_inst.inputs["inherited_input"]["url"] == "file:///tmp/inherited_file.log"
66-
assert app_inst.inputs["inherited_input"]["description"] == "Again, not a file"
61+
assert "input" in app_inst.inputs[_FS]
62+
assert app_inst.inputs[_FS]["input"]["url"] == "file:///tmp/test_file.log"
63+
assert app_inst.inputs[_FS]["input"]["description"] == "Not a file"
64+
assert "inherited_input" in app_inst.inputs[_FS]
65+
assert app_inst.inputs[_FS]["inherited_input"]["url"] == "file:///tmp/inherited_file.log"
66+
assert app_inst.inputs[_FS]["inherited_input"]["description"] == "Again, not a file"
6767

6868
possible_vars = app_inst.workloads["test_wl"].find_variable("my_base_var")
6969
assert len(possible_vars) == 1

lib/ramble/ramble/test/application_language.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,14 +289,14 @@ def test_input_file_directive(app_class):
289289

290290
assert hasattr(app_inst, "inputs")
291291
for input_name, conf in test_defs.items():
292-
assert input_name in app_inst.inputs
292+
assert input_name in app_inst.inputs[_FS]
293293

294294
for conf_name, conf_val in conf.items():
295-
assert conf_name in app_inst.inputs[input_name]
296-
assert app_inst.inputs[input_name][conf_name] == conf_val
295+
assert conf_name in app_inst.inputs[_FS][input_name]
296+
assert app_inst.inputs[_FS][input_name][conf_name] == conf_val
297297

298-
assert "extension" in app_inst.inputs[input_name]
299-
assert "expand" in app_inst.inputs[input_name]
298+
assert "extension" in app_inst.inputs[_FS][input_name]
299+
assert "expand" in app_inst.inputs[_FS][input_name]
300300

301301

302302
@pytest.mark.parametrize("app_class", app_types)

lib/ramble/ramble/test/mirror.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
),
3535
]
3636

37+
_FS = frozenset()
38+
3739

3840
class MockFetcher:
3941
"""Mock fetcher object which implements the necessary functionality for
@@ -85,15 +87,15 @@ def create_archive(archive_dir, app_class):
8587
with open(archive_name, "rb") as f:
8688
bytes = f.read()
8789
conf["fetcher"].digest = hashlib.sha256(bytes).hexdigest()
88-
app_class.inputs[conf["input_name"]]["sha256"] = conf["fetcher"].digest
90+
app_class.inputs[_FS][conf["input_name"]]["sha256"] = conf["fetcher"].digest
8991
else:
9092
with open(input_name, "w+") as f:
9193
f.write("Input file\n")
9294

9395
with open(input_name, "rb") as f:
9496
bytes = f.read()
9597
conf["fetcher"].digest = hashlib.sha256(bytes).hexdigest()
96-
app_class.inputs[conf["input_name"]]["sha256"] = conf["fetcher"].digest
98+
app_class.inputs[_FS][conf["input_name"]]["sha256"] = conf["fetcher"].digest
9799

98100

99101
def check_mirror(mirror_path, app_name, app_class):

lib/ramble/ramble/test/when.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,3 +1008,49 @@ def test_executable_errors_when_overlapping_conditions(request):
10081008

10091009
captured = workspace("setup", global_args=global_args)
10101010
assert "test_exec_def_when is defined for overlapping `when` conditions" in captured
1011+
1012+
1013+
@pytest.mark.parametrize(
1014+
"input_when,expected_input_file",
1015+
[
1016+
(False, "input1_false"),
1017+
(True, "input1_true"),
1018+
],
1019+
)
1020+
def test_input_when(request, input_when, expected_input_file):
1021+
ws_name = request.node.name.replace("[", "_").replace("]", "_")
1022+
1023+
global_args = ["-w", ws_name]
1024+
1025+
with ramble.workspace.create(ws_name) as ws:
1026+
workspace(
1027+
"manage",
1028+
"experiments",
1029+
"when-directives",
1030+
"--wf",
1031+
"test_inputs_wl",
1032+
"-v",
1033+
"n_ranks=1",
1034+
"-v",
1035+
"n_nodes=1",
1036+
"-v",
1037+
"processes_per_node=1",
1038+
global_args=global_args,
1039+
)
1040+
1041+
config("add", f"variants:input_when:{input_when}", global_args=global_args)
1042+
1043+
ws._re_read()
1044+
workspace("setup", "--dry-run", global_args=global_args)
1045+
1046+
log_file = os.path.join(
1047+
ws.path,
1048+
"logs",
1049+
"setup.latest.out",
1050+
)
1051+
1052+
with open(log_file) as f:
1053+
data = f.read()
1054+
1055+
assert expected_input_file in data
1056+
assert ("input2" in data) == input_when

var/ramble/repos/builtin.mock/applications/when-directives/application.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# option. This file may not be copied, modified, or distributed
77
# except according to those terms.
88

9+
import os
10+
911
from ramble.appkit import *
1012

1113

@@ -213,3 +215,41 @@ def test_builtin_when(self):
213215
"exec_when_wl",
214216
executables=["test_exec", "test_exec_skipped", "test_exec_def_when"],
215217
)
218+
219+
# for input_file()
220+
cwd = os.getcwd()
221+
variant(
222+
"input_when",
223+
default=False,
224+
values=[True, False],
225+
description="Add input file using when",
226+
)
227+
228+
input_file(
229+
"test-input1",
230+
url=f"file://{cwd}/input1_false",
231+
description="Test input when false",
232+
when=["~input_when"],
233+
)
234+
235+
input_file(
236+
"test-input1",
237+
url=f"file://{cwd}/input1_true",
238+
expand=False,
239+
description="Test input when true",
240+
when=["+input_when"],
241+
)
242+
243+
input_file(
244+
"test-input2",
245+
url=f"file://{cwd}/input2",
246+
expand=False,
247+
description="Test input skipped when false",
248+
when=["+input_when"],
249+
)
250+
251+
workload(
252+
"test_inputs_wl",
253+
executables=["test_exec"],
254+
inputs=["test-input1", "test-input2"],
255+
)

0 commit comments

Comments
 (0)