Skip to content

Commit e0cdb00

Browse files
committed
WIP
1 parent b6dbc7a commit e0cdb00

File tree

13 files changed

+2451
-396
lines changed

13 files changed

+2451
-396
lines changed

Cargo.lock

Lines changed: 144 additions & 135 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/messages/src/lib.rs

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ pub mod action {
146146
pub enum ExecutorAction {
147147
Train(TrainAction),
148148
Aggregate(AggregateAction),
149+
Gymnasium(GymnasiumAction),
149150
}
150151

151152
/// Actions targeted at training workers.
@@ -194,6 +195,17 @@ pub mod action {
194195
BroadcastUpdate { target: Reference },
195196
Terminate,
196197
}
198+
199+
/// Actions targeted at RL data generation executors.
200+
#[derive(Clone, Debug, Serialize, Deserialize)]
201+
#[serde(tag = "kind", rename_all = "kebab-case")]
202+
pub enum GymnasiumAction {
203+
Idle { timeout: SystemTime },
204+
Generate { source: Reference },
205+
Send { target: Reference },
206+
Update {},
207+
Terminate,
208+
}
197209
}
198210

199211
// Protocol: Scheduler requests available workers
@@ -385,10 +397,18 @@ impl Fetch {
385397
})
386398
}
387399

388-
pub fn scheduler(peer_id: PeerId, daset: String) -> Self {
400+
pub fn data_peers(peer_ids: Vec<PeerId>, resource: DataSlice) -> Self {
401+
Self(Reference::Peers {
402+
peers: peer_ids,
403+
strategy: SelectionStrategy::One,
404+
resource: Some(resource),
405+
})
406+
}
407+
408+
pub fn scheduler(peer_id: PeerId, dataset: String) -> Self {
389409
Self(Reference::Scheduler {
390410
peer: peer_id,
391-
dataset: daset,
411+
dataset,
392412
})
393413
}
394414
}
@@ -586,6 +606,21 @@ pub struct AggregateExecutorConfig {
586606
pub optimizer: Nesterov,
587607
}
588608

609+
#[derive(Clone, Debug, Serialize, Deserialize)]
610+
pub struct GymnasiumExecutorConfig {
611+
// TODO: Add support for additional optimizeres when needed.
612+
pub model: Model,
613+
pub environment: String,
614+
}
615+
616+
#[derive(Clone, Debug, Serialize, Deserialize)]
617+
pub struct RlTrainerExecutorConfig {
618+
// TODO: Add support for additional optimizeres when needed.
619+
pub model: Model,
620+
pub data: Fetch,
621+
pub batch_size: u32,
622+
}
623+
589624
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
590625
pub struct TrainExecutorDescriptor {
591626
name: String,
@@ -642,11 +677,69 @@ impl From<AggregateExecutorDescriptor> for ExecutorDescriptor {
642677
}
643678
}
644679

