11import os
22import socket
33import time
4+ from argparse import Namespace
45from contextlib import nullcontext
56from pathlib import Path
7+ from typing import Dict , Optional , Tuple , Union
68
79import ray
810import torch
911import torch .distributed as dist
12+ from ray .actor import ActorHandle
1013
1114if torch .version .hip :
1215 from vllm .device_allocator .cumem import CuMemAllocator
2023from slime .utils .data import process_rollout_data
2124from slime .utils .distributed_utils import get_gloo_group , init_process_group
2225from slime .utils .memory_utils import clear_memory , print_memory
26+ from slime .utils .ray_utils import Box
2327from slime .utils .timer import Timer , timer
2428from slime .utils .wandb_utils import init_wandb_secondary
2529
2630from .checkpoint import load_checkpoint
2731from .cp_utils import slice_log_prob_with_cp
28- from .data import get_data_iterator , log_perf_data , log_rollout_data , sync_actor_critic_data
32+ from .data import DataIterator , get_data_iterator , log_perf_data , log_rollout_data , sync_actor_critic_data
2933from .initialize import init , is_megatron_main_rank
3034from .loss import compute_advantages_and_returns , get_log_probs_and_entropy , get_values
3135from .model import forward_only , initialize_model_and_optimizer , save , train
3236from .update_weight_utils import UpdateWeightFromDistributed , UpdateWeightFromTensor , named_parameters
3337
3438
3539class MegatronTrainRayActor (TrainRayActor ):
36- def init (self , args , role , wandb_run_id , with_ref = False ):
40+ def init (
41+ self ,
42+ args : Namespace ,
43+ role : str ,
44+ wandb_run_id : str ,
45+ with_ref : bool = False ,
46+ ) -> Optional [int ]:
3747 super ().init (args , role , wandb_run_id , with_ref )
3848
3949 init (args )
@@ -128,29 +138,29 @@ def init(self, args, role, wandb_run_id, with_ref=False):
128138 return start_rollout_id
129139
130140 @torch .no_grad ()
131- def update_cpu_params_dict (self , params_dict ) :
141+ def update_cpu_params_dict (self , params_dict : Dict [ str , torch . Tensor ]) -> None :
132142 for name , param in named_parameters (self .args , self .model ):
133143 if name not in params_dict :
134144 params_dict [name ] = torch .empty_like (param , device = torch .device ("cpu" ), pin_memory = True )
135145 params_dict [name ].copy_ (param .detach (), non_blocking = True )
136146 torch .cuda .synchronize ()
137147
138148 @torch .no_grad ()
139- def update_gpu_params_dict (self , params_dict ) :
149+ def update_gpu_params_dict (self , params_dict : Dict [ str , torch . Tensor ]) -> None :
140150 for name , param in named_parameters (self .args , self .model ):
141151 assert name in params_dict
142152 param .copy_ (params_dict [name ], non_blocking = True )
143153 torch .cuda .synchronize ()
144154
145155 @timer
146- def sleep (self , tags ) :
156+ def sleep (self , tags : Union [ str , Tuple [ str , ...]]) -> None :
147157 assert self .args .offload
148158 assert "model" in tags
149159 if isinstance (tags , str ):
150160 tags = (tags ,)
151161
152162 clear_memory ()
153- print_memory (f "before offload model" )
163+ print_memory ("before offload model" )
154164 if hasattr (mpu , "destroy_process_groups" ):
155165 mpu .destroy_process_groups ()
156166
@@ -160,10 +170,10 @@ def sleep(self, tags):
160170 allocator = CuMemAllocator .get_instance ()
161171 allocator .sleep (offload_tags = tags )
162172
163- print_memory (f "after offload model" )
173+ print_memory ("after offload model" )
164174
165175 @timer
166- def wake_up (self , tags ) :
176+ def wake_up (self , tags : Union [ str , Tuple [ str , ...]]) -> None :
167177 assert self .args .offload
168178
169179 # there are weird times when sglang is not offloaded immediately, so we wait here.
@@ -189,7 +199,9 @@ def wake_up(self, tags):
189199 mpu .reload_process_groups ()
190200 print_memory ("after wake_up model" )
191201
192- def _get_rollout_data (self , rollout_data_ref ):
202+ def _get_rollout_data (
203+ self , rollout_data_ref : Box
204+ ) -> Dict [str , list [torch .Tensor ] | list [int ] | list [float ] | list [str ]]:
193205 # Fetch data through ray on CPU, not sure if this will be performance bottleneck.
194206 # Both first pp stage and the last pp stage will recieve the data.
195207 rollout_data = process_rollout_data (
@@ -221,11 +233,11 @@ def _get_rollout_data(self, rollout_data_ref):
221233
222234 def compute_log_prob (
223235 self ,
224- model_tag ,
225- data_iterator ,
226- num_microbatches ,
227- store_prefix = "" ,
228- ):
236+ model_tag : str ,
237+ data_iterator : list [ DataIterator ] ,
238+ num_microbatches : list [ int ] ,
239+ store_prefix : str = "" ,
240+ ) -> Dict [ str , list [ torch . Tensor ]] :
229241 self .update_gpu_params_dict (self .weights [model_tag ])
230242
231243 with timer (f"{ store_prefix } log_probs" ):
@@ -238,7 +250,7 @@ def compute_log_prob(
238250 store_prefix = store_prefix ,
239251 )
240252
241- def train (self , rollout_id , rollout_data_ref ) :
253+ def train (self , rollout_id : int , rollout_data_ref : Box ) -> None :
242254 Timer ().end ("train_wait" )
243255
244256 if self .args .offload :
@@ -256,7 +268,9 @@ def train(self, rollout_id, rollout_data_ref):
256268 else :
257269 return self .train_actor (rollout_id , rollout_data )
258270
259- def train_critic (self , rollout_id , rollout_data ):
271+ def train_critic (
272+ self , rollout_id : int , rollout_data : Dict [str , list [torch .Tensor ] | list [int ] | list [float ] | list [str ]]
273+ ) -> None :
260274 # Create data iterator for log_probs and train.
261275 data_iterator , num_microbatches = get_data_iterator (self .args , self .model , rollout_data )
262276 rollout_data .update (
@@ -285,7 +299,9 @@ def train_critic(self, rollout_id, rollout_data):
285299 )
286300 Timer ().start ("train_wait" )
287301
288- def train_actor (self , rollout_id , rollout_data ):
302+ def train_actor (
303+ self , rollout_id : int , rollout_data : Dict [str , list [torch .Tensor ] | list [int ] | list [float ] | list [str ]]
304+ ) -> None :
289305 # Create data iterator for log_probs and train.
290306 data_iterator , num_microbatches = get_data_iterator (self .args , self .model , rollout_data )
291307
@@ -386,14 +402,14 @@ def train_actor(self, rollout_id, rollout_data):
386402 log_perf_data (rollout_id , self .args )
387403 Timer ().start ("train_wait" )
388404
389- def save_model (self , iteration ) :
405+ def save_model (self , iteration : int ) -> None :
390406 if self .args .debug_rollout_only :
391407 return
392408
393409 save (iteration , self .model , self .optimizer , self .opt_param_scheduler )
394410
395411 @timer
396- def update_weights (self ):
412+ def update_weights (self ) -> None :
397413 if self .args .debug_train_only or self .args .debug_rollout_only :
398414 return
399415
@@ -424,7 +440,7 @@ def update_weights(self):
424440 if self .args .offload and hasattr (mpu , "destroy_process_groups" ):
425441 mpu .destroy_process_groups ()
426442
427- def load_other_checkpoint (self , model_tag , path ) :
443+ def load_other_checkpoint (self , model_tag : str , path : str ) -> None :
428444 old_args = self .args .load , self .args .no_load_optim , self .args .no_load_rng , self .args .finetune
429445 self .args .load = path
430446 self .args .no_load_optim = True
@@ -450,7 +466,12 @@ def load_other_checkpoint(self, model_tag, path):
450466 self .weights [model_tag ] = {}
451467 self .update_cpu_params_dict (self .weights [model_tag ])
452468
453- def connect_actor_critic (self , actor_handle = None , master_address = None , master_port = None ):
469+ def connect_actor_critic (
470+ self ,
471+ actor_handle : Optional [ActorHandle ] = None ,
472+ master_address : Optional [str ] = None ,
473+ master_port : Optional [int ] = None ,
474+ ) -> None :
454475 if self .role == "actor" :
455476 master_address = ray .util .get_node_ip_address ()
456477 with socket .socket () as sock :
0 commit comments