Skip to content

Commit e9eab52

Browse files
authored
Nccl timeout (#2673)
* print the error & delete nccl group * timeout for nccl binding * typo * revert error * fixed a typo
1 parent 36ca62d commit e9eab52

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

mlx/distributed/nccl/nccl.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
namespace mlx::core::distributed::nccl {
2323

24+
// Can be tuned with MLX_NCCL_TIMEOUT
25+
constexpr int nccl_timeout = 300000; // miliseconds
26+
2427
#define CHECK_CUDA(cmd) \
2528
do { \
2629
cudaError_t e = cmd; \
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
181184
close(sock);
182185

183186
} else {
184-
// Here just wanted to make show that rank 0 has enough time to bind
185-
// so we will retry to connect until max attempts
187+
// Here we want to make sure that rank 0 has enough time to bind
188+
// so we will retry to connect until elapsed time exceeds nccl_timeout
189+
// this is particularity important for multinode setup
186190

187191
int sock = socket(AF_INET, SOCK_STREAM, 0);
188192
if (sock < 0) {
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
200204
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
201205
serv.sin_port = htons(port);
202206

203-
const int max_retries = 30;
204-
int attempt = 0;
207+
const int timeout_ms = env::nccl_timeout(nccl_timeout);
205208
bool connected = false;
206209

207-
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
208-
for (attempt = 0; attempt < max_retries; ++attempt) {
210+
const char* dbg = std::getenv("NCCL_DEBUG");
211+
bool do_log = (dbg && std::string(dbg) == "INFO");
212+
213+
auto start = std::chrono::steady_clock::now();
214+
int attempt = 0;
215+
216+
while (true) {
217+
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
218+
std::chrono::steady_clock::now() - start)
219+
.count();
220+
if (elapsed_ms > timeout_ms)
221+
break;
209222
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
210223
0) {
211224
connected = true;
212225
if (do_log) {
213-
std::cout << "[Rank " << rank
214-
<< "] Connected successfully on attempt " << attempt + 1
215-
<< std::endl;
226+
std::cout << "[Rank " << rank << "] Connected successfully after "
227+
<< elapsed_ms << " miliseconds" << std::endl;
216228
break;
217229
}
218230
}
219231
if (errno != ECONNREFUSED) {
220232
break;
221233
}
234+
++attempt;
222235
std::this_thread::sleep_for(std::chrono::milliseconds(500));
223236
}
224237

225238
if (!connected) {
226239
std::ostringstream msg;
227-
msg << "[Rank " << rank << "] connect() failed after " << attempt
228-
<< " retries: " << strerror(errno);
240+
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
241+
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
229242
close(sock);
230243
throw std::runtime_error(msg.str());
231244
}
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
256269

257270
~NCCLGroup() {
258271
ncclCommDestroy(comm_);
259-
ncclGroupEnd();
260272
initialized_ = false;
261273
}
262274

mlx/utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ inline bool enable_tf32() {
165165
return enable_tf32_;
166166
}
167167

168+
inline int nccl_timeout(int default_value) {
169+
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
170+
return nccl_timeout;
171+
}
172+
168173
} // namespace env
169174

170175
} // namespace mlx::core

0 commit comments

Comments
 (0)