Skip to content

Commit d71dc04

Browse files
committed
[no ci] WIP: some progress on generalizing pushing indirections
1 parent 92125af commit d71dc04

File tree

1 file changed

+35
-137
lines changed

1 file changed

+35
-137
lines changed

pytato/transform/indirections.py

Lines changed: 35 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2020
THE SOFTWARE.
2121
"""
22+
23+
import sys
2224
from typing import (Any, Dict, Mapping, Tuple, TypeAlias, Iterable,
2325
FrozenSet, Union, Set, List, Optional, Callable)
2426
from pytato.array import (Array, InputArgumentBase, DictOfNamedArrays,
@@ -38,9 +40,13 @@
3840
from immutables import Map
3941
from pytato.utils import are_shape_components_equal
4042

43+
if sys.version >= (3, 11):
44+
zip_equal = lambda *_args: zip(*_args, strict=True)
45+
else:
46+
from more_itertools import zip_equal
47+
4148
_ComposedIndirectionT: TypeAlias = Tuple[Array, ...]
4249
IndexT: TypeAlias = Union[Array, NormalizedSlice]
43-
IndexStackT: TypeAlias = Tuple[IndexT, ...]
4450

4551

4652
def _is_materialized(expr: Array) -> bool:
@@ -53,15 +59,15 @@ def _is_materialized(expr: Array) -> bool:
5359
or bool(expr.tags_of_type(ImplStored)))
5460

5561

56-
def _is_trivial_slice(dim: ShapeComponent, slice_: NormalizedSlice) -> bool:
62+
def _is_trivial_slice(dim: ShapeComponent, slice_: IndexT) -> bool:
5763
"""
5864
Returns *True* only if *slice_* indexes an entire axis of shape *dim* with
5965
a step of 1.
6066
"""
61-
return (slice_.step == 1
67+
return (isinstance(slice_, NormalizedSlice)
68+
and slice_.step == 1
6269
and are_shape_components_equal(slice_.start, 0)
63-
and are_shape_components_equal(slice_.stop, dim)
64-
)
70+
and are_shape_components_equal(slice_.stop, dim))
6571

6672

6773
def _take_along_axis(ary: Array, iaxis: int, idxs: IndexStackT) -> Array:
@@ -427,35 +433,35 @@ class _IndirectionPusher(Mapper):
427433

428434
def __init__(self) -> None:
429435
self.get_reordarable_axes = _LegallyAxisReorderingFinder()
430-
self._cache: Dict[Tuple[ArrayOrNames, Map[int, IndexStackT]],
436+
self._cache: Dict[Tuple[ArrayOrNames, Map[int, IndexT]],
431437
ArrayOrNames] = {}
432438
super().__init__()
433439

434440
def rec(self, # type: ignore[override]
435441
expr: MappedT,
436-
index_stacks: Map[int, IndexStackT]) -> MappedT:
437-
key = (expr, index_stacks)
442+
indices: Tuple[IndexT, ...]) -> MappedT:
443+
assert len(indices) == expr.ndim
444+
key = (expr, indices)
438445
try:
439446
# type-ignore-reason: parametric mapping types aren't a thing in 'typing'
440447
return self._cache[key] # type: ignore[return-value]
441448
except KeyError:
442-
result = Mapper.rec(self, expr, index_stacks)
449+
result = Mapper.rec(self, expr, indices)
443450
self._cache[key] = result
444451
return result # type: ignore[no-any-return]
445452

446453
def __call__(self, # type: ignore[override]
447454
expr: MappedT,
448-
index_stacks: Map[int, IndexStackT]) -> MappedT:
449-
return self.rec(expr, index_stacks)
455+
indices: Map[int, IndexT]) -> MappedT:
456+
return self.rec(expr, indices)
450457

451458
def _map_materialized(self,
452459
expr: Array,
453-
index_stacks: Map[int, IndexStackT]) -> Array:
454-
result = expr
455-
for iaxis, idxs in index_stacks.items():
456-
result = _take_along_axis(result, iaxis, idxs)
457-
458-
return result
460+
indices: Tuple[IndexT, ...]) -> Array:
461+
if all(_is_trivial_slice(dim, idx)
462+
for dim, idx in zip(expr.shape, indices)):
463+
return expr
464+
return expr[*indices]
459465

460466
def map_dict_of_named_arrays(self,
461467
expr: DictOfNamedArrays,
@@ -467,9 +473,12 @@ def map_dict_of_named_arrays(self,
467473

468474
def map_index_lambda(self,
469475
expr: IndexLambda,
470-
index_stacks: Map[int, IndexStackT]
476+
indices: Tuple[IndexT, ...],
471477
) -> Array:
472478
if _is_materialized(expr):
479+
# FIXME: Move this logic to .rec (Why on earth do we need)
480+
# to copy the damn node???
481+
473482
# do not propagate the indexings to the bindings.
474483
expr = IndexLambda(expr.expr,
475484
expr.shape,
@@ -478,9 +487,13 @@ def map_index_lambda(self,
478487
for name, bnd in expr.bindings.items()}),
479488
expr.var_to_reduction_descr,
480489
tags=expr.tags,
481-
axes=expr.axes,
482-
)
483-
return self._map_materialized(expr, index_stacks)
490+
axes=expr.axes,)
491+
return self._map_materialized(expr, indices)
492+
493+
# FIXME:
494+
# This is the money shot. Over here we need to figure out the index
495+
# propagation logic.
496+
484497

485498
iout_axis_to_bnd_axis = _get_iout_axis_to_binding_axis(expr)
486499

@@ -886,128 +899,13 @@ def push_axis_indirections_towards_materialized_nodes(expr: MappedT
886899
) -> MappedT:
887900
"""
888901
Returns a copy of *expr* with the indirections propagated closer to the
889-
materialized nodes. We propagate an indirections only if the indirection in
890-
an :class:`~pytato.array.AdvancedIndexInContiguousAxes` or
891-
:class:`~pytato.array.AdvancedIndexInNoncontiguousAxes` is an indirection
892-
over a single axis.
902+
materialized nodes.
893903
"""
894904
mapper = _IndirectionPusher()
895905

896906
return mapper(expr, Map())
897907

898908

899-
def _get_unbroadcasted_axis_in_indirections(
900-
expr: AdvancedIndexInContiguousAxes) -> Optional[Mapping[int, int]]:
901-
"""
902-
Returns a mapping from the index of an indirection to its *only*
903-
unbroadcasted axis as required by the logic. Returns *None* if no such
904-
mapping exists.
905-
"""
906-
from pytato.utils import partition, get_shape_after_broadcasting
907-
adv_indices, _ = partition(lambda i: isinstance(expr.indices[i],
908-
NormalizedSlice),
909-
range(expr.array.ndim))
910-
i_ary_indices = [i_idx
911-
for i_idx, idx in enumerate(expr.indices)
912-
if isinstance(idx, Array)]
913-
914-
adv_idx_shape = get_shape_after_broadcasting([expr.indices[i_idx]
915-
for i_idx in adv_indices])
916-
917-
if len(adv_idx_shape) != len(i_ary_indices):
918-
return None
919-
920-
i_adv_out_axis_to_candidate_i_arys: Dict[int, Set[int]] = {
921-
idim: set()
922-
for idim, _ in enumerate(adv_idx_shape)
923-
}
924-
925-
for i_ary_idx in i_ary_indices:
926-
ary = expr.indices[i_ary_idx]
927-
assert isinstance(ary, Array)
928-
for iadv_out_axis, i_ary_axis in zip(range(len(adv_idx_shape)-1, -1, -1),
929-
range(ary.ndim-1, -1, -1)):
930-
if are_shape_components_equal(adv_idx_shape[iadv_out_axis],
931-
ary.shape[i_ary_axis]):
932-
i_adv_out_axis_to_candidate_i_arys[iadv_out_axis].add(i_ary_idx)
933-
934-
from itertools import permutations
935-
# FIXME: O(expr.ndim!) complexity, typically ndim <= 4 so this should be fine.
936-
for guess_i_adv_out_axis_to_i_ary in permutations(range(len(i_ary_indices))):
937-
if all(i_ary in i_adv_out_axis_to_candidate_i_arys[i_adv_out]
938-
for i_adv_out, i_ary in enumerate(guess_i_adv_out_axis_to_i_ary)):
939-
# TODO: Return the mapping here...
940-
i_ary_to_unbroadcasted_axis: Dict[int, int] = {}
941-
for guess_i_adv_out_axis, i_ary_idx in enumerate(
942-
guess_i_adv_out_axis_to_i_ary):
943-
ary = expr.indices[i_ary_idx]
944-
assert isinstance(ary, Array)
945-
iunbroadcasted_axis, = [
946-
i_ary_axis
947-
for i_adv_out_axis, i_ary_axis in zip(
948-
range(len(adv_idx_shape)-1, -1, -1),
949-
range(ary.ndim-1, -1, -1))
950-
if i_adv_out_axis == guess_i_adv_out_axis
951-
]
952-
i_ary_to_unbroadcasted_axis[i_ary_idx] = iunbroadcasted_axis
953-
954-
return Map(i_ary_to_unbroadcasted_axis)
955-
956-
return None
957-
958-
959-
class MultiAxisIndirectionsDecoupler(CopyMapper):
960-
def map_contiguous_advanced_index(self,
961-
expr: AdvancedIndexInContiguousAxes
962-
) -> Array:
963-
i_ary_idx_to_unbroadcasted_axis = _get_unbroadcasted_axis_in_indirections(
964-
expr)
965-
966-
if i_ary_idx_to_unbroadcasted_axis is not None:
967-
from pytato.utils import partition
968-
i_adv_indices, _ = partition(lambda idx: isinstance(expr.indices[idx],
969-
NormalizedSlice),
970-
range(len(expr.indices)))
971-
972-
result = self.rec(expr.array)
973-
974-
for iaxis, idx in enumerate(expr.indices):
975-
if isinstance(idx, Array):
976-
from pytato.array import squeeze
977-
axes_to_squeeze = [
978-
idim
979-
for idim in range(expr
980-
.indices[iaxis] # type: ignore[union-attr]
981-
.ndim)
982-
if idim != i_ary_idx_to_unbroadcasted_axis[iaxis]]
983-
if axes_to_squeeze:
984-
idx = squeeze(idx, axis=axes_to_squeeze)
985-
if not (isinstance(idx, NormalizedSlice)
986-
and _is_trivial_slice(expr.array.shape[iaxis], idx)):
987-
result = result[
988-
(slice(None),) * iaxis + (idx, )] # type: ignore[operator]
989-
990-
return result
991-
else:
992-
return super().map_contiguous_advanced_index(expr)
993-
994-
995-
def decouple_multi_axis_indirections_into_single_axis_indirections(
996-
expr: MappedT) -> MappedT:
997-
"""
998-
Returns a copy of *expr* with multiple indirections in an
999-
:class:`~pytato.array.AdvancedIndexInContiguousAxes` decoupled as a
1000-
composition of indexing nodes with single-axis indirections.
1001-
1002-
.. note::
1003-
1004-
This is a dependency preserving transformation. If a decoupling an
1005-
advanced indexing node is not legal, we leave the node unmodified.
1006-
"""
1007-
mapper = MultiAxisIndirectionsDecoupler()
1008-
return mapper(expr)
1009-
1010-
1011909
# {{{ fold indirection constants
1012910

1013911
class _ConstantIndirectionArrayCollector(CombineMapper[FrozenSet[Array]]):

0 commit comments

Comments
 (0)