Skip to content

Commit 8132a30

Browse files
committed
fix: multi_batch assignment
Assign a single batch in the first round and after an update. This stabilizes the scheduling and prevents over assignment. Once the scheduler is refactored, this can be further optimized.
1 parent 6852674 commit 8132a30

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

crates/scheduler/src/scheduling/batch_scheduler.rs

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,9 @@ where
227227
.map(|w| (batch_sizer)(&w.resources))
228228
.collect();
229229

230-
let (should_update, projected_target) = if update_target <= count {
231-
(true, count)
232-
} else if !snapshot.is_empty()
233-
&& batch_sizes.iter().all(|&b| b > 0)
234-
&& stats.iter().all(|&s| s > 0 && s < u64::MAX)
235-
{
230+
let (should_update, projected_target, batches) = if update_target <= count {
231+
(true, count, 0)
232+
} else if !snapshot.is_empty() && stats.iter().all(|&s| s > 0 && s < u64::MAX) {
236233
let (time, cnt, projection, capped) = S::project(
237234
&progress,
238235
&batch_sizes,
@@ -255,9 +252,10 @@ where
255252
&& peer_position < projection.len()
256253
&& projection[peer_position] == 0,
257254
count.saturating_add(cnt.unsigned_abs()),
255+
projection[peer_position],
258256
)
259257
} else {
260-
(false, count)
258+
(false, count, 1)
261259
};
262260

263261
// Check if peer has applied update or sent update
@@ -290,9 +288,7 @@ where
290288
timeout: short_idle,
291289
})
292290
} else if !should_update {
293-
ExecutorAction::Train(TrainAction::ExecuteBatch {
294-
batches: multi_batch_size,
295-
})
291+
ExecutorAction::Train(TrainAction::ExecuteBatch { batches })
296292
} else if parameter_servers.is_empty() {
297293
// NOTE: If we need to send an update but there are no parameter servers,
298294
// we must wait (idle) until one becomes available.
@@ -425,10 +421,7 @@ where
425421

426422
let (should_update, projected_target, batches) = if update_target <= count {
427423
(true, count, 0)
428-
} else if !snapshot.is_empty()
429-
&& batch_sizes.iter().all(|&b| b > 0)
430-
&& stats.iter().all(|&s| s > 0 && s < u64::MAX)
431-
{
424+
} else if !snapshot.is_empty() && stats.iter().all(|&s| s > 0 && s < u64::MAX) {
432425
let (time, cnt, projection, capped) = S::project(
433426
&progress,
434427
&batch_sizes,
@@ -455,7 +448,7 @@ where
455448
projection[peer_position],
456449
)
457450
} else {
458-
(false, count, multi_batch_size)
451+
(false, count, 1)
459452
};
460453

461454
if !should_update {
@@ -596,11 +589,7 @@ where
596589
},
597590
})
598591
} else {
599-
// We can either move through idle or expect that the parameters are tuned
600-
// s.t., its okay to execute a multi batch in the first round.
601-
ExecutorAction::Train(TrainAction::ExecuteBatch {
602-
batches: multi_batch_size,
603-
})
592+
ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 })
604593
}
605594
}
606595
}
@@ -1153,7 +1142,7 @@ mod batch_scheduler_tests {
11531142
.unwrap();
11541143

11551144
match resp.next {
1156-
ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 3 }) => {}
1145+
ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) => {}
11571146
other => panic!("Unexpected response: {:?}", other),
11581147
}
11591148
}
@@ -1224,7 +1213,7 @@ mod batch_scheduler_tests {
12241213
.unwrap();
12251214

12261215
match resp.next {
1227-
ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 3 }) => {}
1216+
ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) => {}
12281217
other => panic!("Unexpected response: {:?}", other),
12291218
}
12301219
}

0 commit comments

Comments
 (0)