Skip to content

Commit 366df44

Browse files
authored
Fix using a non-tuple sequence for multidimensional indexing that wil… (#8732)
Fix using a non-tuple sequence for multidimensional indexing that will result in error in pytorch 2.9. Fixes # . ### Description The `SlidingWindowInferer` complains that from PyTorch 2.9, using non-tuple sequences to slice into Tensors will result in an error (or, worse, misbehave). This patch targets all places in `monai/inferers/utils.py` where indexing may happen with lists or other types of sequences. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Aaron Ponti <aaron@aaronponti.ch>
1 parent 9ddd5e6 commit 366df44

File tree

2 files changed

+108
-8
lines changed

2 files changed

+108
-8
lines changed

monai/inferers/utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,19 @@ def sliding_window_inference(
243243
for idx in slice_range
244244
]
245245
if sw_batch_size > 1:
246-
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
246+
win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device)
247247
if condition is not None:
248-
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
248+
win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(
249+
sw_device
250+
)
249251
kwargs["condition"] = win_condition
250252
else:
251-
win_data = inputs[unravel_slice[0]].to(sw_device)
253+
s0 = unravel_slice[0]
254+
s0_idx = ensure_tuple(s0)
255+
256+
win_data = inputs[s0_idx].to(sw_device)
252257
if condition is not None:
253-
win_condition = condition[unravel_slice[0]].to(sw_device)
258+
win_condition = condition[s0_idx].to(sw_device)
254259
kwargs["condition"] = win_condition
255260

256261
if with_coord:
@@ -277,7 +282,7 @@ def sliding_window_inference(
277282
offset = s[buffer_dim + 2].start - c_start
278283
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
279284
s[0] = slice(0, 1)
280-
sw_device_buffer[0][s] += p * w_t
285+
sw_device_buffer[0][ensure_tuple(s)] += p * w_t
281286
b_i += len(unravel_slice)
282287
if b_i < b_slices[b_s][0]:
283288
continue
@@ -308,10 +313,11 @@ def sliding_window_inference(
308313
o_slice[buffer_dim + 2] = slice(c_start, c_end)
309314
img_b = b_s // n_per_batch # image batch index
310315
o_slice[0] = slice(img_b, img_b + 1)
316+
o_slice_idx = ensure_tuple(o_slice)
311317
if non_blocking:
312-
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
318+
output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking)
313319
else:
314-
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
320+
output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device)
315321
else:
316322
sw_device_buffer[ss] *= w_t
317323
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
@@ -387,7 +393,7 @@ def _compute_coords(coords, z_scale, out, patch):
387393
idx_zm[axis] = slice(
388394
int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2])
389395
)
390-
out[idx_zm] += p
396+
out[ensure_tuple(idx_zm)] += p
391397

392398

393399
def _get_scan_interval(

tests/inferers/test_sliding_window_inference.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from monai.data.utils import list_data_collate
2222
from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference
23+
from monai.inferers.utils import _compute_coords
2324
from monai.utils import optional_import
2425
from tests.test_utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick
2526

@@ -704,6 +705,99 @@ def compute_dict(data, condition):
704705
for rr, _ in zip(result_dict, expected_dict):
705706
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)
706707

708+
@parameterized.expand([(1,), (4,)])
709+
def test_conditioned_branches_and_buffered_parity(self, sw_batch_size):
710+
"""Validate conditioned parity between buffered and non-buffered flows.
711+
712+
Args:
713+
sw_batch_size (int): Sliding-window batch size.
714+
715+
Returns:
716+
None.
717+
718+
Raises:
719+
AssertionError: If device, conditioning alignment, or output parity checks fail.
720+
"""
721+
inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8)
722+
condition = inputs + 100.0
723+
roi_shape = (4, 4)
724+
725+
def compute(data, condition):
726+
"""Compute output for a conditioned patch.
727+
728+
Args:
729+
data (torch.Tensor): Input patch tensor.
730+
condition (torch.Tensor): Conditioning patch tensor aligned to ``data``.
731+
732+
Returns:
733+
torch.Tensor: Element-wise ``data + condition``.
734+
735+
Raises:
736+
AssertionError: If device placement or conditioning alignment checks fail.
737+
"""
738+
self.assertEqual(data.device.type, "cpu")
739+
self.assertEqual(condition.device.type, "cpu")
740+
torch.testing.assert_close(condition - data, torch.full_like(data, 100.0))
741+
return data + condition
742+
743+
# Non-buffered flow.
744+
result_non_buffered = sliding_window_inference(
745+
inputs, roi_shape, sw_batch_size, compute, overlap=0.5, mode="constant", condition=condition
746+
)
747+
# Buffered flow; should match the non-buffered output.
748+
result_buffered = sliding_window_inference(
749+
inputs,
750+
roi_shape,
751+
sw_batch_size,
752+
compute,
753+
overlap=0.5,
754+
mode="constant",
755+
condition=condition,
756+
buffer_steps=2,
757+
buffer_dim=0,
758+
)
759+
760+
expected = inputs + condition
761+
torch.testing.assert_close(result_non_buffered, expected)
762+
torch.testing.assert_close(result_buffered, expected)
763+
torch.testing.assert_close(result_buffered, result_non_buffered)
764+
765+
766+
class TestSlidingWindowUtils(unittest.TestCase):
767+
"""Tests for low-level sliding-window utility helpers.
768+
769+
Args:
770+
None.
771+
772+
Returns:
773+
None.
774+
775+
Raises:
776+
None.
777+
"""
778+
779+
def test_compute_coords_accepts_list_indices(self):
780+
"""Ensure ``_compute_coords`` handles list-based index containers.
781+
782+
Args:
783+
None.
784+
785+
Returns:
786+
None.
787+
788+
Raises:
789+
AssertionError: If computed output placement differs from expected placement.
790+
"""
791+
out = torch.zeros((1, 1, 12, 12), dtype=torch.float)
792+
patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4)
793+
coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]]
794+
795+
_compute_coords(coords=coords, z_scale=[2.0, 2.0], out=out, patch=patch)
796+
797+
expected = torch.zeros_like(out)
798+
expected[0, 0, 2:6, 4:8] = patch[0, 0]
799+
torch.testing.assert_close(out, expected)
800+
707801

708802
if __name__ == "__main__":
709803
unittest.main()

0 commit comments

Comments
 (0)