@@ -362,14 +362,15 @@ def iter_splits(iterable, keys):
362
362
yield dict (zip (keys , list (flatten (iter , max_depth = 1000 ))))
363
363
364
364
365
- def input_shape (in1 ):
366
- """Get input shape. """
365
+ def input_shape (inp , cont_dim = 1 ):
366
+ """Get input shape, depends on the container dimension, if not specify it is assumed to be 1 """
367
367
# TODO: have to be changed for inner splitter (sometimes different length)
368
- shape = [len (in1 )]
368
+ cont_dim -= 1
369
+ shape = [len (inp )]
369
370
last_shape = None
370
- for value in in1 :
371
- if isinstance (value , list ):
372
- cur_shape = input_shape (value )
371
+ for value in inp :
372
+ if isinstance (value , list ) and cont_dim > 0 :
373
+ cur_shape = input_shape (value , cont_dim )
373
374
if last_shape is None :
374
375
last_shape = cur_shape
375
376
elif last_shape != cur_shape :
@@ -383,11 +384,37 @@ def input_shape(in1):
383
384
return tuple (shape )
384
385
385
386
386
- def splits (splitter_rpn , inputs , inner_inputs = None ):
387
- """Split process as specified by an rpn splitter, from left to right."""
387
+ def splits (splitter_rpn , inputs , inner_inputs = None , cont_dim = None ):
388
+ """
389
+ Splits input variable as specified by splitter
390
+
391
+ Parameters
392
+ ----------
393
+ splitter_rpn : list
394
+ splitter in RPN notation
395
+ inputs: dict
396
+ input variables
397
+ inner_inputs: dict, optional
398
+ inner input specification
399
+ cont_dim: dict, optional
400
+ container dimension for input variable, specifies how nested is the intput,
401
+ if not specified 1 will be used for all inputs (so will not be flatten)
402
+
403
+
404
+ Returns
405
+ -------
406
+ splitter : list
407
+ each element contains indices for inputs
408
+ keys: list
409
+ names of input variables
410
+
411
+ """
412
+
388
413
stack = []
389
414
keys = []
390
- shapes_var = {}
415
+ if cont_dim is None :
416
+ cont_dim = {}
417
+ # analysing states from connected tasks if inner_inputs
391
418
if inner_inputs :
392
419
previous_states_ind = {
393
420
"_{}" .format (v .name ): (v .ind_l_final , v .keys_final )
@@ -407,9 +434,9 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
407
434
op_single ,
408
435
inputs ,
409
436
inner_inputs ,
410
- shapes_var ,
411
437
previous_states_ind ,
412
438
keys_fromLeftSpl ,
439
+ cont_dim = cont_dim ,
413
440
)
414
441
415
442
terms = {}
@@ -418,7 +445,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
418
445
shape = {}
419
446
# iterating splitter_rpn
420
447
for token in splitter_rpn :
421
- if token in ["." , "*" ]:
448
+ if token not in ["." , "*" ]: # token is one of the input var
449
+ # adding variable to the stack
450
+ stack .append (token )
451
+ else :
452
+ # removing Right and Left var from the stack
422
453
terms ["R" ] = stack .pop ()
423
454
terms ["L" ] = stack .pop ()
424
455
# checking if terms are strings, shapes, etc.
@@ -429,10 +460,14 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
429
460
trm_val [lr ] = previous_states_ind [term ][0 ]
430
461
shape [lr ] = (len (trm_val [lr ]),)
431
462
else :
432
- shape [lr ] = input_shape (inputs [term ])
463
+ if term in cont_dim :
464
+ shape [lr ] = input_shape (
465
+ inputs [term ], cont_dim = cont_dim [term ]
466
+ )
467
+ else :
468
+ shape [lr ] = input_shape (inputs [term ])
433
469
trm_val [lr ] = range (reduce (lambda x , y : x * y , shape [lr ]))
434
470
trm_str [lr ] = True
435
- shapes_var [term ] = shape [lr ]
436
471
else :
437
472
trm_val [lr ], shape [lr ] = term
438
473
trm_str [lr ] = False
@@ -447,6 +482,7 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
447
482
)
448
483
newshape = shape ["R" ]
449
484
if token == "*" :
485
+ # TODO: pomyslec
450
486
newshape = tuple (list (shape ["L" ]) + list (shape ["R" ]))
451
487
452
488
# creating list with keys
@@ -466,7 +502,6 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
466
502
elif trm_str ["R" ]:
467
503
keys = keys + new_keys ["R" ]
468
504
469
- #
470
505
newtrm_val = {}
471
506
for lr in ["R" , "L" ]:
472
507
# TODO: rewrite once I have more tests
@@ -491,13 +526,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
491
526
492
527
pushval = (op [token ](newtrm_val ["L" ], newtrm_val ["R" ]), newshape )
493
528
stack .append (pushval )
494
- else : # name of one of the inputs (token not in [".", "*"])
495
- stack .append (token )
496
529
497
530
val = stack .pop ()
498
531
if isinstance (val , tuple ):
499
532
val = val [0 ]
500
- return val , keys , shapes_var , keys_fromLeftSpl
533
+ return val , keys , keys_fromLeftSpl
501
534
502
535
503
536
# dj: TODO: do I need keys?
@@ -636,17 +669,22 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
636
669
637
670
638
671
def _single_op_splits (
639
- op_single , inputs , inner_inputs , shapes_var , previous_states_ind , keys_fromLeftSpl
672
+ op_single ,
673
+ inputs ,
674
+ inner_inputs ,
675
+ previous_states_ind ,
676
+ keys_fromLeftSpl ,
677
+ cont_dim = None ,
640
678
):
641
679
if op_single .startswith ("_" ):
642
680
return (
643
681
previous_states_ind [op_single ][0 ],
644
682
previous_states_ind [op_single ][1 ],
645
- None ,
646
683
keys_fromLeftSpl ,
647
684
)
648
- shape = input_shape (inputs [op_single ])
649
- shapes_var [op_single ] = shape
685
+ if cont_dim is None :
686
+ cont_dim = {}
687
+ shape = input_shape (inputs [op_single ], cont_dim = cont_dim .get (op_single , 1 ))
650
688
trmval = range (reduce (lambda x , y : x * y , shape ))
651
689
if op_single in inner_inputs :
652
690
# TODO: have to be changed if differ length
@@ -659,11 +697,11 @@ def _single_op_splits(
659
697
res = op ["." ](op_out , trmval )
660
698
val = res
661
699
keys = inner_inputs [op_single ].keys_final + [op_single ]
662
- return val , keys , shapes_var , keys_fromLeftSpl
700
+ return val , keys , keys_fromLeftSpl
663
701
else :
664
702
val = op ["*" ](trmval )
665
703
keys = [op_single ]
666
- return val , keys , shapes_var , keys_fromLeftSpl
704
+ return val , keys , keys_fromLeftSpl
667
705
668
706
669
707
def _single_op_splits_groups (
@@ -727,10 +765,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
727
765
return keys_final , groups_final , groups_stack_final , combiner_all
728
766
729
767
730
- def map_splits (split_iter , inputs ):
768
+ def map_splits (split_iter , inputs , cont_dim = None ):
731
769
"""Get a dictionary of prescribed splits."""
770
+ if cont_dim is None :
771
+ cont_dim = {}
732
772
for split in split_iter :
733
- yield {k : list (flatten (ensure_list (inputs [k ])))[v ] for k , v in split .items ()}
773
+ yield {
774
+ k : list (flatten (ensure_list (inputs [k ]), max_depth = cont_dim .get (k , None )))[v ]
775
+ for k , v in split .items ()
776
+ }
734
777
735
778
736
779
# Functions for merging and completing splitters in states.
0 commit comments