@@ -724,7 +724,7 @@ async def trainer_mode(self):
724724
725725 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
726726 @GPUMemoryLogger (role = "update_actor" , logger = logger )
727- @DistProfiler .annotate (color = "red" )
727+ @DistProfiler .annotate (color = "red" , role = "actor_update" )
728728 def update_actor (self , data : DataProto ):
729729 assert self ._is_actor
730730 if self ._is_offload_param :
@@ -767,7 +767,7 @@ def update_actor(self, data: DataProto):
767767
768768 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "rollout" ))
769769 @GPUMemoryLogger (role = "generate_sequences" , logger = logger )
770- @DistProfiler .annotate (color = "red" )
770+ @DistProfiler .annotate (color = "red" , role = "rollout_generate" )
771771 def generate_sequences (self , prompts : DataProto ):
772772 assert self ._is_rollout
773773 prompts = prompts .to (get_device_name ())
@@ -817,7 +817,7 @@ def generate_sequences(self, prompts: DataProto):
817817
818818 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
819819 @GPUMemoryLogger (role = "compute_ref_log_prob" , logger = logger )
820- @DistProfiler .annotate (color = "olive" )
820+ @DistProfiler .annotate (color = "olive" , role = "ref_compute_log_prob" )
821821 def compute_ref_log_prob (self , data : DataProto ):
822822 assert self ._is_ref
823823 if self ._ref_is_offload_param :
@@ -839,7 +839,7 @@ def compute_ref_log_prob(self, data: DataProto):
839839
840840 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
841841 @GPUMemoryLogger (role = "compute_log_prob" , logger = logger )
842- @DistProfiler .annotate (color = "blue" )
842+ @DistProfiler .annotate (color = "blue" , role = "actor_compute_log_prob" )
843843 def compute_log_prob (self , data : DataProto ):
844844 assert self ._is_actor
845845 if self ._is_offload_param :
@@ -1207,7 +1207,7 @@ def init_model(self):
12071207 )
12081208
12091209 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "critic" ))
1210- @DistProfiler .annotate (color = "cyan" )
1210+ @DistProfiler .annotate (color = "cyan" , role = "compute_values" )
12111211 def compute_values (self , data : DataProto ):
12121212 micro_batch_size = self .config .ppo_micro_batch_size_per_gpu
12131213 data .meta_info ["micro_batch_size" ] = micro_batch_size
@@ -1224,7 +1224,7 @@ def compute_values(self, data: DataProto):
12241224 return output
12251225
12261226 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "critic" ))
1227- @DistProfiler .annotate (color = "pink" )
1227+ @DistProfiler .annotate (color = "pink" , role = "update_critic" )
12281228 def update_critic (self , data : DataProto ):
12291229 data = data .to (get_device_id ())
12301230
@@ -1448,7 +1448,7 @@ def init_model(self):
14481448 # TODO: reward model use itself tokenizer instead of sft tokenizer
14491449 # the input_ids, responses, attention_mask and position_ids may be different!
14501450 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "reward" ))
1451- @DistProfiler .annotate (color = "brown" )
1451+ @DistProfiler .annotate (color = "brown" , role = "compute_rm_score" )
14521452 def compute_rm_score (self , data : DataProto ):
14531453 data .meta_info ["micro_batch_size" ] = self .config .micro_batch_size_per_gpu
14541454 data .meta_info ["max_token_len" ] = self .config .forward_max_token_len_per_gpu
0 commit comments