Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
from collections.abc import Iterable
from typing import Set, Union
from typing import Set, Union, List

import numpy as np
import torch
Expand Down Expand Up @@ -133,8 +133,8 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
page_img = Image.fromarray(orig_img).convert("RGB")
else:
raise TypeError("Not supported input image format")

resize = {"height": self._image_size, "width": self._image_size}

inputs = self._image_processor(
images=page_img,
return_tensors="pt",
Expand Down Expand Up @@ -175,3 +175,90 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
"label": label_str,
"confidence": score,
}


@torch.inference_mode()
def predict_batch(self, orig_img: List[Union[Image.Image, np.ndarray]]) -> Iterable[dict]:
"""
Predict bounding boxes for a batch of page images.
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
[left, top, right, bottom]

Parameter
---------
origin_img: List of images to be predicted as a PIL Image object or numpy array.

Yield
-----
Iterable per page of bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"

Raises
------
TypeError when the input image is not supported
"""
# Convert image format
if isinstance(orig_img[0], Image.Image):
page_img = [img.convert("RGB") for img in orig_img]
elif isinstance(orig_img[0], np.ndarray):
page_img = [Image.fromarray(img).convert("RGB") for img in orig_img]
else:
raise TypeError("Not supported input image format")

resize = {"height": self._image_size, "width": self._image_size}
inputs = self._image_processor(
images=page_img,
return_tensors="pt",
size=resize,
).to(self._device)

target_sizes = torch.tensor([page_img[i].size[::-1] for i in range(len(page_img))])

outputs = self._model(**inputs)

results = self._image_processor.post_process_object_detection(
outputs,
target_sizes=target_sizes,
threshold=self._threshold,
)

for batch_item_idx, result in enumerate(results):
w, h = page_img[batch_item_idx].size
yield self.postprocess_result(result, w, h)

def postprocess_result(self, result: dict, w: int, h: int) -> Iterable[dict]:
"""
Postprocess the result of the layout prediction.

Parameters
----------
result: The result of the layout prediction.
w: The width of the image.
h: The height of the image.

Yields
------
Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
"""
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
score = float(score.item())

label_id = int(label_id.item()) + 1 # Advance the label_id
label_str = self._classes_map[label_id]

# Filter out blacklisted classes
if label_str in self._black_classes:
continue

bbox_float = [float(b.item()) for b in box]
l = min(w, max(0, bbox_float[0]))
t = min(h, max(0, bbox_float[1]))
r = min(w, max(0, bbox_float[2]))
b = min(h, max(0, bbox_float[3]))
yield {
"l": l,
"t": t,
"r": r,
"b": b,
"label": label_str,
"confidence": score,
}