11from abc import ABC , abstractmethod
22import os
3- import torch
4- from detectron2 .config import get_cfg
5- from detectron2 .engine import DefaultPredictor
6- from ..elements import *
7- from fvcore .common .file_io import PathManager
3+ import importlib
4+
85from PIL import Image
96import numpy as np
7+ import torch
8+ from fvcore .common .file_io import PathManager
9+ #TODO: Update to iopath in the next major release
10+
11+ from ..elements import *
1012
1113__all__ = ["Detectron2LayoutModel" ]
1214
@@ -16,6 +18,41 @@ class BaseLayoutModel(ABC):
1618 def detect (self ):
1719 pass
1820
21+ # Add lazy loading mechanisms for layout models, refer to
22+ # layoutparser.ocr.BaseOCRAgent
23+ # TODO: Build a metaclass for lazy module loader
24+ @property
25+ @abstractmethod
26+ def DEPENDENCIES (self ):
27+ """DEPENDENCIES lists all necessary dependencies for the class."""
28+ pass
29+
30+ @property
31+ @abstractmethod
32+ def MODULES (self ):
33+ """MODULES instructs how to import these necessary libraries."""
34+ pass
35+
36+ @classmethod
37+ def _import_module (cls ):
38+ for m in cls .MODULES :
39+ if importlib .util .find_spec (m ["module_path" ]):
40+ setattr (
41+ cls , m ["import_name" ], importlib .import_module (m ["module_path" ])
42+ )
43+ else :
44+ raise ModuleNotFoundError (
45+ f"\n "
46+ f"\n Please install the following libraries to support the class { cls .__name__ } :"
47+ f"\n pip install { ' ' .join (cls .DEPENDENCIES )} "
48+ f"\n "
49+ )
50+
51+ def __new__ (cls , * args , ** kwargs ):
52+
53+ cls ._import_module ()
54+ return super ().__new__ (cls )
55+
1956
2057class Detectron2LayoutModel (BaseLayoutModel ):
2158 """Create a Detectron2-based Layout Detection Model
@@ -45,9 +82,18 @@ class Detectron2LayoutModel(BaseLayoutModel):
4582
4683 """
4784
85+ DEPENDENCIES = ["detectron2" ]
86+ MODULES = [
87+ {
88+ "import_name" : "_engine" ,
89+ "module_path" : "detectron2.engine" ,
90+ },
91+ {"import_name" : "_config" , "module_path" : "detectron2.config" },
92+ ]
93+
4894 def __init__ (self , config_path , model_path = None , label_map = None , extra_config = []):
4995
50- cfg = get_cfg ()
96+ cfg = self . _config . get_cfg ()
5197 config_path = PathManager .get_local_path (config_path )
5298 cfg .merge_from_file (config_path )
5399 cfg .merge_from_list (extra_config )
@@ -83,7 +129,7 @@ def gather_output(self, outputs):
83129 return layout
84130
85131 def _create_model (self ):
86- self .model = DefaultPredictor (self .cfg )
132+ self .model = self . _engine . DefaultPredictor (self .cfg )
87133
88134 def detect (self , image ):
89135 """Detect the layout of a given image.
0 commit comments