diff --git a/pydra/engine/node.py b/pydra/engine/node.py index c4705d766..a11f693e5 100644 --- a/pydra/engine/node.py +++ b/pydra/engine/node.py @@ -215,9 +215,10 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]: and val._node.state.depth() ): node: Node = val._node - # variables that are part of inner splitters should be treated as a containers + # variables that are part of scoped splitters (previously known as inner splitter) + # should be treated as a containers if node.state and f"{node.name}.{val._field}" in node.state.splitter: - node.state._inner_container_ndim[f"{node.name}.{val._field}"] = 1 + node.state._scoped_container_ndim[f"{node.name}.{val._field}"] = 1 # adding task_name: (task.state, [a field from the connection] if node.name not in upstream_states: upstream_states[node.name] = (node.state, [val._field]) diff --git a/pydra/engine/state.py b/pydra/engine/state.py index b81d85db6..c770a3a26 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -52,7 +52,7 @@ class State: (previous state, input from current state needed the connection) } - inner_inputs : :obj:`dict` + inputs_previous_states : :obj:`dict` used to create connections with previous states ``{"{self.name}.input name for current inp": previous state}`` states_ind : :obj:`list` of :obj:`dict` @@ -110,7 +110,7 @@ def __init__( # temporary combiner self.combiner = combiner self.container_ndim = container_ndim or {} - self._inner_container_ndim = {} + self._scoped_container_ndim = {} self._inputs_ind = None # if other_states, the connections have to be updated if self.other_states: @@ -359,9 +359,9 @@ def prev_state_splitter_rpn_compact(self): @property def container_ndim_all(self): - # adding inner_container_ndim to the general container_dimension provided by the users + # adding scoped_container_ndim to the general container_dimension provided by the users container_ndim_all = deepcopy(self.container_ndim) - for k, v in self._inner_container_ndim.items(): + for k, v in self._scoped_container_ndim.items(): container_ndim_all[k] = container_ndim_all.get(k, 1) + v return container_ndim_all @@ -445,18 +445,18 @@ def other_states(self, other_states): self._other_states = {} @property - def inner_inputs(self): + def inputs_previous_states(self): """specifies connections between fields from the current state with the specific state from the previous states, uses dictionary ``{input name for current state: the previous state}`` """ if self.other_states: - _inner_inputs = {} + _inputs_previous_states = {} for name, (st, inp_l) in self.other_states.items(): if f"_{st.name}" in self.splitter_rpn_compact: for inp in inp_l: - _inner_inputs[f"{self.name}.{inp}"] = st - return _inner_inputs + _inputs_previous_states[f"{self.name}.{inp}"] = st + return _inputs_previous_states else: return {} @@ -706,7 +706,7 @@ def set_input_groups(self, state_fields=True): keys_f, group_for_inputs_f, groups_stack_f, combiner_all = splits_groups( current_splitter_rpn, combiner=self.current_combiner, - inner_inputs=self.inner_inputs, + inputs_previous_states=self.inputs_previous_states, ) self._current_combiner_all = combiner_all if ( @@ -768,7 +768,7 @@ def _merge_previous_groups(self): ) = splits_groups( st.splitter_rpn_final, combiner=st_combiner, - inner_inputs=st.inner_inputs, + inputs_previous_states=st.inputs_previous_states, ) self.keys_final += keys_f_st # st.keys_final if not hasattr(st, "group_for_inputs_final"): @@ -884,7 +884,7 @@ def prepare_states_ind(self): Uses splits. """ - # removing elements that are connected to inner splitter + # removing elements that are connected to scoped splitter (previously inner splitter) # (they will be taken into account in splits anyway) # _comb part will be used in prepare_states_combined_ind # TODO: need tests in test_Workflow.py @@ -986,8 +986,8 @@ def prepare_inputs(self): """ Preparing inputs indices, merges input from previous states. - Includes indices for fields from inner splitters - (removes elements connected to the inner splitters fields). + Includes indices for fields from scoped splitters (previously inner splitters) + (removes elements connected to the scoped splitters fields). """ if not self.other_states: @@ -1004,21 +1004,21 @@ def prepare_inputs(self): inputs_ind = [] # merging elements that come from previous nodes outputs - # states that are connected to inner splitters are treated differently + # states that are connected to scoped splitters are treated differently # (already included in inputs_ind) keys_inp_prev = [] inputs_ind_prev = [] - connected_to_inner = [] + connected_to_scoped = [] for ii, el in enumerate(self.prev_state_splitter_rpn_compact): if el in ["*", "."]: continue st, inp_l = self.other_states[el[1:]] inp_l = [f"{self.name}.{inp}" for inp in inp_l] - if set(inp_l).intersection(self.splitter_rpn): # inner splitter - connected_to_inner += [ + if set(inp_l).intersection(self.splitter_rpn): # scoped splitter + connected_to_scoped += [ el for el in st.splitter_rpn_final if el not in [".", "*"] ] - else: # previous states that are not connected to inner splitter + else: # previous states that are not connected to scoped splitter st_ind = range(len(st.states_ind_final)) if inputs_ind_prev: # in case the prev-state part has scalar parts (not very well tested) @@ -1043,9 +1043,9 @@ def prepare_inputs(self): # iter_splits using inputs from current state/node self._inputs_ind = list(iter_splits(inputs_ind, keys_inp)) - # removing elements that are connected to inner splitter + # removing elements that are connected to scoped splitter # TODO - add tests to test_workflow.py (not sure if we want to remove it) - for el in connected_to_inner: + for el in connected_to_scoped: [dict.pop(el) for dict in self._inputs_ind] def splits(self, splitter_rpn): @@ -1064,10 +1064,10 @@ def splits(self, splitter_rpn): names of input variables """ - # analysing states from connected tasks if inner_inputs + # analysing states from connected tasks if inputs_previous_states previous_states_ind = { f"_{v.name}": (v.ind_l_final, v.keys_final) - for v in self.inner_inputs.values() + for v in self.inputs_previous_states.values() } # when splitter is a single element (no operators) @@ -1152,17 +1152,17 @@ def _processing_terms(self, term, previous_states_ind): shape = input_shape(self.inputs[term], container_ndim=container_ndim) var_ind = range(prod(shape)) new_keys = [term] - # checking if the term is in inner_inputs - if term in self.inner_inputs: + # checking if the term is in inputs_previous_states + if term in self.inputs_previous_states: # TODO: have to be changed if differ length inner_len = [shape[-1]] * prod(shape[:-1]) # this come from the previous node - outer_ind = self.inner_inputs[term].ind_l + outer_ind = self.inputs_previous_states[term].ind_l var_ind_out = itertools.chain.from_iterable( itertools.repeat(x, n) for x, n in zip(outer_ind, inner_len) ) var_ind = op["."](var_ind_out, var_ind) - new_keys = self.inner_inputs[term].keys_final + new_keys + new_keys = self.inputs_previous_states[term].keys_final + new_keys return shape, var_ind, new_keys @@ -1173,18 +1173,18 @@ def _single_op_splits(self, op_single): container_ndim=self.container_ndim_all.get(op_single, 1), ) val_ind = range(prod(shape)) - if op_single in self.inner_inputs: + if op_single in self.inputs_previous_states: # TODO: have to be changed if differ length inner_len = [shape[-1]] * prod(shape[:-1]) # this come from the previous node - outer_ind = self.inner_inputs[op_single].ind_l + outer_ind = self.inputs_previous_states[op_single].ind_l op_out = itertools.chain.from_iterable( itertools.repeat(x, n) for x, n in zip(outer_ind, inner_len) ) res = op["."](op_out, val_ind) val = res - keys = self.inner_inputs[op_single].keys_final + [op_single] + keys = self.inputs_previous_states[op_single].keys_final + [op_single] return val, keys else: val = op["*"](val_ind) @@ -1585,7 +1585,7 @@ def iter_splits(iterable, keys): def input_shape(inp, container_ndim=1): """Get input shape, depends on the container dimension, if not specify it is assumed to be 1""" - # TODO: have to be changed for inner splitter (sometimes different length) + # TODO: have to be changed for scoped (prev. inner) splitter (sometimes different length) container_ndim -= 1 shape = [len(inp)] last_shape = None @@ -1605,7 +1605,7 @@ def input_shape(inp, container_ndim=1): return tuple(shape) -def splits_groups(splitter_rpn, combiner=None, inner_inputs=None): +def splits_groups(splitter_rpn, combiner=None, inputs_previous_states=None): """splits inputs to groups (axes) and creates stacks for these groups This is used to specify which input can be combined. """ @@ -1617,19 +1617,23 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None): group_count = None if not combiner: combiner = [] - if inner_inputs: + if inputs_previous_states: previous_states_ind = { - f"_{v.name}": v.keys_final for v in inner_inputs.values() + f"_{v.name}": v.keys_final for v in inputs_previous_states.values() + } + inputs_previous_states = { + k: v for k, v in inputs_previous_states.items() if k in splitter_rpn } - inner_inputs = {k: v for k, v in inner_inputs.items() if k in splitter_rpn} else: previous_states_ind = {} - inner_inputs = {} + inputs_previous_states = {} # when splitter is a single element (no operators) if len(splitter_rpn) == 1: op_single = splitter_rpn[0] - return _single_op_splits_groups(op_single, combiner, inner_inputs, groups) + return _single_op_splits_groups( + op_single, combiner, inputs_previous_states, groups + ) # len(splitter_rpn) > 1 # iterating splitter_rpn @@ -1722,7 +1726,7 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None): groups_stack = stack.pop() if isinstance(groups_stack, int): groups_stack = [groups_stack] - if inner_inputs: + if inputs_previous_states: groups_stack = [[], groups_stack] else: groups_stack = [groups_stack] @@ -1739,12 +1743,12 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None): return keys, groups, groups_stack, [] -def _single_op_splits_groups(op_single, combiner, inner_inputs, groups): +def _single_op_splits_groups(op_single, combiner, inputs_previous_states, groups): """splits_groups function if splitter is a singleton""" - if op_single in inner_inputs: + if op_single in inputs_previous_states: # TODO: have to be changed if differ length # TODO: i think I don't want to add here from left part - # keys = inner_inputs[op_single].keys_final + [op_single] + # keys = inputs_previous_states[op_single].keys_final + [op_single] keys = [op_single] groups[op_single], groups_stack = 0, [[], [0]] else: diff --git a/pydra/engine/workflow.py b/pydra/engine/workflow.py index c277d502e..e46cf0c0b 100644 --- a/pydra/engine/workflow.py +++ b/pydra/engine/workflow.py @@ -342,14 +342,14 @@ def _create_graph( graph.node(lf._node.name).state and graph.node(lf._node.name).state.splitter_rpn_final ): - # variables that are part of inner splitters should be + # variables that are part of scoped splitters should be # treated as a containers if ( node.state and f"{node.name}.{field.name}" in node.state._current_splitter_rpn ): - node.state._inner_container_ndim[ + node.state._scoped_container_ndim[ f"{node.name}.{field.name}" ] = 1 # adding task_name: (job.state, [a field from the connection]