Skip to content
Open
55 changes: 43 additions & 12 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@
__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"]


def _inverse_one(
t: InvertibleTransform, data: Any, map_items: bool | int, unpack_items: bool, log_stats: bool | str
) -> Any:
"""Invert a single transform, delegating directly to nested ``Compose`` objects.

When ``t`` is a ``Compose`` instance its own ``inverse()`` is called so that
the child's ``map_items`` setting is respected. For all other invertible
transforms, ``apply_transform`` is used with ``lazy=False``.

Args:
t: The invertible transform to invert.
data: Data to be inverted.
map_items: Whether to map over list/tuple items (forwarded to
``apply_transform`` for non-``Compose`` transforms).
unpack_items: Whether to unpack data as parameters.
log_stats: Logger name or boolean for logging.

Returns:
The inverted data.
"""
if isinstance(t, Compose):
return t.inverse(data)
return apply_transform(t.inverse, data, map_items, unpack_items, lazy=False, log_stats=log_stats)


def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
Expand Down Expand Up @@ -315,20 +340,32 @@ def get_index_of_first(self, predicate):
return None

def flatten(self):
"""Return a Composition with a simple list of transforms, as opposed to any nested Compositions.
"""Return a Composition with a flattened list of transforms.

Nested ``Compose`` objects that share the same ``map_items`` setting as
the parent are inlined. Nested ``Compose`` objects with a *different*
``map_items`` value are kept as-is so their item-mapping behaviour is
preserved at runtime and during inversion.

