Skip to content

added list extend to MultiSampleTrait #8531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,19 @@ def apply_transform(
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0:
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
res: list[Any] = []
for item in data:
res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
# Only extend if we're at the leaf level (map_items_ == 1) and the transform
# actually returned a list (not preserving nested structure)
if isinstance(res_item, list) and map_items_ == 1:
if not isinstance(item, (list, tuple)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within the containing block res_item will always be a list (line 151) and so this if statement's condition is always false. I'm not sure the logic is what you're expecting it to be here?

res.extend(res_item)
else:
res.append(res_item)
else:
res.append(res_item)
return res
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down Expand Up @@ -482,8 +491,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable
yield (key,) + tuple(_ex_iters) if extra_iterables else key
elif not self.allow_missing_keys:
raise KeyError(
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
" and allow_missing_keys==False."
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False."
)

def first_key(self, data: dict[Hashable, Any]):
Expand Down
Loading