Skip to content

Commit e0c509e

Browse files
TingquanGaoBobholamovic
authored andcommitted
use model cache files when network is unavailable (#4676)
1 parent 0af6510 commit e0c509e

File tree

2 files changed

+35
-35
lines changed

2 files changed

+35
-35
lines changed

paddlex/inference/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def create_predictor(
7070

7171
if need_local_model(genai_config):
7272
if model_dir is None:
73-
assert (
74-
model_name in official_models
75-
), f"The model ({model_name}) is not supported! Please using directory of local model files or model name supported by PaddleX!"
7673
model_dir = official_models[model_name]
7774
else:
7875
assert Path(model_dir).exists(), f"{model_dir} is not exists!"

paddlex/inference/utils/official_models.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"ResNet152",
4646
"ResNet152_vd",
4747
"ResNet200_vd",
48-
"PaddleOCR-VL-0.9B",
48+
"PaddleOCR-VL",
4949
"PP-LCNet_x0_25",
5050
"PP-LCNet_x0_25_textline_ori",
5151
"PP-LCNet_x0_35",
@@ -345,7 +345,7 @@
345345
"en_PP-OCRv5_mobile_rec",
346346
"th_PP-OCRv5_mobile_rec",
347347
"el_PP-OCRv5_mobile_rec",
348-
"PaddleOCR-VL-0.9B",
348+
"PaddleOCR-VL",
349349
"PicoDet_layout_1x",
350350
"PicoDet_layout_1x_table",
351351
"PicoDet-L_layout_17cls",
@@ -419,27 +419,15 @@ def get_model(self, model_name):
419419
assert (
420420
model_name in self.model_list
421421
), f"The model {model_name} is not supported on hosting {self.__class__.__name__}!"
422-
if model_name == "PaddleOCR-VL-0.9B":
423-
model_name = "PaddleOCR-VL"
424422

425423
model_dir = self._save_dir / f"{model_name}"
426-
if os.path.exists(model_dir):
427-
logging.info(
428-
f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`."
429-
)
430-
else:
431-
logging.info(
432-
f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
433-
)
434-
self._download(model_name, model_dir)
435-
logging.debug(
436-
f"`{model_name}` model files has been download from model source: `{self.alias}`!"
437-
)
438-
439-
if model_name == "PaddleOCR-VL":
440-
vl_model_dir = model_dir / "PaddleOCR-VL-0.9B"
441-
if vl_model_dir.exists() and vl_model_dir.is_dir():
442-
return vl_model_dir
424+
logging.info(
425+
f"Using official model ({model_name}), the model files will be automatically downloaded and saved in `{model_dir}`."
426+
)
427+
self._download(model_name, model_dir)
428+
logging.debug(
429+
f"`{model_name}` model files has been download from model source: `{self.alias}`!"
430+
)
443431

444432
return model_dir
445433

@@ -573,21 +561,33 @@ def _build_hosters(self):
573561
hosters.append(hoster_cls(self._save_dir))
574562
if len(hosters) == 0:
575563
logging.warning(
576-
f"""No model hoster is available! Please check your network connection to one of the following model hosts:
577-
HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}),
578-
ModelScope ({_ModelScopeModelHoster.healthcheck_url}),
579-
AIStudio ({_AIStudioModelHoster.healthcheck_url}), or
580-
BOS ({_BosModelHoster.healthcheck_url}).
581-
Otherwise, only local models can be used."""
564+
f"No model hoster is available! Please check your network connection to one of the following model hosts: HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}), ModelScope ({_ModelScopeModelHoster.healthcheck_url}), AIStudio ({_AIStudioModelHoster.healthcheck_url}), or BOS ({_BosModelHoster.healthcheck_url}). Otherwise, only local models can be used."
582565
)
583566
return hosters
584567

585568
def _get_model_local_path(self, model_name):
586-
if len(self._hosters) == 0:
587-
msg = "No available model hosting platforms detected. Please check your network connection."
588-
logging.error(msg)
589-
raise Exception(msg)
590-
return self._download_from_hoster(self._hosters, model_name)
569+
if model_name == "PaddleOCR-VL-0.9B":
570+
model_name = "PaddleOCR-VL"
571+
572+
model_dir = self._save_dir / f"{model_name}"
573+
if os.path.exists(model_dir):
574+
logging.info(
575+
f"Model files already exist. Using cached files. To redownload, please delete the directory manually: `{model_dir}`."
576+
)
577+
else:
578+
if len(self._hosters) == 0:
579+
msg = "No available model hosting platforms detected. Please check your network connection."
580+
logging.error(msg)
581+
raise Exception(msg)
582+
583+
model_dir = self._download_from_hoster(self._hosters, model_name)
584+
585+
if model_name == "PaddleOCR-VL":
586+
vl_model_dir = model_dir / "PaddleOCR-VL-0.9B"
587+
if vl_model_dir.exists() and vl_model_dir.is_dir():
588+
return vl_model_dir
589+
590+
return model_dir
591591

592592
def _download_from_hoster(self, hosters, model_name):
593593
for idx, hoster in enumerate(hosters):
@@ -605,6 +605,9 @@ def _download_from_hoster(self, hosters, model_name):
605605
f"Encountering exception when download model from {hoster.alias}: \n{e}, will try to download from other model sources: `{hosters[idx + 1].alias}`."
606606
)
607607
return self._download_from_hoster(hosters[idx + 1 :], model_name)
608+
raise Exception(
609+
f"No model source is available for model `{model_name}`! Please check model name and network, or use local model files!"
610+
)
608611

609612
def __contains__(self, model_name):
610613
return model_name in self.model_list

0 commit comments

Comments
 (0)