Skip to content

Commit 8c781d9

Browse files
committed
Prototype RL scheduler
1 parent e0cdb00 commit 8c781d9

File tree

7 files changed

+1069
-29
lines changed

7 files changed

+1069
-29
lines changed

Cargo.lock

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/messages/src/lib.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ pub mod action {
8989
pub enum ExecutorStatus {
9090
Train(TrainStatus),
9191
Aggregate(AggregateStatus),
92+
Gymnasium(GymnasiumStatus),
9293
}
9394

9495
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -127,6 +128,17 @@ pub mod action {
127128
Error(AggregateError),
128129
}
129130

131+
#[derive(Clone, Debug, Serialize, Deserialize)]
132+
#[serde(tag = "state", rename_all = "kebab-case")]
133+
pub enum GymnasiumStatus {
134+
Joined,
135+
Idle,
136+
GeneratedData,
137+
SentData,
138+
ReceivedAgentState,
139+
Error(GymnasiumError),
140+
}
141+
130142
#[derive(Clone, Debug, Serialize, Deserialize)]
131143
#[serde(tag = "type", rename_all = "kebab-case")]
132144
pub enum TrainError {
@@ -141,6 +153,13 @@ pub mod action {
141153
Other { message: String },
142154
}
143155

156+
#[derive(Clone, Debug, Serialize, Deserialize)]
157+
#[serde(tag = "type", rename_all = "kebab-case")]
158+
pub enum GymnasiumError {
159+
Connection { message: String },
160+
Other { message: String },
161+
}
162+
144163
#[derive(Clone, Debug, Serialize, Deserialize)]
145164
#[serde(tag = "executor", content = "action", rename_all = "kebab-case")]
146165
pub enum ExecutorAction {
@@ -201,7 +220,7 @@ pub mod action {
201220
#[serde(tag = "kind", rename_all = "kebab-case")]
202221
pub enum GymnasiumAction {
203222
Idle { timeout: SystemTime },
204-
Generate { source: Reference },
223+
Generate {},
205224
Send { target: Reference },
206225
Update {},
207226
Terminate,

crates/scheduler/src/bin/hypha-scheduler.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ use hypha_scheduler::{
2828
network::Network,
2929
pool::{Pool, PoolConfig, PoolWithAggregateInfo, PoolWithTrainInfo},
3030
scheduler_config::{Job as SchedulerJob, MetricsConfig},
31-
scheduling::{batch_scheduler::BatchScheduler, data_scheduler::DataScheduler},
31+
scheduling::{
32+
batch_scheduler::BatchScheduler, data_scheduler::DataScheduler, rl_scheduler::RLScheduler,
33+
},
3234
simulation::BasicSimulation,
3335
statistics::RunningMean,
3436
task::Task,
@@ -700,12 +702,13 @@ async fn run(config: ConfigWithMetadata<Config>) -> Result<()> {
700702
let network = network.clone();
701703
let rl_config = rl_config.clone();
702704
let worker_spec = trainer_worker_spec.clone();
705+
let gymnasium_worker_handle = gymnasium_worker_handle.clone();
703706

704707
tokio::spawn(trainer_worker_pool.for_each_concurrent(None, move |worker| {
705708
let network = network.clone();
706709
let rl_config = rl_config.clone();
707710
let worker_spec = worker_spec.clone();
708-
let gymnasium_worker_pool = gymnasium_worker_handle.clone();
711+
let gymnasium_worker_handle = gymnasium_worker_handle.clone();
709712

710713
async move {
711714
match worker {
@@ -722,7 +725,7 @@ async fn run(config: ConfigWithMetadata<Config>) -> Result<()> {
722725
.into_executor(RlTrainerExecutorConfig {
723726
model: rl_config.model.clone().into(),
724727
data: Fetch::data_peers(
725-
gymnasium_worker_pool.members().iter().map(|worker| worker.peer_id).collect(),
728+
gymnasium_worker_handle.members().iter().map(|worker| worker.peer_id).collect(),
726729
DataSlice {dataset: "foo".to_string(), hash: 0},
727730
),
728731
batch_size,
@@ -752,9 +755,10 @@ async fn run(config: ConfigWithMetadata<Config>) -> Result<()> {
752755
}))
753756
};
754757

755-
let (metrics_rx, batch_scheduler_handle) =
756-
BatchScheduler::run::<RunningMean, BasicSimulation>(
758+
let (metrics_rx, rl_scheduler_handle) =
759+
RLScheduler::run::<RunningMean, BasicSimulation>(
757760
network.clone(),
761+
gymnasium_worker_handle.clone(),
758762
trainer_worker_handle.clone(),
759763
parameter_handle.clone(),
760764
job_id,
@@ -763,7 +767,7 @@ async fn run(config: ConfigWithMetadata<Config>) -> Result<()> {
763767
rl_config.rounds.avg_samples_between_updates,
764768
rl_config.rounds.update_rounds,
765769
rl_config.model_destination.clone(),
766-
batch_sizer.clone(),
770+
batch_sizer,
767771
rl_config.rounds.multi_batch_size,
768772
token.clone(),
769773
)
@@ -780,10 +784,10 @@ async fn run(config: ConfigWithMetadata<Config>) -> Result<()> {
780784
});
781785

782786
let abort_future = Box::pin(async move {
783-
if !batch_scheduler_handle.is_finished() {
784-
batch_scheduler_handle.abort();
787+
if !rl_scheduler_handle.is_finished() {
788+
rl_scheduler_handle.abort();
785789
}
786-
let _ = batch_scheduler_handle.await;
790+
let _ = rl_scheduler_handle.await;
787791

788792
if !parameter_dispatcher.is_finished() {
789793
parameter_dispatcher.abort();

crates/scheduler/src/scheduling/batch_scheduler.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ where
131131
T: RuntimeStatistic + 'static,
132132
S: Simulation + Send + Sync + 'static,
133133
{
134-
let _ = std::marker::PhantomData::<S>;
135134
let (peer_id, action::ActionRequest { job_id, status }) = request;
136135
tracing::debug!(%peer_id, ?status, %job_id, "Received action request");
137136

@@ -853,6 +852,11 @@ where
853852

854853
AggregateStatus::Terminated => ExecutorAction::Aggregate(AggregateAction::Terminate),
855854
},
855+
_ => {
856+
return Err(BatchSchedulerError::NetworkError(
857+
RequestResponseError::Other("unexpected request".to_string()),
858+
));
859+
}
856860
};
857861

858862
tracing::debug!(%peer_id, %job_id, response = ?next_action, "Sending action response");
@@ -886,7 +890,6 @@ impl BatchScheduler {
886890
T: RuntimeStatistic + 'static,
887891
S: Simulation + Send + Sync + 'static,
888892
{
889-
let _ = std::marker::PhantomData::<S>;
890893
let (tx, rx) = mpsc::channel(100);
891894
let start = std::time::Instant::now();
892895
let push_destination = Arc::new(push_destination);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod batch_scheduler;
22
pub mod data_scheduler;
3+
pub mod rl_scheduler;

0 commit comments

Comments
 (0)