680+
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
681+
pub struct GymnasiumExecutorDescriptor {
682+
name: String,
683+
}
684+
685+
impl GymnasiumExecutorDescriptor {
686+
pub fn new(name: impl Into<String>) -> Self {
687+
Self { name: name.into() }
688+
}
689+
690+
pub fn name(&self) -> &str {
691+
&self.name
692+
}
693+
694+
pub fn into_executor(self, config: GymnasiumExecutorConfig) -> GymnasiumExecutor {
695+
GymnasiumExecutor {
696+
descriptor: self,
697+
config,
698+
}
699+
}
700+
}
701+
702+
impl From<GymnasiumExecutorDescriptor> for ExecutorDescriptor {
703+
fn from(descriptor: GymnasiumExecutorDescriptor) -> Self {
704+
Self::Gymnasium(descriptor.clone())
705+
}
706+
}
707+
708+
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
709+
pub struct RlTrainerExecutorDescriptor {
710+
name: String,
711+
}
712+
713+
impl RlTrainerExecutorDescriptor {
714+
pub fn new(name: impl Into<String>) -> Self {
715+
Self { name: name.into() }
716+
}
717+
718+
pub fn name(&self) -> &str {
719+
&self.name
720+
}
721+
722+
pub fn into_executor(self, config: RlTrainerExecutorConfig) -> RlTrainerExecutor {
723+
RlTrainerExecutor {
724+
descriptor: self,
725+
config,
726+
}
727+
}
728+
}
729+
730+
impl From<RlTrainerExecutorDescriptor> for ExecutorDescriptor {
731+
fn from(descriptor: RlTrainerExecutorDescriptor) -> Self {
732+
Self::RlTrainer(descriptor.clone())
733+
}
734+
}
735+
645736
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
646737
#[serde(tag = "class", rename_all = "kebab-case")]
647738
pub enum ExecutorDescriptor {
648739
Train(TrainExecutorDescriptor),
649740
Aggregate(AggregateExecutorDescriptor),
741+
Gymnasium(GymnasiumExecutorDescriptor),
742+
RlTrainer(RlTrainerExecutorDescriptor),
650743
}
651744

652745
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -693,12 +786,58 @@ impl From<AggregateExecutor> for Executor {
693786
}
694787
}
695788

789+
#[derive(Clone, Debug, Serialize, Deserialize)]
790+
pub struct GymnasiumExecutor {
791+
descriptor: GymnasiumExecutorDescriptor,
792+
config: GymnasiumExecutorConfig,
793+
}
794+
795+
impl GymnasiumExecutor {
796+
pub fn descriptor(&self) -> &GymnasiumExecutorDescriptor {
797+
&self.descriptor
798+
}
799+
800+
pub fn config(&self) -> &GymnasiumExecutorConfig {
801+
&self.config
802+
}
803+
}
804+
805+
impl From<GymnasiumExecutor> for Executor {
806+
fn from(executor: GymnasiumExecutor) -> Self {
807+
Self::Gymnasium(executor)
808+
}
809+
}
810+
811+
#[derive(Clone, Debug, Serialize, Deserialize)]
812+
pub struct RlTrainerExecutor {
813+
descriptor: RlTrainerExecutorDescriptor,
814+
config: RlTrainerExecutorConfig,
815+
}
816+
817+
impl RlTrainerExecutor {
818+
pub fn descriptor(&self) -> &RlTrainerExecutorDescriptor {
819+
&self.descriptor
820+
}
821+
822+
pub fn config(&self) -> &RlTrainerExecutorConfig {
823+
&self.config
824+
}
825+
}
826+
827+
impl From<RlTrainerExecutor> for Executor {
828+
fn from(executor: RlTrainerExecutor) -> Self {
829+
Self::RlTrainer(executor)
830+
}
831+
}
832+
696833
#[allow(clippy::large_enum_variant)]
697834
#[derive(Clone, Debug, Serialize, Deserialize)]
698835
#[serde(tag = "class", rename_all = "kebab-case")]
699836
pub enum Executor {
700837
Train(TrainExecutor),
701838
Aggregate(AggregateExecutor),
839+
Gymnasium(GymnasiumExecutor),
840+
RlTrainer(RlTrainerExecutor),
702841
}
703842

704843
// NOTE: This is not only to convert an `Executor` into an `ExecutorDescriptor` enum but also
@@ -710,6 +849,12 @@ impl From<&Executor> for ExecutorDescriptor {
710849
Executor::Aggregate(AggregateExecutor { descriptor, .. }) => {
711850
Self::Aggregate(descriptor.clone())
712851
}
852+
Executor::Gymnasium(GymnasiumExecutor { descriptor, .. }) => {
853+
Self::Gymnasium(descriptor.clone())
854+
}
855+
Executor::RlTrainer(RlTrainerExecutor { descriptor, .. }) => {
856+
Self::RlTrainer(descriptor.clone())
857+
}
713858
}
714859
}
715860
}

0 commit comments

Comments
 (0)