2121
2222namespace mlx ::core::distributed::nccl {
2323
24+ // Can be tuned with MLX_NCCL_TIMEOUT
25+ constexpr int nccl_timeout = 300000 ; // miliseconds
26+
2427#define CHECK_CUDA (cmd ) \
2528 do { \
2629 cudaError_t e = cmd; \
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
181184 close (sock);
182185
183186 } else {
184- // Here just wanted to make show that rank 0 has enough time to bind
185- // so we will retry to connect until max attempts
187+ // Here we want to make sure that rank 0 has enough time to bind
188+ // so we will retry to connect until elapsed time exceeds nccl_timeout
189+ // this is particularity important for multinode setup
186190
187191 int sock = socket (AF_INET, SOCK_STREAM, 0 );
188192 if (sock < 0 ) {
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
200204 memcpy (&serv.sin_addr , he->h_addr_list [0 ], he->h_length );
201205 serv.sin_port = htons (port);
202206
203- const int max_retries = 30 ;
204- int attempt = 0 ;
207+ const int timeout_ms = env::nccl_timeout (nccl_timeout);
205208 bool connected = false ;
206209
207- bool do_log = std::getenv (" NCCL_DEBUG" ) == " INFO" ;
208- for (attempt = 0 ; attempt < max_retries; ++attempt) {
210+ const char * dbg = std::getenv (" NCCL_DEBUG" );
211+ bool do_log = (dbg && std::string (dbg) == " INFO" );
212+
213+ auto start = std::chrono::steady_clock::now ();
214+ int attempt = 0 ;
215+
216+ while (true ) {
217+ auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
218+ std::chrono::steady_clock::now () - start)
219+ .count ();
220+ if (elapsed_ms > timeout_ms)
221+ break ;
209222 if (connect (sock, reinterpret_cast <sockaddr*>(&serv), sizeof (serv)) ==
210223 0 ) {
211224 connected = true ;
212225 if (do_log) {
213- std::cout << " [Rank " << rank
214- << " ] Connected successfully on attempt " << attempt + 1
215- << std::endl;
226+ std::cout << " [Rank " << rank << " ] Connected successfully after "
227+ << elapsed_ms << " miliseconds" << std::endl;
216228 break ;
217229 }
218230 }
219231 if (errno != ECONNREFUSED) {
220232 break ;
221233 }
234+ ++attempt;
222235 std::this_thread::sleep_for (std::chrono::milliseconds (500 ));
223236 }
224237
225238 if (!connected) {
226239 std::ostringstream msg;
227- msg << " [Rank " << rank << " ] connect() failed after " << attempt
228- << " retries: " << strerror (errno);
240+ msg << " [Rank " << rank << " ] connect() failed after " << timeout_ms
241+ << " milliseconds and " << attempt << " retries: " << strerror (errno);
229242 close (sock);
230243 throw std::runtime_error (msg.str ());
231244 }
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
256269
257270 ~NCCLGroup () {
258271 ncclCommDestroy (comm_);
259- ncclGroupEnd ();
260272 initialized_ = false ;
261273 }
262274
0 commit comments