Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions paddlex/inference/utils/official_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,34 @@ def is_available(cls):
return False


class _LocalModelHoster(_BaseModelHoster):
model_list = OCR_MODELS
alias = "local"
healthcheck_url = None

def get_model(self, model_name):
assert (
model_name in self.model_list
), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!"
model_dir = self._save_dir / f"{model_name}"
if os.path.exists(model_dir):
logging.info(
f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`."
)
else:
logging.warning(
f"local model ({model_name}) not found in {self._save_dir}. "
)
logging.info(
f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
)
raise FileNotFoundError
return model_dir

def _download(self):
pass


class _BosModelHoster(_BaseModelHoster):
model_list = ALL_MODELS
alias = "bos"
Expand Down Expand Up @@ -551,11 +579,15 @@ def _build_hosters(self):
return hosters

def _get_model_local_path(self, model_name):
if len(self._hosters) == 0:
msg = "No available model hosting platforms detected. Please check your network connection."
logging.error(msg)
raise Exception(msg)
return self._download_from_hoster(self._hosters, model_name)
try:
local_hoster = _LocalModelHoster(self._save_dir)
return local_hoster.get_model(model_name)
except FileNotFoundError:
if len(self._hosters) == 0:
msg = "No available models detected. Please check your network connection."
logging.error(msg)
raise Exception(msg)
return self._download_from_hoster(self._hosters, model_name)

def _download_from_hoster(self, hosters, model_name):
for idx, hoster in enumerate(hosters):
Expand Down