@@ -109,32 +109,30 @@ def create_placement_groups(args):
109109 }
110110
111111
112- def allocate_train_group (args , num_nodes , num_gpus_per_node , pg , wandb_run_id ):
112+ def allocate_train_group (args , num_nodes , num_gpus_per_node , pg ):
113113 return RayTrainGroup (
114114 args = args ,
115115 num_nodes = num_nodes ,
116116 num_gpus_per_node = num_gpus_per_node ,
117117 pg = pg ,
118- wandb_run_id = wandb_run_id ,
118+ wandb_run_id = args . wandb_run_id ,
119119 num_gpus_per_actor = 0.4 ,
120120 )
121121
122122
123- def create_training_models (args , pgs , rollout_manager , wandb_run_id ):
123+ def create_training_models (args , pgs , rollout_manager ):
124124 actor_model = allocate_train_group (
125125 args = args ,
126126 num_nodes = args .actor_num_nodes ,
127127 num_gpus_per_node = args .actor_num_gpus_per_node ,
128128 pg = pgs ["actor" ],
129- wandb_run_id = wandb_run_id ,
130129 )
131130 if args .use_critic :
132131 critic_model = allocate_train_group (
133132 args = args ,
134133 num_nodes = args .critic_num_nodes ,
135134 num_gpus_per_node = args .critic_num_gpus_per_node ,
136135 pg = pgs ["critic" ],
137- wandb_run_id = wandb_run_id ,
138136 )
139137 critic_init_handle = critic_model .async_init (args , role = "critic" , with_ref = False )
140138 else :
@@ -159,11 +157,11 @@ def create_training_models(args, pgs, rollout_manager, wandb_run_id):
159157 return actor_model , critic_model
160158
161159
162- def create_rollout_manager (args , pg , wandb_run_id ):
160+ def create_rollout_manager (args , pg ):
163161 rollout_manager = RolloutManager .options (
164162 num_cpus = 1 ,
165163 num_gpus = 0 ,
166- ).remote (args , pg , wandb_run_id = wandb_run_id )
164+ ).remote (args , pg , wandb_run_id = args . wandb_run_id )
167165
168166 # calculate num_rollout from num_epoch
169167 num_rollout_per_epoch = None
0 commit comments