|
2 | 2 | #include <vector> |
3 | 3 | #include <string> |
4 | 4 | #include <sstream> |
| 5 | +#include <utility> |
5 | 6 | #include <unistd.h> |
6 | 7 |
|
7 | 8 | #include <nvml.h> |
@@ -82,44 +83,52 @@ vector<int> readCudaVisibleDevices() { |
82 | 83 | return res; |
83 | 84 | } |
84 | 85 |
|
85 | | -vector<int> getAvailableDevices() { |
| 86 | +pair<vector<int>, bool> getAvailableDevices() { |
86 | 87 | vector<int> visibleDevices = readCudaVisibleDevices(); |
87 | 88 |
|
| 89 | + // cout << ">>>>> visibleDevices.empty()" << visibleDevices.empty() << endl; |
| 90 | + |
88 | 91 | if (visibleDevices.empty()) |
89 | | - return getAllPhysicallyAvailableDevices(); |
| 92 | + return make_pair(getAllPhysicallyAvailableDevices(), false); |
90 | 93 |
|
91 | | - return visibleDevices; |
| 94 | + return make_pair(visibleDevices, true); |
92 | 95 | } |
93 | 96 |
|
94 | 97 | extern "C" |
95 | | -int occupyDevices(int requestedDevicesCount, int * occupiedDevicesIdxs, char * errorMsg) { |
| 98 | +int occupyDevices(int requestedDevicesCount, int * occupiedDevicesIdxs, char * errorMsgOut) { |
96 | 99 | try { |
97 | | - vector<int> availableDevcices = getAvailableDevices(); |
| 100 | + auto availableDevcicesPair = getAvailableDevices(); |
| 101 | + auto availableDevices = availableDevcicesPair.first; |
| 102 | + bool cudaVisibleDevciesSetProperly = availableDevcicesPair.second; |
98 | 103 |
|
99 | | - if ((int)availableDevcices.size() < requestedDevicesCount) { |
| 104 | + if ((int)availableDevices.size() < requestedDevicesCount) { |
100 | 105 | string msg = "There are not as many free devices as requested. Requested devices count: " |
101 | | - + to_string(requestedDevicesCount) + ". Available devices count: " + to_string(availableDevcices.size()) + "."; |
| 106 | + + to_string(requestedDevicesCount) + ". Available devices count: " + to_string(availableDevices.size()) + "."; |
102 | 107 |
|
103 | | - memcpy(errorMsg, msg.c_str(), msg.length()); |
| 108 | + memcpy(errorMsgOut, msg.c_str(), msg.length()); |
104 | 109 |
|
105 | 110 | return -1; |
106 | 111 | } |
107 | 112 |
|
108 | 113 | int nextDeviceIdx = 0; |
109 | 114 | for (int i = 0; i < requestedDevicesCount; i++) { |
110 | | - int deviceIdx = availableDevcices[i]; |
111 | | - gpuErrchk( cudaSetDevice(i), deviceIdx, errorMsg ); |
| 115 | + int deviceIdx = -1; |
| 116 | + if (cudaVisibleDevciesSetProperly) |
| 117 | + deviceIdx = i; |
| 118 | + else |
| 119 | + deviceIdx = availableDevices[i]; |
| 120 | + gpuErrchk( cudaSetDevice(deviceIdx), deviceIdx, errorMsgOut ); |
112 | 121 |
|
113 | 122 | //call some API functions to really occupy device (I'm lazy to look for more elegant way to do it) |
114 | 123 | char * ddata; |
115 | | - gpuErrchk( cudaMalloc(&ddata, 1), deviceIdx, errorMsg ); |
116 | | - gpuErrchk( cudaFree(ddata), deviceIdx, errorMsg ); |
| 124 | + gpuErrchk( cudaMalloc(&ddata, 1), deviceIdx, errorMsgOut ); |
| 125 | + gpuErrchk( cudaFree(ddata), deviceIdx, errorMsgOut ); |
117 | 126 |
|
118 | | - occupiedDevicesIdxs[nextDeviceIdx++] = i; |
| 127 | + occupiedDevicesIdxs[nextDeviceIdx++] = deviceIdx; |
119 | 128 | } |
120 | 129 | } catch (const std::exception& e) { |
121 | 130 | auto msg = string(e.what()); |
122 | | - memcpy(errorMsg, msg.c_str(), msg.length()); |
| 131 | + memcpy(errorMsgOut, msg.c_str(), msg.length()); |
123 | 132 | return -1; |
124 | 133 | } |
125 | 134 |
|
|
0 commit comments