@@ -146,6 +146,7 @@ pub mod action {
146146 pub enum ExecutorAction {
147147 Train ( TrainAction ) ,
148148 Aggregate ( AggregateAction ) ,
149+ Gymnasium ( GymnasiumAction ) ,
149150 }
150151
151152 /// Actions targeted at training workers.
@@ -194,6 +195,17 @@ pub mod action {
194195 BroadcastUpdate { target : Reference } ,
195196 Terminate ,
196197 }
198+
199+ /// Actions targeted at RL data generation executors.
200+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
201+ #[ serde( tag = "kind" , rename_all = "kebab-case" ) ]
202+ pub enum GymnasiumAction {
203+ Idle { timeout : SystemTime } ,
204+ Generate { source : Reference } ,
205+ Send { target : Reference } ,
206+ Update { } ,
207+ Terminate ,
208+ }
197209}
198210
199211// Protocol: Scheduler requests available workers
@@ -385,10 +397,18 @@ impl Fetch {
385397 } )
386398 }
387399
388- pub fn scheduler ( peer_id : PeerId , daset : String ) -> Self {
400+ pub fn data_peers ( peer_ids : Vec < PeerId > , resource : DataSlice ) -> Self {
401+ Self ( Reference :: Peers {
402+ peers : peer_ids,
403+ strategy : SelectionStrategy :: One ,
404+ resource : Some ( resource) ,
405+ } )
406+ }
407+
408+ pub fn scheduler ( peer_id : PeerId , dataset : String ) -> Self {
389409 Self ( Reference :: Scheduler {
390410 peer : peer_id,
391- dataset : daset ,
411+ dataset,
392412 } )
393413 }
394414}
@@ -586,6 +606,21 @@ pub struct AggregateExecutorConfig {
586606 pub optimizer : Nesterov ,
587607}
588608
609+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
610+ pub struct GymnasiumExecutorConfig {
611+ // TODO: Add support for additional optimizeres when needed.
612+ pub model : Model ,
613+ pub environment : String ,
614+ }
615+
616+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
617+ pub struct RlTrainerExecutorConfig {
618+ // TODO: Add support for additional optimizeres when needed.
619+ pub model : Model ,
620+ pub data : Fetch ,
621+ pub batch_size : u32 ,
622+ }
623+
589624#[ derive( Clone , Debug , Serialize , Deserialize , PartialEq , Eq , Hash ) ]
590625pub struct TrainExecutorDescriptor {
591626 name : String ,
@@ -642,11 +677,69 @@ impl From<AggregateExecutorDescriptor> for ExecutorDescriptor {
642677 }
643678}
644679
680+ #[ derive( Clone , Debug , Serialize , Deserialize , PartialEq , Eq , Hash ) ]
681+ pub struct GymnasiumExecutorDescriptor {
682+ name : String ,
683+ }
684+
685+ impl GymnasiumExecutorDescriptor {
686+ pub fn new ( name : impl Into < String > ) -> Self {
687+ Self { name : name. into ( ) }
688+ }
689+
690+ pub fn name ( & self ) -> & str {
691+ & self . name
692+ }
693+
694+ pub fn into_executor ( self , config : GymnasiumExecutorConfig ) -> GymnasiumExecutor {
695+ GymnasiumExecutor {
696+ descriptor : self ,
697+ config,
698+ }
699+ }
700+ }
701+
702+ impl From < GymnasiumExecutorDescriptor > for ExecutorDescriptor {
703+ fn from ( descriptor : GymnasiumExecutorDescriptor ) -> Self {
704+ Self :: Gymnasium ( descriptor. clone ( ) )
705+ }
706+ }
707+
708+ #[ derive( Clone , Debug , Serialize , Deserialize , PartialEq , Eq , Hash ) ]
709+ pub struct RlTrainerExecutorDescriptor {
710+ name : String ,
711+ }
712+
713+ impl RlTrainerExecutorDescriptor {
714+ pub fn new ( name : impl Into < String > ) -> Self {
715+ Self { name : name. into ( ) }
716+ }
717+
718+ pub fn name ( & self ) -> & str {
719+ & self . name
720+ }
721+
722+ pub fn into_executor ( self , config : RlTrainerExecutorConfig ) -> RlTrainerExecutor {
723+ RlTrainerExecutor {
724+ descriptor : self ,
725+ config,
726+ }
727+ }
728+ }
729+
730+ impl From < RlTrainerExecutorDescriptor > for ExecutorDescriptor {
731+ fn from ( descriptor : RlTrainerExecutorDescriptor ) -> Self {
732+ Self :: RlTrainer ( descriptor. clone ( ) )
733+ }
734+ }
735+
645736#[ derive( Clone , Debug , Serialize , Deserialize , PartialEq , Eq , Hash ) ]
646737#[ serde( tag = "class" , rename_all = "kebab-case" ) ]
647738pub enum ExecutorDescriptor {
648739 Train ( TrainExecutorDescriptor ) ,
649740 Aggregate ( AggregateExecutorDescriptor ) ,
741+ Gymnasium ( GymnasiumExecutorDescriptor ) ,
742+ RlTrainer ( RlTrainerExecutorDescriptor ) ,
650743}
651744
652745#[ derive( Clone , Debug , Serialize , Deserialize ) ]
@@ -693,12 +786,58 @@ impl From<AggregateExecutor> for Executor {
693786 }
694787}
695788
789+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
790+ pub struct GymnasiumExecutor {
791+ descriptor : GymnasiumExecutorDescriptor ,
792+ config : GymnasiumExecutorConfig ,
793+ }
794+
795+ impl GymnasiumExecutor {
796+ pub fn descriptor ( & self ) -> & GymnasiumExecutorDescriptor {
797+ & self . descriptor
798+ }
799+
800+ pub fn config ( & self ) -> & GymnasiumExecutorConfig {
801+ & self . config
802+ }
803+ }
804+
805+ impl From < GymnasiumExecutor > for Executor {
806+ fn from ( executor : GymnasiumExecutor ) -> Self {
807+ Self :: Gymnasium ( executor)
808+ }
809+ }
810+
811+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
812+ pub struct RlTrainerExecutor {
813+ descriptor : RlTrainerExecutorDescriptor ,
814+ config : RlTrainerExecutorConfig ,
815+ }
816+
817+ impl RlTrainerExecutor {
818+ pub fn descriptor ( & self ) -> & RlTrainerExecutorDescriptor {
819+ & self . descriptor
820+ }
821+
822+ pub fn config ( & self ) -> & RlTrainerExecutorConfig {
823+ & self . config
824+ }
825+ }
826+
827+ impl From < RlTrainerExecutor > for Executor {
828+ fn from ( executor : RlTrainerExecutor ) -> Self {
829+ Self :: RlTrainer ( executor)
830+ }
831+ }
832+
696833#[ allow( clippy:: large_enum_variant) ]
697834#[ derive( Clone , Debug , Serialize , Deserialize ) ]
698835#[ serde( tag = "class" , rename_all = "kebab-case" ) ]
699836pub enum Executor {
700837 Train ( TrainExecutor ) ,
701838 Aggregate ( AggregateExecutor ) ,
839+ Gymnasium ( GymnasiumExecutor ) ,
840+ RlTrainer ( RlTrainerExecutor ) ,
702841}
703842
704843// NOTE: This is not only to convert an `Executor` into an `ExecutorDescriptor` enum but also
@@ -710,6 +849,12 @@ impl From<&Executor> for ExecutorDescriptor {
710849 Executor :: Aggregate ( AggregateExecutor { descriptor, .. } ) => {
711850 Self :: Aggregate ( descriptor. clone ( ) )
712851 }
852+ Executor :: Gymnasium ( GymnasiumExecutor { descriptor, .. } ) => {
853+ Self :: Gymnasium ( descriptor. clone ( ) )
854+ }
855+ Executor :: RlTrainer ( RlTrainerExecutor { descriptor, .. } ) => {
856+ Self :: RlTrainer ( descriptor. clone ( ) )
857+ }
713858 }
714859 }
715860}
0 commit comments