diff --git a/python-package/insightface/model_zoo/scrfd.py b/python-package/insightface/model_zoo/scrfd.py index 674db4bba..6558ec24f 100644 --- a/python-package/insightface/model_zoo/scrfd.py +++ b/python-package/insightface/model_zoo/scrfd.py @@ -79,7 +79,10 @@ def __init__(self, model_file=None, session=None): if self.session is None: assert self.model_file is not None assert osp.exists(self.model_file) - self.session = onnxruntime.InferenceSession(self.model_file, None) + if (onnxruntime.get_device() == "GPU"): + self.session = onnxruntime.InferenceSession(self.model_file, None, providers=['CUDAExecutionProvider']) + if (onnxruntime.get_device() == "CPU"): + self.session = onnxruntime.InferenceSession(self.model_file, None, providers=['CPUExecutionProvider']) self.center_cache = {} self.nms_thresh = 0.4 self.det_thresh = 0.5