Skip to content

changing name from inner splitter to scoped splitter #823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pydra/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,10 @@
and val._node.state.depth()
):
node: Node = val._node
# variables that are part of inner splitters should be treated as a containers
# variables that are part of scoped splitters (previously known as inner splitter)
# should be treated as a containers
if node.state and f"{node.name}.{val._field}" in node.state.splitter:
node.state._inner_container_ndim[f"{node.name}.{val._field}"] = 1
node.state._scoped_container_ndim[f"{node.name}.{val._field}"] = 1

Check warning on line 221 in pydra/engine/node.py

View check run for this annotation

Codecov / codecov/patch

pydra/engine/node.py#L221

Added line #L221 was not covered by tests
# adding task_name: (task.state, [a field from the connection]
if node.name not in upstream_states:
upstream_states[node.name] = (node.state, [val._field])
Expand Down
84 changes: 44 additions & 40 deletions pydra/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class State:
(previous state, input from current state needed the connection)
}

inner_inputs : :obj:`dict`
inputs_previous_states : :obj:`dict`
used to create connections with previous states
``{"{self.name}.input name for current inp": previous state}``
states_ind : :obj:`list` of :obj:`dict`
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
# temporary combiner
self.combiner = combiner
self.container_ndim = container_ndim or {}
self._inner_container_ndim = {}
self._scoped_container_ndim = {}
self._inputs_ind = None
# if other_states, the connections have to be updated
if self.other_states:
Expand Down Expand Up @@ -359,9 +359,9 @@ def prev_state_splitter_rpn_compact(self):

@property
def container_ndim_all(self):
# adding inner_container_ndim to the general container_dimension provided by the users
# adding scoped_container_ndim to the general container_dimension provided by the users
container_ndim_all = deepcopy(self.container_ndim)
for k, v in self._inner_container_ndim.items():
for k, v in self._scoped_container_ndim.items():
container_ndim_all[k] = container_ndim_all.get(k, 1) + v
return container_ndim_all

Expand Down Expand Up @@ -445,18 +445,18 @@ def other_states(self, other_states):
self._other_states = {}

@property
def inner_inputs(self):
def inputs_previous_states(self):
"""specifies connections between fields from the current state
with the specific state from the previous states, uses dictionary
``{input name for current state: the previous state}``
"""
if self.other_states:
_inner_inputs = {}
_inputs_previous_states = {}
for name, (st, inp_l) in self.other_states.items():
if f"_{st.name}" in self.splitter_rpn_compact:
for inp in inp_l:
_inner_inputs[f"{self.name}.{inp}"] = st
return _inner_inputs
_inputs_previous_states[f"{self.name}.{inp}"] = st
return _inputs_previous_states
else:
return {}

Expand Down Expand Up @@ -706,7 +706,7 @@ def set_input_groups(self, state_fields=True):
keys_f, group_for_inputs_f, groups_stack_f, combiner_all = splits_groups(
current_splitter_rpn,
combiner=self.current_combiner,
inner_inputs=self.inner_inputs,
inputs_previous_states=self.inputs_previous_states,
)
self._current_combiner_all = combiner_all
if (
Expand Down Expand Up @@ -768,7 +768,7 @@ def _merge_previous_groups(self):
) = splits_groups(
st.splitter_rpn_final,
combiner=st_combiner,
inner_inputs=st.inner_inputs,
inputs_previous_states=st.inputs_previous_states,
)
self.keys_final += keys_f_st # st.keys_final
if not hasattr(st, "group_for_inputs_final"):
Expand Down Expand Up @@ -884,7 +884,7 @@ def prepare_states_ind(self):
Uses splits.

"""
# removing elements that are connected to inner splitter
# removing elements that are connected to scoped splitter (previously inner splitter)
# (they will be taken into account in splits anyway)
# _comb part will be used in prepare_states_combined_ind
# TODO: need tests in test_Workflow.py
Expand Down Expand Up @@ -986,8 +986,8 @@ def prepare_inputs(self):
"""
Preparing inputs indices, merges input from previous states.

