Skip to content

Commit 73721f6

Browse files
committed
support to bypass model source availablity check by PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK
1 parent b4304fb commit 73721f6

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

docs/FAQ.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ A:可以:
5454
2. 设置全局预训练模型缓存路径,例如:`paddlex.pretrain_dir='/usrname/paddlex'`,已下载模型将不会重复下载。
5555

5656

57+
## <b>Q:每次导入`paddlex`都会卡住一会,为什么?</b>
58+
59+
1. 因为每次启动,`paddlex`会默认自动测试模型托管平台的网络联通性(包括huggingface、aistudio、modelscope),以确定后续自动下载模型时选择哪个平台;
60+
2. 如果确定使用本地模型,不需要测试检查,可以设置环境变量`PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK=1`来禁用;
61+
5762

5863
## <b>Q:当我在使用PaddleX的过程中遇到问题,应该怎样反馈呢?</b>
5964

paddlex/inference/utils/official_models.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ...utils import logging
3232
from ...utils.cache import CACHE_DIR
3333
from ...utils.download import download_and_extract
34-
from ...utils.flags import MODEL_SOURCE
34+
from ...utils.flags import DISABLE_MODEL_SOURCE_CHECK, MODEL_SOURCE
3535

3636
ALL_MODELS = [
3737
"ResNet18",
@@ -541,18 +541,35 @@ def _clone(local_dir):
541541
class _ModelManager:
542542
model_list = ALL_MODELS
543543
_save_dir = Path(CACHE_DIR) / "official_models"
544+
hoster_candidates = [
545+
_HuggingFaceModelHoster,
546+
_AIStudioModelHoster,
547+
_ModelScopeModelHoster,
548+
_BosModelHoster,
549+
]
544550

545551
def __init__(self) -> None:
546552
self._hosters = self._build_hosters()
547553

548554
def _build_hosters(self):
555+
556+
if DISABLE_MODEL_SOURCE_CHECK:
557+
logging.warning(
558+
f"Connectivity check to the model hoster has been skipped because `DISABLE_MODEL_SOURCE_CHECK` is enabled."
559+
)
560+
hosters = []
561+
for hoster_cls in self.hoster_candidates:
562+
if hoster_cls.alias == MODEL_SOURCE:
563+
hosters.insert(0, hoster_cls(self._save_dir))
564+
else:
565+
hosters.append(hoster_cls(self._save_dir))
566+
return hosters
567+
568+
logging.warning(
569+
f"Checking connectivity to the model hosters, this may take a while. To bypass this check, set `DISABLE_MODEL_SOURCE_CHECK` to `True`."
570+
)
549571
hosters = []
550-
for hoster_cls in [
551-
_HuggingFaceModelHoster,
552-
_AIStudioModelHoster,
553-
_ModelScopeModelHoster,
554-
_BosModelHoster,
555-
]:
572+
for hoster_cls in self.hoster_candidates:
556573
if hoster_cls.alias == MODEL_SOURCE:
557574
if hoster_cls.is_available():
558575
hosters.insert(0, hoster_cls(self._save_dir))
@@ -561,7 +578,7 @@ def _build_hosters(self):
561578
hosters.append(hoster_cls(self._save_dir))
562579
if len(hosters) == 0:
563580
logging.warning(
564-
f"No model hoster is available! Please check your network connection to one of the following model hosts: HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}), ModelScope ({_ModelScopeModelHoster.healthcheck_url}), AIStudio ({_AIStudioModelHoster.healthcheck_url}), or BOS ({_BosModelHoster.healthcheck_url}). Otherwise, only local models can be used."
581+
f"No model hoster is available! Please check your network connection to one of the following model hoster: HuggingFace ({_HuggingFaceModelHoster.healthcheck_url}), ModelScope ({_ModelScopeModelHoster.healthcheck_url}), AIStudio ({_AIStudioModelHoster.healthcheck_url}), or BOS ({_BosModelHoster.healthcheck_url}). Otherwise, only local models can be used."
565582
)
566583
return hosters
567584

paddlex/utils/flags.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def get_flag_from_env_var(name, default, format_func=str):
6666
)
6767

6868
MODEL_SOURCE = os.environ.get("PADDLE_PDX_MODEL_SOURCE", "huggingface").lower()
69+
DISABLE_MODEL_SOURCE_CHECK = os.environ.get(
70+
"PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK", False
71+
)
6972

7073

7174
# Inference Benchmark

0 commit comments

Comments
 (0)