Skip to content

Enable batch inference for image slices#1298

Open
tomathosauce wants to merge 1 commit intoobss:mainfrom
tomathosauce:batch_implementation1
Open

Enable batch inference for image slices#1298
tomathosauce wants to merge 1 commit intoobss:mainfrom
tomathosauce:batch_implementation1

Conversation

@tomathosauce
Copy link

This PR is my attempt at implementing batch inference over all slices of an image. I decided to work on this because I tried the existing pull requests addressing this feature, but unfortunately they did not work in my case.

My current implementation is still somewhat hacky and it is limited to Ultralytics models, but I made sure to include tests demonstrating that it works as intended. I plan to continue refining it as time permits, and I hope others in the community find it useful.

@golden452
Copy link

golden452 commented Mar 2, 2026

Thanks a lot for your work!

I can't get it working. Could you help look into this issue?

failed tests:

=========================== short test summary info ============================
FAILED tests/test_predict.py::test_get_prediction_automodel_yolo11 - ValueError: Number of predictions (1) and shift_amount_list (2) do not match.
FAILED tests/test_predict.py::test_prediction_category_remapping - ValueError: Number of predictions (1) and shift_amount_list (2) do not match.
FAILED tests/test_predict.py::test_get_prediction_yolo11 - ValueError: Number of predictions (1) and shift_amount_list (2) do not match.
FAILED tests/test_predict.py::test_ultralytics_yolo11n_prediction - ZeroDivisionError: float division by zero
FAILED tests/test_predict.py::test_video_prediction - ZeroDivisionError: float division by zero
FAILED tests/test_huggingface_model.py::test_get_prediction_huggingface - assert 0 == 10
 +  where 0 = len([])
FAILED tests/test_huggingface_model.py::test_get_prediction_automodel_huggingface - assert 0 == 10
 +  where 0 = len([])
FAILED tests/test_huggingface_model.py::test_get_sliced_prediction_huggingface - NotImplementedError
FAILED tests/test_torchvision.py::TestTorchVisionDetectionModel::test_get_prediction_torchvision - assert 0 == 7
 +  where 0 = len([])
FAILED tests/test_torchvision.py::TestTorchVisionDetectionModel::test_get_sliced_prediction_torchvision - NotImplementedError

errors I got: (on both yolo11 and yolo26 models)

---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
Cell In[6], line 3
      1 for image_path in images_to_process:
      2     # get batch sliced prediction
----> 3     prediction_result1 = get_sliced_prediction(
      4         image=image_path,
      5         detection_model=yolo26_detection_model,
      6         slice_height=SLICE_SIZE,
      7         slice_width=SLICE_SIZE,
      8         overlap_height_ratio=OVERLAP,
      9         overlap_width_ratio=OVERLAP,
     10         perform_standard_pred=False,
     11         postprocess_type='GREEDYNMM',
     12         postprocess_match_threshold='0.1',
     13         postprocess_match_metric='IOS',
     14         postprocess_class_agnostic=True,
     15         num_batch=4
     16     )

File [~/Desktop/tomato_sahi/sahi_batch/sahi/predict.py:340](http://localhost:8888/lab/tree/tomato_sahi/sahi_batch/tomato_sahi/sahi_batch/sahi/predict.py#line=339), in get_sliced_prediction(image, detection_model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio, perform_standard_pred, postprocess_type, postprocess_match_metric, postprocess_match_threshold, postprocess_class_agnostic, verbose, merge_buffer_length, auto_slice_resolution, slice_export_prefix, slice_dir, exclude_classes_by_name, exclude_classes_by_id, progress_bar, progress_callback, num_batch)
    337     shift_amount_list.append(slice_image_result.starting_pixels[idx])
    339 # perform batch prediction
--> 340 prediction_result = get_prediction(
    341     image=image_list,
    342     detection_model=detection_model,
    343     shift_amount=shift_amount_list,
    344     full_shape=[
    345         slice_image_result.original_image_height,
    346         slice_image_result.original_image_width,
    347     ],
    348     exclude_classes_by_name=exclude_classes_by_name,
    349     exclude_classes_by_id=exclude_classes_by_id,
    350 )
    352 if isinstance(prediction_result, list):
    353     for prediction in prediction_result:

File [~/Desktop/tomato_sahi/sahi_batch/sahi/predict.py:152](http://localhost:8888/lab/tree/tomato_sahi/sahi_batch/tomato_sahi/sahi_batch/sahi/predict.py#line=151), in get_prediction(image, detection_model, shift_amount, full_shape, postprocess, verbose, exclude_classes_by_name, exclude_classes_by_id)
    149     object_prediction_dict[shift_amount_index].append(obj_preds)
    151 time_end = time.perf_counter() - time_start
--> 152 durations_in_seconds["postprocess"] = time_end [/](http://localhost:8888/) len(object_prediction_dict)
    154 if verbose == 1:
    155     print(
    156         "Prediction performed in",
    157         durations_in_seconds["prediction"],
    158         "seconds.",
    159     )

ZeroDivisionError: float division by zero

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants