Skip to content

Commit fbb39d6

Browse files
authored
Merge pull request #480 from jw-96/splitters
Support for splitters containing a single item list
2 parents 7fb3d17 + eb17a4d commit fbb39d6

File tree

2 files changed

+92
-70
lines changed

2 files changed

+92
-70
lines changed

pydra/engine/helpers_state.py

+80-70
Original file line numberDiff line numberDiff line change
@@ -56,81 +56,91 @@ def _ordering(
5656
if type(el) is tuple:
5757
# checking if the splitter dont contain splitter from previous nodes
5858
# i.e. has str "_NA", etc.
59-
if type(el[0]) is str and el[0].startswith("_"):
60-
node_nm = el[0][1:]
61-
if node_nm not in other_states and state_fields:
62-
raise PydraStateError(
63-
"can't ask for splitter from {}, other nodes that are connected: {}".format(
64-
node_nm, other_states.keys()
59+
if len(el) == 1:
60+
# treats .split(("x",)) like .split("x")
61+
el = el[0]
62+
_ordering(el, i, output_splitter, current_sign, other_states, state_fields)
63+
else:
64+
if type(el[0]) is str and el[0].startswith("_"):
65+
node_nm = el[0][1:]
66+
if node_nm not in other_states and state_fields:
67+
raise PydraStateError(
68+
"can't ask for splitter from {}, other nodes that are connected: {}".format(
69+
node_nm, other_states.keys()
70+
)
6571
)
66-
)
67-
elif state_fields:
68-
splitter_mod = add_name_splitter(
69-
splitter=other_states[node_nm][0].splitter_final, name=node_nm
70-
)
71-
el = (splitter_mod, el[1])
72-
if other_states[node_nm][0].other_states:
73-
other_states.update(other_states[node_nm][0].other_states)
74-
if type(el[1]) is str and el[1].startswith("_"):
75-
node_nm = el[1][1:]
76-
if node_nm not in other_states and state_fields:
77-
raise PydraStateError(
78-
"can't ask for splitter from {}, other nodes that are connected: {}".format(
79-
node_nm, other_states.keys()
72+
elif state_fields:
73+
splitter_mod = add_name_splitter(
74+
splitter=other_states[node_nm][0].splitter_final, name=node_nm
8075
)
81-
)
82-
elif state_fields:
83-
splitter_mod = add_name_splitter(
84-
splitter=other_states[node_nm][0].splitter_final, name=node_nm
85-
)
86-
el = (el[0], splitter_mod)
87-
if other_states[node_nm][0].other_states:
88-
other_states.update(other_states[node_nm][0].other_states)
89-
_iterate_list(
90-
el,
91-
".",
92-
other_states,
93-
output_splitter=output_splitter,
94-
state_fields=state_fields,
95-
)
76+
el = (splitter_mod, el[1])
77+
if other_states[node_nm][0].other_states:
78+
other_states.update(other_states[node_nm][0].other_states)
79+
if type(el[1]) is str and el[1].startswith("_"):
80+
node_nm = el[1][1:]
81+
if node_nm not in other_states and state_fields:
82+
raise PydraStateError(
83+
"can't ask for splitter from {}, other nodes that are connected: {}".format(
84+
node_nm, other_states.keys()
85+
)
86+
)
87+
elif state_fields:
88+
splitter_mod = add_name_splitter(
89+
splitter=other_states[node_nm][0].splitter_final, name=node_nm
90+
)
91+
el = (el[0], splitter_mod)
92+
if other_states[node_nm][0].other_states:
93+
other_states.update(other_states[node_nm][0].other_states)
94+
_iterate_list(
95+
el,
96+
".",
97+
other_states,
98+
output_splitter=output_splitter,
99+
state_fields=state_fields,
100+
)
96101
elif type(el) is list:
97-
if type(el[0]) is str and el[0].startswith("_"):
98-
node_nm = el[0][1:]
99-
if node_nm not in other_states and state_fields:
100-
raise PydraStateError(
101-
"can't ask for splitter from {}, other nodes that are connected: {}".format(
102-
node_nm, other_states.keys()
102+
if len(el) == 1:
103+
# treats .split(["x"]) like .split("x")
104+
el = el[0]
105+
_ordering(el, i, output_splitter, current_sign, other_states, state_fields)
106+
else:
107+
if type(el[0]) is str and el[0].startswith("_"):
108+
node_nm = el[0][1:]
109+
if node_nm not in other_states and state_fields:
110+
raise PydraStateError(
111+
"can't ask for splitter from {}, other nodes that are connected: {}".format(
112+
node_nm, other_states.keys()
113+
)
103114
)
104-
)
105-
elif state_fields:
106-
splitter_mod = add_name_splitter(
107-
splitter=other_states[node_nm][0].splitter_final, name=node_nm
108-
)
109-
el[0] = splitter_mod
110-
if other_states[node_nm][0].other_states:
111-
other_states.update(other_states[node_nm][0].other_states)
112-
if type(el[1]) is str and el[1].startswith("_"):
113-
node_nm = el[1][1:]
114-
if node_nm not in other_states and state_fields:
115-
raise PydraStateError(
116-
"can't ask for splitter from {}, other nodes that are connected: {}".format(
117-
node_nm, other_states.keys()
115+
elif state_fields:
116+
splitter_mod = add_name_splitter(
117+
splitter=other_states[node_nm][0].splitter_final, name=node_nm
118118
)
119-
)
120-
elif state_fields:
121-
splitter_mod = add_name_splitter(
122-
splitter=other_states[node_nm][0].splitter_final, name=node_nm
123-
)
124-
el[1] = splitter_mod
125-
if other_states[node_nm][0].other_states:
126-
other_states.update(other_states[node_nm][0].other_states)
127-
_iterate_list(
128-
el,
129-
"*",
130-
other_states,
131-
output_splitter=output_splitter,
132-
state_fields=state_fields,
133-
)
119+
el[0] = splitter_mod
120+
if other_states[node_nm][0].other_states:
121+
other_states.update(other_states[node_nm][0].other_states)
122+
if type(el[1]) is str and el[1].startswith("_"):
123+
node_nm = el[1][1:]
124+
if node_nm not in other_states and state_fields:
125+
raise PydraStateError(
126+
"can't ask for splitter from {}, other nodes that are connected: {}".format(
127+
node_nm, other_states.keys()
128+
)
129+
)
130+
elif state_fields:
131+
splitter_mod = add_name_splitter(
132+
splitter=other_states[node_nm][0].splitter_final, name=node_nm
133+
)
134+
el[1] = splitter_mod
135+
if other_states[node_nm][0].other_states:
136+
other_states.update(other_states[node_nm][0].other_states)
137+
_iterate_list(
138+
el,
139+
"*",
140+
other_states,
141+
output_splitter=output_splitter,
142+
state_fields=state_fields,
143+
)
134144
elif type(el) is str:
135145
if el.startswith("_"):
136146
node_nm = el[1:]

pydra/engine/tests/test_helpers_state.py

+12
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ def __init__(
3131
"splitter, keys_exp, groups_exp, grstack_exp",
3232
[
3333
("a", ["a"], {"a": 0}, [[0]]),
34+
(["a"], ["a"], {"a": 0}, [[0]]),
35+
(("a",), ["a"], {"a": 0}, [[0]]),
3436
(("a", "b"), ["a", "b"], {"a": 0, "b": 0}, [[0]]),
3537
(["a", "b"], ["a", "b"], {"a": 0, "b": 1}, [[0, 1]]),
38+
([["a", "b"]], ["a", "b"], {"a": 0, "b": 1}, [[0, 1]]),
39+
((["a", "b"],), ["a", "b"], {"a": 0, "b": 1}, [[0, 1]]),
3640
((["a", "b"], "c"), ["a", "b", "c"], {"a": 0, "b": 1, "c": [0, 1]}, [[0, 1]]),
3741
([("a", "b"), "c"], ["a", "b", "c"], {"a": 0, "b": 0, "c": 1}, [[0, 1]]),
3842
([["a", "b"], "c"], ["a", "b", "c"], {"a": 0, "b": 1, "c": 2}, [[0, 1, 2]]),
@@ -58,6 +62,8 @@ def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp):
5862
"keys_final_exp, groups_final_exp, grstack_final_exp",
5963
[
6064
("a", ["a"], ["a"], [], {}, []),
65+
(["a"], ["a"], ["a"], [], {}, []),
66+
(("a",), ["a"], ["a"], [], {}, []),
6167
(("a", "b"), ["a"], ["a", "b"], [], {}, [[]]),
6268
(("a", "b"), ["b"], ["a", "b"], [], {}, [[]]),
6369
(["a", "b"], ["b"], ["b"], ["a"], {"a": 0}, [[0]]),
@@ -69,6 +75,8 @@ def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp):
6975
([("a", "b"), "c"], ["a"], ["a", "b"], ["c"], {"c": 0}, [[0]]),
7076
([("a", "b"), "c"], ["b"], ["a", "b"], ["c"], {"c": 0}, [[0]]),
7177
([("a", "b"), "c"], ["c"], ["c"], ["a", "b"], {"a": 0, "b": 0}, [[0]]),
78+
([[("a", "b"), "c"]], ["c"], ["c"], ["a", "b"], {"a": 0, "b": 0}, [[0]]),
79+
(([("a", "b"), "c"],), ["c"], ["c"], ["a", "b"], {"a": 0, "b": 0}, [[0]]),
7280
],
7381
)
7482
def test_splits_groups_comb(
@@ -94,6 +102,8 @@ def test_splits_groups_comb(
94102
"splitter, cont_dim, values, keys, splits",
95103
[
96104
("a", None, [(0,), (1,)], ["a"], [{"a": 1}, {"a": 2}]),
105+
(["a"], None, [(0,), (1,)], ["a"], [{"a": 1}, {"a": 2}]),
106+
(("a",), None, [(0,), (1,)], ["a"], [{"a": 1}, {"a": 2}]),
97107
(
98108
("a", "v"),
99109
None,
@@ -468,6 +478,8 @@ def test_splits_2(splitter_rpn, inner_inputs, values, keys, splits):
468478
(["a", ("b", ["c", "d"])], ["a", "b", "c", "d", "*", ".", "*"]),
469479
((["a", "b"], "c"), ["a", "b", "*", "c", "."]),
470480
((["a", "b"], ["c", "d"]), ["a", "b", "*", "c", "d", "*", "."]),
481+
(([["a", "b"]], ["c", "d"]), ["a", "b", "*", "c", "d", "*", "."]),
482+
(((["a", "b"],), ["c", "d"]), ["a", "b", "*", "c", "d", "*", "."]),
471483
([("a", "b"), ("c", "d")], ["a", "b", ".", "c", "d", ".", "*"]),
472484
],
473485
)

0 commit comments

Comments
 (0)