Skip to content

Commit feafec8

Browse files
authored
Merge pull request #213 from djarecka/mnt/state_review
adding cont_dim arguments to specify the container dimension for input variables (closes #181)
2 parents 81af3b8 + abfef7b commit feafec8

File tree

5 files changed

+222
-72
lines changed

5 files changed

+222
-72
lines changed

pydra/engine/helpers_state.py

+68-25
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,15 @@ def iter_splits(iterable, keys):
362362
yield dict(zip(keys, list(flatten(iter, max_depth=1000))))
363363

364364

365-
def input_shape(in1):
366-
"""Get input shape."""
365+
def input_shape(inp, cont_dim=1):
366+
"""Get input shape, depends on the container dimension, if not specify it is assumed to be 1 """
367367
# TODO: have to be changed for inner splitter (sometimes different length)
368-
shape = [len(in1)]
368+
cont_dim -= 1
369+
shape = [len(inp)]
369370
last_shape = None
370-
for value in in1:
371-
if isinstance(value, list):
372-
cur_shape = input_shape(value)
371+
for value in inp:
372+
if isinstance(value, list) and cont_dim > 0:
373+
cur_shape = input_shape(value, cont_dim)
373374
if last_shape is None:
374375
last_shape = cur_shape
375376
elif last_shape != cur_shape:
@@ -383,11 +384,37 @@ def input_shape(in1):
383384
return tuple(shape)
384385

385386

386-
def splits(splitter_rpn, inputs, inner_inputs=None):
387-
"""Split process as specified by an rpn splitter, from left to right."""
387+
def splits(splitter_rpn, inputs, inner_inputs=None, cont_dim=None):
388+
"""
389+
Splits input variable as specified by splitter
390+
391+
Parameters
392+
----------
393+
splitter_rpn : list
394+
splitter in RPN notation
395+
inputs: dict
396+
input variables
397+
inner_inputs: dict, optional
398+
inner input specification
399+
cont_dim: dict, optional
400+
container dimension for input variable, specifies how nested is the intput,
401+
if not specified 1 will be used for all inputs (so will not be flatten)
402+
403+
404+
Returns
405+
-------
406+
splitter : list
407+
each element contains indices for inputs
408+
keys: list
409+
names of input variables
410+
411+
"""
412+
388413
stack = []
389414
keys = []
390-
shapes_var = {}
415+
if cont_dim is None:
416+
cont_dim = {}
417+
# analysing states from connected tasks if inner_inputs
391418
if inner_inputs:
392419
previous_states_ind = {
393420
"_{}".format(v.name): (v.ind_l_final, v.keys_final)
@@ -407,9 +434,9 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
407434
op_single,
408435
inputs,
409436
inner_inputs,
410-
shapes_var,
411437
previous_states_ind,
412438
keys_fromLeftSpl,
439+
cont_dim=cont_dim,
413440
)
414441

415442
terms = {}
@@ -418,7 +445,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
418445
shape = {}
419446
# iterating splitter_rpn
420447
for token in splitter_rpn:
421-
if token in [".", "*"]:
448+
if token not in [".", "*"]: # token is one of the input var
449+
# adding variable to the stack
450+
stack.append(token)
451+
else:
452+
# removing Right and Left var from the stack
422453
terms["R"] = stack.pop()
423454
terms["L"] = stack.pop()
424455
# checking if terms are strings, shapes, etc.
@@ -429,10 +460,14 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
429460
trm_val[lr] = previous_states_ind[term][0]
430461
shape[lr] = (len(trm_val[lr]),)
431462
else:
432-
shape[lr] = input_shape(inputs[term])
463+
if term in cont_dim:
464+
shape[lr] = input_shape(
465+
inputs[term], cont_dim=cont_dim[term]
466+
)
467+
else:
468+
shape[lr] = input_shape(inputs[term])
433469
trm_val[lr] = range(reduce(lambda x, y: x * y, shape[lr]))
434470
trm_str[lr] = True
435-
shapes_var[term] = shape[lr]
436471
else:
437472
trm_val[lr], shape[lr] = term
438473
trm_str[lr] = False
@@ -447,6 +482,7 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
447482
)
448483
newshape = shape["R"]
449484
if token == "*":
485+
# TODO: pomyslec
450486
newshape = tuple(list(shape["L"]) + list(shape["R"]))
451487

452488
# creating list with keys
@@ -466,7 +502,6 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
466502
elif trm_str["R"]:
467503
keys = keys + new_keys["R"]
468504

469-
#
470505
newtrm_val = {}
471506
for lr in ["R", "L"]:
472507
# TODO: rewrite once I have more tests
@@ -491,13 +526,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
491526

492527
pushval = (op[token](newtrm_val["L"], newtrm_val["R"]), newshape)
493528
stack.append(pushval)
494-
else: # name of one of the inputs (token not in [".", "*"])
495-
stack.append(token)
496529

497530
val = stack.pop()
498531
if isinstance(val, tuple):
499532
val = val[0]
500-
return val, keys, shapes_var, keys_fromLeftSpl
533+
return val, keys, keys_fromLeftSpl
501534

502535

503536
# dj: TODO: do I need keys?
@@ -636,17 +669,22 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
636669

637670

