diff --git a/paddlex/inference/utils/official_models.py b/paddlex/inference/utils/official_models.py index d0081cd1f..c38c2638c 100644 --- a/paddlex/inference/utils/official_models.py +++ b/paddlex/inference/utils/official_models.py @@ -434,6 +434,34 @@ def is_available(cls): return False +class _LocalModelHoster(_BaseModelHoster): + model_list = OCR_MODELS + alias = "local" + healthcheck_url = None + + def get_model(self, model_name): + assert ( + model_name in self.model_list + ), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!" + model_dir = self._save_dir / f"{model_name}" + if os.path.exists(model_dir): + logging.info( + f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`." + ) + else: + logging.warning( + f"local model ({model_name}) not found in {self._save_dir}. " + ) + logging.info( + f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`." + ) + raise FileNotFoundError + return model_dir + + def _download(self): + pass + + class _BosModelHoster(_BaseModelHoster): model_list = ALL_MODELS alias = "bos" @@ -551,11 +579,15 @@ def _build_hosters(self): return hosters def _get_model_local_path(self, model_name): - if len(self._hosters) == 0: - msg = "No available model hosting platforms detected. Please check your network connection." - logging.error(msg) - raise Exception(msg) - return self._download_from_hoster(self._hosters, model_name) + try: + local_hoster = _LocalModelHoster(self._save_dir) + return local_hoster.get_model(model_name) + except FileNotFoundError: + if len(self._hosters) == 0: + msg = "No available models detected. Please check your network connection." + logging.error(msg) + raise Exception(msg) + return self._download_from_hoster(self._hosters, model_name) def _download_from_hoster(self, hosters, model_name): for idx, hoster in enumerate(hosters):