Skip to content

Commit 07bd24c

Browse files
committed
Support lazy loading for detectron2
1 parent 28f789c commit 07bd24c

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"torch",
2828
"torchvision",
2929
"pycocotools",
30+
"fvcore",
3031
],
3132
extras_require={
3233
"ocr": [

src/layoutparser/models/layoutmodel.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from abc import ABC, abstractmethod
22
import os
3-
import torch
4-
from detectron2.config import get_cfg
5-
from detectron2.engine import DefaultPredictor
6-
from ..elements import *
7-
from fvcore.common.file_io import PathManager
3+
import importlib
4+
85
from PIL import Image
96
import numpy as np
7+
import torch
8+
from fvcore.common.file_io import PathManager
9+
#TODO: Update to iopath in the next major release
10+
11+
from ..elements import *
1012

1113
__all__ = ["Detectron2LayoutModel"]
1214

@@ -16,6 +18,41 @@ class BaseLayoutModel(ABC):
1618
def detect(self):
1719
pass
1820

21+
# Add lazy loading mechanisms for layout models, refer to
22+
# layoutparser.ocr.BaseOCRAgent
23+
# TODO: Build a metaclass for lazy module loader
24+
@property
25+
@abstractmethod
26+
def DEPENDENCIES(self):
27+
"""DEPENDENCIES lists all necessary dependencies for the class."""
28+
pass
29+
30+
@property
31+
@abstractmethod
32+
def MODULES(self):
33+
"""MODULES instructs how to import these necessary libraries."""
34+
pass
35+
36+
@classmethod
37+
def _import_module(cls):
38+
for m in cls.MODULES:
39+
if importlib.util.find_spec(m["module_path"]):
40+
setattr(
41+
cls, m["import_name"], importlib.import_module(m["module_path"])
42+
)
43+
else:
44+
raise ModuleNotFoundError(
45+
f"\n "
46+
f"\nPlease install the following libraries to support the class {cls.__name__}:"
47+
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
48+
f"\n "
49+
)
50+
51+
def __new__(cls, *args, **kwargs):
52+
53+
cls._import_module()
54+
return super().__new__(cls)
55+
1956

2057
class Detectron2LayoutModel(BaseLayoutModel):
2158
"""Create a Detectron2-based Layout Detection Model
@@ -45,9 +82,18 @@ class Detectron2LayoutModel(BaseLayoutModel):
4582
4683
"""
4784

85+
DEPENDENCIES = ["detectron2"]
86+
MODULES = [
87+
{
88+
"import_name": "_engine",
89+
"module_path": "detectron2.engine",
90+
},
91+
{"import_name": "_config", "module_path": "detectron2.config"},
92+
]
93+
4894
def __init__(self, config_path, model_path=None, label_map=None, extra_config=[]):
4995

50-
cfg = get_cfg()
96+
cfg = self._config.get_cfg()
5197
config_path = PathManager.get_local_path(config_path)
5298
cfg.merge_from_file(config_path)
5399
cfg.merge_from_list(extra_config)
@@ -83,7 +129,7 @@ def gather_output(self, outputs):
83129
return layout
84130

85131
def _create_model(self):
86-
self.model = DefaultPredictor(self.cfg)
132+
self.model = self._engine.DefaultPredictor(self.cfg)
87133

88134
def detect(self, image):
89135
"""Detect the layout of a given image.

0 commit comments

Comments
 (0)