19
19
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
20
THE SOFTWARE.
21
21
"""
22
+
23
+ import sys
22
24
from typing import (Any , Dict , Mapping , Tuple , TypeAlias , Iterable ,
23
25
FrozenSet , Union , Set , List , Optional , Callable )
24
26
from pytato .array import (Array , InputArgumentBase , DictOfNamedArrays ,
38
40
from immutables import Map
39
41
from pytato .utils import are_shape_components_equal
40
42
43
+ if sys .version >= (3 , 11 ):
44
+ zip_equal = lambda * _args : zip (* _args , strict = True )
45
+ else :
46
+ from more_itertools import zip_equal
47
+
41
48
_ComposedIndirectionT : TypeAlias = Tuple [Array , ...]
42
49
IndexT : TypeAlias = Union [Array , NormalizedSlice ]
43
- IndexStackT : TypeAlias = Tuple [IndexT , ...]
44
50
45
51
46
52
def _is_materialized (expr : Array ) -> bool :
@@ -53,15 +59,15 @@ def _is_materialized(expr: Array) -> bool:
53
59
or bool (expr .tags_of_type (ImplStored )))
54
60
55
61
56
- def _is_trivial_slice (dim : ShapeComponent , slice_ : NormalizedSlice ) -> bool :
62
+ def _is_trivial_slice (dim : ShapeComponent , slice_ : IndexT ) -> bool :
57
63
"""
58
64
Returns *True* only if *slice_* indexes an entire axis of shape *dim* with
59
65
a step of 1.
60
66
"""
61
- return (slice_ .step == 1
67
+ return (isinstance (slice_ , NormalizedSlice )
68
+ and slice_ .step == 1
62
69
and are_shape_components_equal (slice_ .start , 0 )
63
- and are_shape_components_equal (slice_ .stop , dim )
64
- )
70
+ and are_shape_components_equal (slice_ .stop , dim ))
65
71
66
72
67
73
def _take_along_axis (ary : Array , iaxis : int , idxs : IndexStackT ) -> Array :
@@ -427,35 +433,35 @@ class _IndirectionPusher(Mapper):
427
433
428
434
def __init__ (self ) -> None :
429
435
self .get_reordarable_axes = _LegallyAxisReorderingFinder ()
430
- self ._cache : Dict [Tuple [ArrayOrNames , Map [int , IndexStackT ]],
436
+ self ._cache : Dict [Tuple [ArrayOrNames , Map [int , IndexT ]],
431
437
ArrayOrNames ] = {}
432
438
super ().__init__ ()
433
439
434
440
def rec (self , # type: ignore[override]
435
441
expr : MappedT ,
436
- index_stacks : Map [int , IndexStackT ]) -> MappedT :
437
- key = (expr , index_stacks )
442
+ indices : Tuple [IndexT , ...]) -> MappedT :
443
+ assert len (indices ) == expr .ndim
444
+ key = (expr , indices )
438
445
try :
439
446
# type-ignore-reason: parametric mapping types aren't a thing in 'typing'
440
447
return self ._cache [key ] # type: ignore[return-value]
441
448
except KeyError :
442
- result = Mapper .rec (self , expr , index_stacks )
449
+ result = Mapper .rec (self , expr , indices )
443
450
self ._cache [key ] = result
444
451
return result # type: ignore[no-any-return]
445
452
446
453
def __call__ (self , # type: ignore[override]
447
454
expr : MappedT ,
448
- index_stacks : Map [int , IndexStackT ]) -> MappedT :
449
- return self .rec (expr , index_stacks )
455
+ indices : Map [int , IndexT ]) -> MappedT :
456
+ return self .rec (expr , indices )
450
457
451
458
def _map_materialized (self ,
452
459
expr : Array ,
453
- index_stacks : Map [int , IndexStackT ]) -> Array :
454
- result = expr
455
- for iaxis , idxs in index_stacks .items ():
456
- result = _take_along_axis (result , iaxis , idxs )
457
-
458
- return result
460
+ indices : Tuple [IndexT , ...]) -> Array :
461
+ if all (_is_trivial_slice (dim , idx )
462
+ for dim , idx in zip (expr .shape , indices )):
463
+ return expr
464
+ return expr [* indices ]
459
465
460
466
def map_dict_of_named_arrays (self ,
461
467
expr : DictOfNamedArrays ,
@@ -467,9 +473,12 @@ def map_dict_of_named_arrays(self,
467
473
468
474
def map_index_lambda (self ,
469
475
expr : IndexLambda ,
470
- index_stacks : Map [ int , IndexStackT ]
476
+ indices : Tuple [ IndexT , ...],
471
477
) -> Array :
472
478
if _is_materialized (expr ):
479
+ # FIXME: Move this logic to .rec (Why on earth do we need)
480
+ # to copy the damn node???
481
+
473
482
# do not propagate the indexings to the bindings.
474
483
expr = IndexLambda (expr .expr ,
475
484
expr .shape ,
@@ -478,9 +487,13 @@ def map_index_lambda(self,
478
487
for name , bnd in expr .bindings .items ()}),
479
488
expr .var_to_reduction_descr ,
480
489
tags = expr .tags ,
481
- axes = expr .axes ,
482
- )
483
- return self ._map_materialized (expr , index_stacks )
490
+ axes = expr .axes ,)
491
+ return self ._map_materialized (expr , indices )
492
+
493
+ # FIXME:
494
+ # This is the money shot. Over here we need to figure out the index
495
+ # propagation logic.
496
+
484
497
485
498
iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis (expr )
486
499
@@ -886,128 +899,13 @@ def push_axis_indirections_towards_materialized_nodes(expr: MappedT
886
899
) -> MappedT :
887
900
"""
888
901
Returns a copy of *expr* with the indirections propagated closer to the
889
- materialized nodes. We propagate an indirections only if the indirection in
890
- an :class:`~pytato.array.AdvancedIndexInContiguousAxes` or
891
- :class:`~pytato.array.AdvancedIndexInNoncontiguousAxes` is an indirection
892
- over a single axis.
902
+ materialized nodes.
893
903
"""
894
904
mapper = _IndirectionPusher ()
895
905
896
906
return mapper (expr , Map ())
897
907
898
908
899
- def _get_unbroadcasted_axis_in_indirections (
900
- expr : AdvancedIndexInContiguousAxes ) -> Optional [Mapping [int , int ]]:
901
- """
902
- Returns a mapping from the index of an indirection to its *only*
903
- unbroadcasted axis as required by the logic. Returns *None* if no such
904
- mapping exists.
905
- """
906
- from pytato .utils import partition , get_shape_after_broadcasting
907
- adv_indices , _ = partition (lambda i : isinstance (expr .indices [i ],
908
- NormalizedSlice ),
909
- range (expr .array .ndim ))
910
- i_ary_indices = [i_idx
911
- for i_idx , idx in enumerate (expr .indices )
912
- if isinstance (idx , Array )]
913
-
914
- adv_idx_shape = get_shape_after_broadcasting ([expr .indices [i_idx ]
915
- for i_idx in adv_indices ])
916
-
917
- if len (adv_idx_shape ) != len (i_ary_indices ):
918
- return None
919
-
920
- i_adv_out_axis_to_candidate_i_arys : Dict [int , Set [int ]] = {
921
- idim : set ()
922
- for idim , _ in enumerate (adv_idx_shape )
923
- }
924
-
925
- for i_ary_idx in i_ary_indices :
926
- ary = expr .indices [i_ary_idx ]
927
- assert isinstance (ary , Array )
928
- for iadv_out_axis , i_ary_axis in zip (range (len (adv_idx_shape )- 1 , - 1 , - 1 ),
929
- range (ary .ndim - 1 , - 1 , - 1 )):
930
- if are_shape_components_equal (adv_idx_shape [iadv_out_axis ],
931
- ary .shape [i_ary_axis ]):
932
- i_adv_out_axis_to_candidate_i_arys [iadv_out_axis ].add (i_ary_idx )
933
-
934
- from itertools import permutations
935
- # FIXME: O(expr.ndim!) complexity, typically ndim <= 4 so this should be fine.
936
- for guess_i_adv_out_axis_to_i_ary in permutations (range (len (i_ary_indices ))):
937
- if all (i_ary in i_adv_out_axis_to_candidate_i_arys [i_adv_out ]
938
- for i_adv_out , i_ary in enumerate (guess_i_adv_out_axis_to_i_ary )):
939
- # TODO: Return the mapping here...
940
- i_ary_to_unbroadcasted_axis : Dict [int , int ] = {}
941
- for guess_i_adv_out_axis , i_ary_idx in enumerate (
942
- guess_i_adv_out_axis_to_i_ary ):
943
- ary = expr .indices [i_ary_idx ]
944
- assert isinstance (ary , Array )
945
- iunbroadcasted_axis , = [
946
- i_ary_axis
947
- for i_adv_out_axis , i_ary_axis in zip (
948
- range (len (adv_idx_shape )- 1 , - 1 , - 1 ),
949
- range (ary .ndim - 1 , - 1 , - 1 ))
950
- if i_adv_out_axis == guess_i_adv_out_axis
951
- ]
952
- i_ary_to_unbroadcasted_axis [i_ary_idx ] = iunbroadcasted_axis
953
-
954
- return Map (i_ary_to_unbroadcasted_axis )
955
-
956
- return None
957
-
958
-
959
- class MultiAxisIndirectionsDecoupler (CopyMapper ):
960
- def map_contiguous_advanced_index (self ,
961
- expr : AdvancedIndexInContiguousAxes
962
- ) -> Array :
963
- i_ary_idx_to_unbroadcasted_axis = _get_unbroadcasted_axis_in_indirections (
964
- expr )
965
-
966
- if i_ary_idx_to_unbroadcasted_axis is not None :
967
- from pytato .utils import partition
968
- i_adv_indices , _ = partition (lambda idx : isinstance (expr .indices [idx ],
969
- NormalizedSlice ),
970
- range (len (expr .indices )))
971
-
972
- result = self .rec (expr .array )
973
-
974
- for iaxis , idx in enumerate (expr .indices ):
975
- if isinstance (idx , Array ):
976
- from pytato .array import squeeze
977
- axes_to_squeeze = [
978
- idim
979
- for idim in range (expr
980
- .indices [iaxis ] # type: ignore[union-attr]
981
- .ndim )
982
- if idim != i_ary_idx_to_unbroadcasted_axis [iaxis ]]
983
- if axes_to_squeeze :
984
- idx = squeeze (idx , axis = axes_to_squeeze )
985
- if not (isinstance (idx , NormalizedSlice )
986
- and _is_trivial_slice (expr .array .shape [iaxis ], idx )):
987
- result = result [
988
- (slice (None ),) * iaxis + (idx , )] # type: ignore[operator]
989
-
990
- return result
991
- else :
992
- return super ().map_contiguous_advanced_index (expr )
993
-
994
-
995
- def decouple_multi_axis_indirections_into_single_axis_indirections (
996
- expr : MappedT ) -> MappedT :
997
- """
998
- Returns a copy of *expr* with multiple indirections in an
999
- :class:`~pytato.array.AdvancedIndexInContiguousAxes` decoupled as a
1000
- composition of indexing nodes with single-axis indirections.
1001
-
1002
- .. note::
1003
-
1004
- This is a dependency preserving transformation. If a decoupling an
1005
- advanced indexing node is not legal, we leave the node unmodified.
1006
- """
1007
- mapper = MultiAxisIndirectionsDecoupler ()
1008
- return mapper (expr )
1009
-
1010
-
1011
909
# {{{ fold indirection constants
1012
910
1013
911
class _ConstantIndirectionArrayCollector (CombineMapper [FrozenSet [Array ]]):
0 commit comments