diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 95653ffbd4..6c3fc104c9 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -37,6 +37,31 @@ __all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"] +def _inverse_one( + t: InvertibleTransform, data: Any, map_items: bool | int, unpack_items: bool, log_stats: bool | str +) -> Any: + """Invert a single transform, delegating directly to nested ``Compose`` objects. + + When ``t`` is a ``Compose`` instance its own ``inverse()`` is called so that + the child's ``map_items`` setting is respected. For all other invertible + transforms, ``apply_transform`` is used with ``lazy=False``. + + Args: + t: The invertible transform to invert. + data: Data to be inverted. + map_items: Whether to map over list/tuple items (forwarded to + ``apply_transform`` for non-``Compose`` transforms). + unpack_items: Whether to unpack data as parameters. + log_stats: Logger name or boolean for logging. + + Returns: + The inverted data. + """ + if isinstance(t, Compose): + return t.inverse(data) + return apply_transform(t.inverse, data, map_items, unpack_items, lazy=False, log_stats=log_stats) + + def execute_compose( data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], transforms: Sequence[Any], @@ -315,7 +340,12 @@ def get_index_of_first(self, predicate): return None def flatten(self): - """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. + """Return a Composition with a flattened list of transforms. + + Nested ``Compose`` objects that share the same ``map_items`` setting as + the parent are inlined. Nested ``Compose`` objects with a *different* + ``map_items`` value are kept as-is so their item-mapping behaviour is + preserved at runtime and during inversion. e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()` will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`. @@ -323,12 +353,19 @@ def flatten(self): """ new_transforms = [] for t in self.transforms: - if type(t) is Compose: # nopep8 + if type(t) is Compose and t.map_items == self.map_items: new_transforms += t.flatten().transforms else: new_transforms.append(t) - return Compose(new_transforms) + return Compose( + new_transforms, + map_items=self.map_items, + unpack_items=self.unpack_items, + log_stats=self.log_stats, + lazy=self._lazy, + overrides=self.overrides, + ) def __len__(self): """Return number of transformations.""" @@ -365,9 +402,7 @@ def inverse(self, data): ) # loop backwards over transforms for t in reversed(invertible_transforms): - data = apply_transform( - t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats - ) + data = _inverse_one(t, data, self.map_items, self.unpack_items, self.log_stats) return data @staticmethod @@ -622,9 +657,7 @@ def inverse(self, data): # loop backwards over transforms for o in reversed(applied_order): if isinstance(self.transforms[o], InvertibleTransform): - data = apply_transform( - self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats - ) + data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats) return data @@ -789,8 +822,6 @@ def inverse(self, data): # loop backwards over transforms for o in reversed(applied_order): if isinstance(self.transforms[o], InvertibleTransform): - data = apply_transform( - self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats - ) + data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats) return data diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1eedc7c333..40f95d47d6 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -143,10 +143,13 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait): - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] + # If the transform is a Compose with its own map_items, let it handle list/tuple + # expansion internally so that nested Compose map_items settings are respected. + if not isinstance(transform, transforms.compose.Compose): + return [ + apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + for item in data + ] return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index 96c6d4606f..0727fb5633 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -775,6 +775,165 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): self.assertEqual(expected, actual) +class TestNestedComposeMapItems(unittest.TestCase): + """Tests for nested Compose respecting child map_items (issues #7932, #7565).""" + + def test_child_map_items_false_receives_list(self): + """Parent map_items=True, child map_items=False: child receives list as-is.""" + + def split(x): + return [x + 1, x + 2] + + def sum_list(items): + return sum(items) + + # The child Compose(map_items=False) should receive the list from split() + # and pass it as-is to sum_list, rather than the parent expanding the list. + pipeline = mt.Compose([split, mt.Compose([sum_list], map_items=False)]) + result = pipeline(10) + self.assertEqual(result, 23) # (10+1) + (10+2) = 23 + + def test_inverse_respects_child_map_items(self): + """Inverse path should delegate to child Compose.inverse directly.""" + pipeline = mt.Compose([mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False)]) + data = torch.randn(1, 4, 4) + result = pipeline(data) + restored = pipeline.inverse(result) + torch.testing.assert_close(data, restored) + + def test_parent_no_map_child_map(self): + """Parent map_items=False, child map_items=True: child maps over items.""" + + def double(x): + return x * 2 + + # Parent treats the list as a single value; child maps double() over each item. + pipeline = mt.Compose([mt.Compose([double], map_items=True)], map_items=False) + result = pipeline([1, 2, 3]) + self.assertEqual(result, [2, 4, 6]) + + def test_flatten_preserves_different_map_items(self): + """flatten() should not merge a child Compose with different map_items.""" + + def noop(x): + return x + + parent = mt.Compose([noop, mt.Compose([noop, noop], map_items=False), noop]) + flat = parent.flatten() + # The inner Compose(map_items=False) should NOT be flattened + self.assertEqual(len(flat.transforms), 3) + self.assertIsInstance(flat.transforms[1], mt.Compose) + + def test_multiple_children_with_mixed_map_items(self): + """Multiple internal Composes with different map_items should be handled correctly.""" + + def add_one(items): + if isinstance(items, list): + return [x + 1 for x in items] + return items + 1 + + def multiply_two(items): + if isinstance(items, list): + return [x * 2 for x in items] + return items * 2 + + # Parent with map_items=False processes the entire input as one unit + # Child 1 (map_items=True) will map over each item in what it receives + # Child 2 (map_items=False) will process the entire thing + pipeline = mt.Compose( + [mt.Compose([add_one], map_items=True), mt.Compose([multiply_two], map_items=False)], map_items=False + ) + + # Input [1, 2, 3] + # First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4] + # Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8] + result = pipeline([1, 2, 3]) + self.assertEqual(result, [4, 6, 8]) + + def test_flatten_with_multiple_children_preserves_both(self): + """flatten() should preserve child with different map_items but flatten child with same.""" + + def noop(x): + return x + + parent = mt.Compose( + [ + noop, + mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened + mt.Compose([noop, noop], map_items=False), # Different, will be preserved + noop, + ] + ) + flat = parent.flatten() + # First nested Compose(map_items=True) will be flattened into parent + # Second nested Compose(map_items=False) will be preserved + # Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms + self.assertEqual(len(flat.transforms), 5) + # Check that the preserved one is at the correct position + self.assertIsInstance(flat.transforms[3], mt.Compose) + self.assertEqual(flat.transforms[3].map_items, False) + + def test_three_level_nesting_respects_different_map_items(self): + """Three-level nesting with different map_items at each level.""" + + def add_one(x): + return x + 1 + + # Level 1 (outermost): map_items=True (default) + # Level 2: map_items=False + # Level 3: map_items=True (same as level 2, so will be flattened into level 2) + innermost = mt.Compose([add_one], map_items=True) + middle = mt.Compose([add_one, innermost], map_items=False) + outer = mt.Compose([middle]) + + # Test with a simple value + # outer has map_items=True (default), middle has map_items=False + # So middle should be preserved and receive the input as-is + result = outer(5) + # outer(5) -> maps to middle -> middle(5) with map_items=False + # middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True + # innermost(6) -> add_one(6) = 7 + self.assertEqual(result, 7) + + def test_inverse_with_multiple_children_different_map_items(self): + """Inverse should work correctly with multiple children having different map_items.""" + pipeline = mt.Compose( + [mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False), mt.Compose([mt.Flip(0)], map_items=True)] + ) + data = torch.randn(2, 4, 4) + result = pipeline(data) + restored = pipeline.inverse(result) + torch.testing.assert_close(data, restored) + + def test_flatten_with_mixed_same_and_different_map_items(self): + """flatten() should merge children with same map_items but preserve those with different.""" + + def noop(x): + return x + + # Parent has map_items=True (default) + # Child 1 has map_items=True (same as parent) -> should be flattened + # Child 2 has map_items=False (different from parent) -> should NOT be flattened + parent = mt.Compose( + [ + noop, + mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened + mt.Compose([noop, noop], map_items=False), # Different from parent, will be preserved + noop, + ] + ) + flat = parent.flatten() + # After flatten: + # - noop (preserved) + # - 2 noops from first Compose (flattened because map_items=True matches parent) + # - Compose([noop, noop], map_items=False) (preserved because different) + # - noop (preserved) + # Total: 5 transforms + self.assertEqual(len(flat.transforms), 5) + self.assertIsInstance(flat.transforms[3], mt.Compose) + self.assertEqual(flat.transforms[3].map_items, False) + + class TestComposeCallableInput(unittest.TestCase): def test_value_error_when_not_sequence(self):