e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()`
will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.

"""
new_transforms = []
for t in self.transforms:
if type(t) is Compose: # nopep8
if type(t) is Compose and t.map_items == self.map_items:
new_transforms += t.flatten().transforms
else:
new_transforms.append(t)

return Compose(new_transforms)
return Compose(
new_transforms,
map_items=self.map_items,
unpack_items=self.unpack_items,
log_stats=self.log_stats,
lazy=self._lazy,
overrides=self.overrides,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __len__(self):
"""Return number of transformations."""
Expand Down Expand Up @@ -365,9 +402,7 @@ def inverse(self, data):
)
# loop backwards over transforms
for t in reversed(invertible_transforms):
data = apply_transform(
t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats
)
data = _inverse_one(t, data, self.map_items, self.unpack_items, self.log_stats)
return data

@staticmethod
Expand Down Expand Up @@ -622,9 +657,7 @@ def inverse(self, data):
# loop backwards over transforms
for o in reversed(applied_order):
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
)
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)
return data


Expand Down Expand Up @@ -789,8 +822,6 @@ def inverse(self, data):
# loop backwards over transforms
for o in reversed(applied_order):
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
)
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)

return data
11 changes: 7 additions & 4 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,13 @@ def apply_transform(
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
# If the transform is a Compose with its own map_items, let it handle list/tuple
# expansion internally so that nested Compose map_items settings are respected.
if not isinstance(transform, transforms.compose.Compose):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down
159 changes: 159 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,165 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
self.assertEqual(expected, actual)


class TestNestedComposeMapItems(unittest.TestCase):
"""Tests for nested Compose respecting child map_items (issues #7932, #7565)."""

def test_child_map_items_false_receives_list(self):
"""Parent map_items=True, child map_items=False: child receives list as-is."""

def split(x):
return [x + 1, x + 2]

def sum_list(items):
return sum(items)

# The child Compose(map_items=False) should receive the list from split()
# and pass it as-is to sum_list, rather than the parent expanding the list.
pipeline = mt.Compose([split, mt.Compose([sum_list], map_items=False)])
result = pipeline(10)
self.assertEqual(result, 23) # (10+1) + (10+2) = 23

def test_inverse_respects_child_map_items(self):
"""Inverse path should delegate to child Compose.inverse directly."""
pipeline = mt.Compose([mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False)])
data = torch.randn(1, 4, 4)
result = pipeline(data)
restored = pipeline.inverse(result)
torch.testing.assert_close(data, restored)

def test_parent_no_map_child_map(self):
"""Parent map_items=False, child map_items=True: child maps over items."""

def double(x):
return x * 2

# Parent treats the list as a single value; child maps double() over each item.
pipeline = mt.Compose([mt.Compose([double], map_items=True)], map_items=False)
result = pipeline([1, 2, 3])
self.assertEqual(result, [2, 4, 6])

def test_flatten_preserves_different_map_items(self):
"""flatten() should not merge a child Compose with different map_items."""

def noop(x):
return x

parent = mt.Compose([noop, mt.Compose([noop, noop], map_items=False), noop])
flat = parent.flatten()
# The inner Compose(map_items=False) should NOT be flattened
self.assertEqual(len(flat.transforms), 3)
self.assertIsInstance(flat.transforms[1], mt.Compose)

def test_multiple_children_with_mixed_map_items(self):
"""Multiple internal Composes with different map_items should be handled correctly."""

def add_one(items):
if isinstance(items, list):
return [x + 1 for x in items]
return items + 1

def multiply_two(items):
if isinstance(items, list):
return [x * 2 for x in items]
return items * 2

# Parent with map_items=False processes the entire input as one unit
# Child 1 (map_items=True) will map over each item in what it receives
# Child 2 (map_items=False) will process the entire thing
pipeline = mt.Compose(
[mt.Compose([add_one], map_items=True), mt.Compose([multiply_two], map_items=False)], map_items=False
)

# Input [1, 2, 3]
# First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
# Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
result = pipeline([1, 2, 3])
self.assertEqual(result, [4, 6, 8])

def test_flatten_with_multiple_children_preserves_both(self):
"""flatten() should preserve child with different map_items but flatten child with same."""

def noop(x):
return x

parent = mt.Compose(
[
noop,
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
mt.Compose([noop, noop], map_items=False), # Different, will be preserved
noop,
]
)
flat = parent.flatten()
# First nested Compose(map_items=True) will be flattened into parent
# Second nested Compose(map_items=False) will be preserved
# Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
self.assertEqual(len(flat.transforms), 5)
# Check that the preserved one is at the correct position
self.assertIsInstance(flat.transforms[3], mt.Compose)
self.assertEqual(flat.transforms[3].map_items, False)

def test_three_level_nesting_respects_different_map_items(self):
"""Three-level nesting with different map_items at each level."""

def add_one(x):
return x + 1

# Level 1 (outermost): map_items=True (default)
# Level 2: map_items=False
# Level 3: map_items=True (same as level 2, so will be flattened into level 2)
innermost = mt.Compose([add_one], map_items=True)
middle = mt.Compose([add_one, innermost], map_items=False)
outer = mt.Compose([middle])

# Test with a simple value
# outer has map_items=True (default), middle has map_items=False
# So middle should be preserved and receive the input as-is
result = outer(5)
# outer(5) -> maps to middle -> middle(5) with map_items=False
# middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
# innermost(6) -> add_one(6) = 7
self.assertEqual(result, 7)

def test_inverse_with_multiple_children_different_map_items(self):
"""Inverse should work correctly with multiple children having different map_items."""
pipeline = mt.Compose(
[mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False), mt.Compose([mt.Flip(0)], map_items=True)]
)
data = torch.randn(2, 4, 4)
result = pipeline(data)
restored = pipeline.inverse(result)
torch.testing.assert_close(data, restored)

def test_flatten_with_mixed_same_and_different_map_items(self):
"""flatten() should merge children with same map_items but preserve those with different."""

def noop(x):
return x

# Parent has map_items=True (default)
# Child 1 has map_items=True (same as parent) -> should be flattened
# Child 2 has map_items=False (different from parent) -> should NOT be flattened
parent = mt.Compose(
[
noop,
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
mt.Compose([noop, noop], map_items=False), # Different from parent, will be preserved
noop,
]
)
flat = parent.flatten()
# After flatten:
# - noop (preserved)
# - 2 noops from first Compose (flattened because map_items=True matches parent)
# - Compose([noop, noop], map_items=False) (preserved because different)
# - noop (preserved)
# Total: 5 transforms
self.assertEqual(len(flat.transforms), 5)
self.assertIsInstance(flat.transforms[3], mt.Compose)
self.assertEqual(flat.transforms[3].map_items, False)


class TestComposeCallableInput(unittest.TestCase):

def test_value_error_when_not_sequence(self):
Expand Down
Loading