Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ repos:
pass_filenames: false
language: system
files: '\.py$'
- id: poetry
name: Poetry check
entry: poetry lock --check
pass_filenames: false
language: system
# - id: poetry
# name: Poetry check
# entry: poetry lock --check
# pass_filenames: false
# language: system
- id: system
name: MyPy
entry: poetry run mypy docling_ibm_models
Expand Down
41 changes: 20 additions & 21 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from transformers import DFineForObjectDetection, RTDetrImageProcessor

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,24 +44,23 @@ def __init__(
"""
# Initialize classes map:
self._classes_map = {
0: "background",
1: "Caption",
2: "Footnote",
3: "Formula",
4: "List-item",
5: "Page-footer",
6: "Page-header",
7: "Picture",
8: "Section-header",
9: "Table",
10: "Text",
11: "Title",
12: "Document Index",
13: "Code",
14: "Checkbox-Selected",
15: "Checkbox-Unselected",
16: "Form",
17: "Key-Value Region",
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
11: "Document Index",
12: "Code",
13: "Checkbox-Selected",
14: "Checkbox-Unselected",
15: "Form",
16: "Key-Value Region",
}

# Blacklisted classes
Expand All @@ -87,7 +86,7 @@ def __init__(
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
model_config = os.path.join(artifact_path, "config.json")
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
self._model = RTDetrForObjectDetection.from_pretrained(
self._model = DFineForObjectDetection.from_pretrained(
artifact_path, config=model_config
).to(self._device)
self._model.eval()
Expand Down Expand Up @@ -155,7 +154,7 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
):
score = float(score.item())

label_id = int(label_id.item()) + 1 # Advance the label_id
label_id = int(label_id.item())
label_str = self._classes_map[label_id]

# Filter out blacklisted classes
Expand Down
Loading
Loading