diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 236d3cc4c5..e15c4e2045 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -47,7 +47,7 @@ def execute_compose( data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], transforms: Sequence[Any], - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, start: int = 0, end: int | None = None, @@ -65,8 +65,13 @@ def execute_compose( Args: data: a tensor-like object to be transformed transforms: a sequence of transforms to be carried out - map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. - defaults to `True`. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. start: the index of the first transform to be executed. If not set, this defaults to 0 @@ -205,8 +210,13 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform): Args: transforms: sequence of callables. - map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. - defaults to `True`. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. @@ -227,7 +237,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform): def __init__( self, transforms: Sequence[Callable] | Callable | None = None, - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = False, @@ -238,9 +248,9 @@ def __init__( if transforms is None: transforms = [] - if not isinstance(map_items, bool): + if not isinstance(map_items, (bool, int)): raise ValueError( - f"Argument 'map_items' should be boolean. Got {type(map_items)}." + f"Argument 'map_items' should be boolean or int. Got {type(map_items)}." "Check brackets when passing a sequence of callables." ) @@ -391,8 +401,13 @@ class OneOf(Compose): transforms: sequence of callables. weights: probabilities corresponding to each callable in transforms. Probabilities are normalized to sum to one. - map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. - defaults to `True`. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. defaults to `False`. log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. @@ -414,7 +429,7 @@ def __init__( self, transforms: Sequence[Callable] | Callable | None = None, weights: Sequence[float] | float | None = None, - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = False, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 15c2499a73..1a365b8d8e 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -101,12 +101,12 @@ def _apply_transform( def apply_transform( transform: Callable[..., ReturnType], data: Any, - map_items: bool = True, + map_items: bool | int = True, unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = None, overrides: dict | None = None, -) -> list[ReturnType] | ReturnType: +) -> list[Any] | ReturnType: """ Transform `data` with `transform`. @@ -117,8 +117,13 @@ def apply_transform( Args: transform: a callable to be used to transform `data`. data: an object to be transformed. - map_items: whether to apply transform to each item in `data`, - if `data` is a list or tuple. Defaults to True. + map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple, + it can behave as follows: + - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied + to the first level of items in `data`. + - If an integer is provided, it specifies the maximum level of nesting to which the transformation + should be recursively applied. This allows treating multi-sample transforms applied after another + multi-sample transform while controlling how deep the mapping goes. unpack_items: whether to unpack parameters using `*`. Defaults to False. log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the @@ -136,8 +141,12 @@ def apply_transform( Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof. """ try: - if isinstance(data, (list, tuple)) and map_items: - return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data] + map_items_ = int(map_items) if isinstance(map_items, bool) else map_items + if isinstance(data, (list, tuple)) and map_items_ > 0: + return [ + apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + for item in data + ] return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint diff --git a/tests/test_compose.py b/tests/test_compose.py index 3c53ac4a22..e6727c976f 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -141,6 +141,20 @@ def b(i, i2): self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected) self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) + def test_list_non_dict_compose_with_unpack_map_2(self): + + def a(i, i2): + return i + "a", i2 + "a2" + + def b(i, i2): + return i + "b", i2 + "b2" + + transforms = [a, b, a, b] + data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]] + expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]] + self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected) + self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected) + def test_list_dict_compose_no_map(self): def a(d): # transform to handle dict data