diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..d9a16d53e7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -143,10 +143,19 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] + res: list[Any] = [] + for item in data: + res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + # Only extend if we're at the leaf level (map_items_ == 1) and the transform + # actually returned a list (not preserving nested structure) + if isinstance(res_item, list) and map_items_ == 1: + if not isinstance(item, (list, tuple)): + res.extend(res_item) + else: + res.append(res_item) + else: + res.append(res_item) + return res return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint @@ -482,8 +491,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError( - f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" - " and allow_missing_keys==False." + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False." ) def first_key(self, data: dict[Hashable, Any]):