Includes indices for fields from inner splitters
(removes elements connected to the inner splitters fields).
Includes indices for fields from scoped splitters (previously inner splitters)
(removes elements connected to the scoped splitters fields).

"""
if not self.other_states:
Expand All @@ -1004,21 +1004,21 @@ def prepare_inputs(self):
inputs_ind = []

# merging elements that come from previous nodes outputs
# states that are connected to inner splitters are treated differently
# states that are connected to scoped splitters are treated differently
# (already included in inputs_ind)
keys_inp_prev = []
inputs_ind_prev = []
connected_to_inner = []
connected_to_scoped = []
for ii, el in enumerate(self.prev_state_splitter_rpn_compact):
if el in ["*", "."]:
continue
st, inp_l = self.other_states[el[1:]]
inp_l = [f"{self.name}.{inp}" for inp in inp_l]
if set(inp_l).intersection(self.splitter_rpn): # inner splitter
connected_to_inner += [
if set(inp_l).intersection(self.splitter_rpn): # scoped splitter
connected_to_scoped += [
el for el in st.splitter_rpn_final if el not in [".", "*"]
]
else: # previous states that are not connected to inner splitter
else: # previous states that are not connected to scoped splitter
st_ind = range(len(st.states_ind_final))
if inputs_ind_prev:
# in case the prev-state part has scalar parts (not very well tested)
Expand All @@ -1043,9 +1043,9 @@ def prepare_inputs(self):

# iter_splits using inputs from current state/node
self._inputs_ind = list(iter_splits(inputs_ind, keys_inp))
# removing elements that are connected to inner splitter
# removing elements that are connected to scoped splitter
# TODO - add tests to test_workflow.py (not sure if we want to remove it)
for el in connected_to_inner:
for el in connected_to_scoped:
[dict.pop(el) for dict in self._inputs_ind]

def splits(self, splitter_rpn):
Expand All @@ -1064,10 +1064,10 @@ def splits(self, splitter_rpn):
names of input variables

"""
# analysing states from connected tasks if inner_inputs
# analysing states from connected tasks if inputs_previous_states
previous_states_ind = {
f"_{v.name}": (v.ind_l_final, v.keys_final)
for v in self.inner_inputs.values()
for v in self.inputs_previous_states.values()
}

# when splitter is a single element (no operators)
Expand Down Expand Up @@ -1152,17 +1152,17 @@ def _processing_terms(self, term, previous_states_ind):
shape = input_shape(self.inputs[term], container_ndim=container_ndim)
var_ind = range(prod(shape))
new_keys = [term]
# checking if the term is in inner_inputs
if term in self.inner_inputs:
# checking if the term is in inputs_previous_states
if term in self.inputs_previous_states:
# TODO: have to be changed if differ length
inner_len = [shape[-1]] * prod(shape[:-1])
# this come from the previous node
outer_ind = self.inner_inputs[term].ind_l
outer_ind = self.inputs_previous_states[term].ind_l
var_ind_out = itertools.chain.from_iterable(
itertools.repeat(x, n) for x, n in zip(outer_ind, inner_len)
)
var_ind = op["."](var_ind_out, var_ind)
new_keys = self.inner_inputs[term].keys_final + new_keys
new_keys = self.inputs_previous_states[term].keys_final + new_keys

return shape, var_ind, new_keys

Expand All @@ -1173,18 +1173,18 @@ def _single_op_splits(self, op_single):
container_ndim=self.container_ndim_all.get(op_single, 1),
)
val_ind = range(prod(shape))
if op_single in self.inner_inputs:
if op_single in self.inputs_previous_states:
# TODO: have to be changed if differ length
inner_len = [shape[-1]] * prod(shape[:-1])

