Skip to content

Commit 0792ff0

Browse files
authored
Only fail when 10 consecutive socket errors occur (#1928)
1 parent fd0d63b commit 0792ff0

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

mlx/distributed/ring/ring.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class SocketThread {
199199
}
200200

201201
void worker() {
202+
int error_count = 0;
202203
bool delete_recv = false;
203204
bool delete_send = false;
204205
while (true) {
@@ -235,10 +236,11 @@ class SocketThread {
235236
task.buffer = static_cast<char*>(task.buffer) + r;
236237
task.size -= r;
237238
delete_recv = task.size == 0;
239+
error_count = 0;
238240
} else if (errno != EAGAIN) {
241+
error_count++;
239242
log_info(
240243
true, "Receiving from socket", fd_, "failed with errno", errno);
241-
return;
242244
}
243245
}
244246
if (!sends_.empty()) {
@@ -248,11 +250,17 @@ class SocketThread {
248250
task.buffer = static_cast<char*>(task.buffer) + r;
249251
task.size -= r;
250252
delete_send = task.size == 0;
253+
error_count = 0;
251254
} else if (errno != EAGAIN) {
255+
error_count++;
252256
log_info(true, "Sending to socket", fd_, "failed with errno", errno);
253-
return;
254257
}
255258
}
259+
260+
if (error_count >= 10) {
261+
log_info(true, "Too many send/recv errors. Aborting...");
262+
return;
263+
}
256264
}
257265
}
258266

python/mlx/distributed_run.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,12 @@ def normalize(path):
112112
break
113113
if not ring:
114114
break
115-
rings.append(normalize(concretize(ring, used_ports)))
115+
try:
116+
rings.append(normalize(concretize(ring, used_ports)))
117+
except RuntimeError:
118+
if len(rings) > 0:
119+
return rings
120+
raise
116121

117122
return rings
118123

0 commit comments

Comments
 (0)