Skip to content

Commit b83e558

Browse files
committed
cat set env ONNX_PROVIDER
1 parent d6f2f5e commit b83e558

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

onnx/test/onnx/onnx_load.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import onnxruntime
44
from misc.config import ONNX_FP
55
from os.path import join
6+
import os
67
from onnxruntime import get_all_providers
78

89
session = onnxruntime.SessionOptions()
@@ -12,10 +13,12 @@
1213

1314
def onnx_load(kind):
1415
fp = join(ONNX_FP, f'{kind}.onnx')
15-
16+
# 可以在 FlagAI/onnx/.env 中设置环境变量 ONNX_PROVIDER 避免 UserWarning Specified provider xxx is not in available
17+
provider = os.getenv('ONNX_PROVIDER')
18+
providers = [provider] if provider else get_all_providers()
1619
sess = onnxruntime.InferenceSession(fp,
1720
sess_options=session,
18-
providers=get_all_providers())
21+
providers=providers)
1922
return sess
2023

2124

0 commit comments

Comments
 (0)