638671
def _single_op_splits(
639-
op_single, inputs, inner_inputs, shapes_var, previous_states_ind, keys_fromLeftSpl
672+
op_single,
673+
inputs,
674+
inner_inputs,
675+
previous_states_ind,
676+
keys_fromLeftSpl,
677+
cont_dim=None,
640678
):
641679
if op_single.startswith("_"):
642680
return (
643681
previous_states_ind[op_single][0],
644682
previous_states_ind[op_single][1],
645-
None,
646683
keys_fromLeftSpl,
647684
)
648-
shape = input_shape(inputs[op_single])
649-
shapes_var[op_single] = shape
685+
if cont_dim is None:
686+
cont_dim = {}
687+
shape = input_shape(inputs[op_single], cont_dim=cont_dim.get(op_single, 1))
650688
trmval = range(reduce(lambda x, y: x * y, shape))
651689
if op_single in inner_inputs:
652690
# TODO: have to be changed if differ length
@@ -659,11 +697,11 @@ def _single_op_splits(
659697
res = op["."](op_out, trmval)
660698
val = res
661699
keys = inner_inputs[op_single].keys_final + [op_single]
662-
return val, keys, shapes_var, keys_fromLeftSpl
700+
return val, keys, keys_fromLeftSpl
663701
else:
664702
val = op["*"](trmval)
665703
keys = [op_single]
666-
return val, keys, shapes_var, keys_fromLeftSpl
704+
return val, keys, keys_fromLeftSpl
667705

668706

669707
def _single_op_splits_groups(
@@ -727,10 +765,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
727765
return keys_final, groups_final, groups_stack_final, combiner_all
728766

729767

730-
def map_splits(split_iter, inputs):
768+
def map_splits(split_iter, inputs, cont_dim=None):
731769
"""Get a dictionary of prescribed splits."""
770+
if cont_dim is None:
771+
cont_dim = {}
732772
for split in split_iter:
733-
yield {k: list(flatten(ensure_list(inputs[k])))[v] for k, v in split.items()}
773+
yield {
774+
k: list(flatten(ensure_list(inputs[k]), max_depth=cont_dim.get(k, None)))[v]
775+
for k, v in split.items()
776+
}
734777

735778

736779
# Functions for merging and completing splitters in states.

pydra/engine/state.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def push_new_states(self):
346346
stack = [gr + nr_gr_f for gr in stack]
347347
self.groups_stack_final.append(stack)
348348

349-
def prepare_states(self, inputs):
349+
def prepare_states(self, inputs, cont_dim=None):
350350
"""
351351
Prepare a full list of state indices and state values.
352352
@@ -357,6 +357,11 @@ def prepare_states(self, inputs):
357357
specific elements from inputs that can be used running interfaces
358358
359359
"""
360+
# container dimension for each input, specifies how nested the input is
361+
if cont_dim is None:
362+
self.cont_dim = {}
363+
else:
364+
self.cont_dim = cont_dim
360365
if isinstance(inputs, BaseSpec):
361366
self.inputs = hlpst.inputs_types_to_dict(self.name, inputs)
362367
else:
@@ -366,7 +371,7 @@ def prepare_states(self, inputs):
366371
# I think now this if is never used
367372
if not hasattr(st, "states_ind"):
368373
# dj: should i provide different inputs?
369-
st.prepare_states(self.inputs)
374+
st.prepare_states(self.inputs, cont_dim=cont_dim)
370375
self.inputs.update(st.inputs)
371376
self.prepare_states_ind()
372377
self.prepare_states_val()
@@ -395,8 +400,11 @@ def prepare_states_ind(self):
395400
partial_rpn = hlpst.remove_inp_from_splitter_rpn(
396401
deepcopy(self.splitter_rpn_compact), elements_to_remove
397402
)
398-
values_out_pr, keys_out_pr, _, kL = hlpst.splits(
399-
partial_rpn, self.inputs, inner_inputs=self.inner_inputs
403+
values_out_pr, keys_out_pr, kL = hlpst.splits(
404+
partial_rpn,
405+
self.inputs,
406+
inner_inputs=self.inner_inputs,
407+
cont_dim=self.cont_dim,
400408
)
401409
values_pr = list(values_out_pr)
402410

@@ -429,8 +437,11 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
429437
)
430438
# TODO: create a function for this!!
431439
if combined_rpn:
432-
val_r, key_r, _, _ = hlpst.splits(
433-
combined_rpn, self.inputs, inner_inputs=self.inner_inputs
440+
val_r, key_r, _ = hlpst.splits(
441+
combined_rpn,
442+
self.inputs,
443+
inner_inputs=self.inner_inputs,
444+
cont_dim=self.cont_dim,
434445
)
435446
values = list(val_r)
436447
else:
@@ -464,7 +475,9 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
464475

465476
def prepare_states_val(self):
466477
"""Evaluate states values having states indices."""
467-
self.states_val = list(hlpst.map_splits(self.states_ind, self.inputs))
478+
self.states_val = list(
479+
hlpst.map_splits(self.states_ind, self.inputs, cont_dim=self.cont_dim)
480+
)
468481
return self.states_val
469482

470483
def prepare_inputs(self):
@@ -489,8 +502,11 @@ def prepare_inputs(self):
489502
deepcopy(self.splitter_rpn_compact), elements_to_remove
490503
)
491504
if partial_rpn:
492-
values_inp, keys_inp, _, _ = hlpst.splits(
493-
partial_rpn, self.inputs, inner_inputs=self.inner_inputs
505+
values_inp, keys_inp, _ = hlpst.splits(
506+
partial_rpn,
507+
self.inputs,
508+
inner_inputs=self.inner_inputs,
509+
cont_dim=self.cont_dim,
494510
)
495511
inputs_ind = values_inp
496512
else:

0 commit comments

Comments
 (0)