diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index 6eae3c4..40301da 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -712,6 +712,11 @@ ModelState::LoadModel( values.push_back(value); } } + + // assign correct GPU to EP + keys.push_back(std::string("device_id")); + values.push_back(std::to_string(instance_group_device_id)); + std::vector c_keys, c_values; if (!keys.empty() && !values.empty()) { for (size_t i = 0; i < keys.size(); ++i) {