# this come from the previous node
outer_ind = self.inner_inputs[op_single].ind_l
outer_ind = self.inputs_previous_states[op_single].ind_l
op_out = itertools.chain.from_iterable(
itertools.repeat(x, n) for x, n in zip(outer_ind, inner_len)
)
res = op["."](op_out, val_ind)
val = res
keys = self.inner_inputs[op_single].keys_final + [op_single]
keys = self.inputs_previous_states[op_single].keys_final + [op_single]
return val, keys
else:
val = op["*"](val_ind)
Expand Down Expand Up @@ -1585,7 +1585,7 @@ def iter_splits(iterable, keys):

def input_shape(inp, container_ndim=1):
"""Get input shape, depends on the container dimension, if not specify it is assumed to be 1"""
# TODO: have to be changed for inner splitter (sometimes different length)
# TODO: have to be changed for scoped (prev. inner) splitter (sometimes different length)
container_ndim -= 1
shape = [len(inp)]
last_shape = None
Expand All @@ -1605,7 +1605,7 @@ def input_shape(inp, container_ndim=1):
return tuple(shape)


def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
def splits_groups(splitter_rpn, combiner=None, inputs_previous_states=None):
"""splits inputs to groups (axes) and creates stacks for these groups
This is used to specify which input can be combined.
"""
Expand All @@ -1617,19 +1617,23 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
group_count = None
if not combiner:
combiner = []
if inner_inputs:
if inputs_previous_states:
previous_states_ind = {
f"_{v.name}": v.keys_final for v in inner_inputs.values()
f"_{v.name}": v.keys_final for v in inputs_previous_states.values()
}
inputs_previous_states = {
k: v for k, v in inputs_previous_states.items() if k in splitter_rpn
}
inner_inputs = {k: v for k, v in inner_inputs.items() if k in splitter_rpn}
else:
previous_states_ind = {}
inner_inputs = {}
inputs_previous_states = {}

# when splitter is a single element (no operators)
if len(splitter_rpn) == 1:
op_single = splitter_rpn[0]
return _single_op_splits_groups(op_single, combiner, inner_inputs, groups)
return _single_op_splits_groups(
op_single, combiner, inputs_previous_states, groups
)

# len(splitter_rpn) > 1
# iterating splitter_rpn
Expand Down Expand Up @@ -1722,7 +1726,7 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
groups_stack = stack.pop()
if isinstance(groups_stack, int):
groups_stack = [groups_stack]
if inner_inputs:
if inputs_previous_states:
groups_stack = [[], groups_stack]
else:
groups_stack = [groups_stack]
Expand All @@ -1739,12 +1743,12 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
return keys, groups, groups_stack, []


def _single_op_splits_groups(op_single, combiner, inner_inputs, groups):
def _single_op_splits_groups(op_single, combiner, inputs_previous_states, groups):
"""splits_groups function if splitter is a singleton"""
if op_single in inner_inputs:
if op_single in inputs_previous_states:
# TODO: have to be changed if differ length
# TODO: i think I don't want to add here from left part
# keys = inner_inputs[op_single].keys_final + [op_single]
# keys = inputs_previous_states[op_single].keys_final + [op_single]
keys = [op_single]
groups[op_single], groups_stack = 0, [[], [0]]
else:
Expand Down
4 changes: 2 additions & 2 deletions pydra/engine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,14 @@ def _create_graph(
graph.node(lf._node.name).state
and graph.node(lf._node.name).state.splitter_rpn_final
):
# variables that are part of inner splitters should be
# variables that are part of scoped splitters should be
# treated as a containers
if (
node.state
and f"{node.name}.{field.name}"
in node.state._current_splitter_rpn
):
node.state._inner_container_ndim[
node.state._scoped_container_ndim[
f"{node.name}.{field.name}"
] = 1
# adding task_name: (job.state, [a field from the connection]
Expand Down
Loading