@@ -631,7 +631,7 @@ async def trainer_mode(self):
631631
632632 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
633633 @GPUMemoryLogger (role = "update_actor" , logger = logger )
634- @DistProfiler .annotate (color = "red" )
634+ @DistProfiler .annotate (color = "red" , role = "actor_update" )
635635 def update_actor (self , data : DataProto ):
636636 assert self ._is_actor
637637 if self ._is_offload_param :
@@ -674,7 +674,7 @@ def update_actor(self, data: DataProto):
674674
675675 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "rollout" ))
676676 @GPUMemoryLogger (role = "generate_sequences" , logger = logger )
677- @DistProfiler .annotate (color = "red" )
677+ @DistProfiler .annotate (color = "red" , role = "rollout_generate" )
678678 def generate_sequences (self , prompts : DataProto ):
679679 assert self ._is_rollout
680680 prompts = prompts .to (get_device_name ())
@@ -724,7 +724,7 @@ def generate_sequences(self, prompts: DataProto):
724724
725725 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
726726 @GPUMemoryLogger (role = "compute_ref_log_prob" , logger = logger )
727- @DistProfiler .annotate (color = "olive" )
727+ @DistProfiler .annotate (color = "olive" , role = "ref_compute_log_prob" )
728728 def compute_ref_log_prob (self , data : DataProto ):
729729 assert self ._is_ref
730730 if self ._ref_is_offload_param :
@@ -746,7 +746,7 @@ def compute_ref_log_prob(self, data: DataProto):
746746
747747 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
748748 @GPUMemoryLogger (role = "compute_log_prob" , logger = logger )
749- @DistProfiler .annotate (color = "blue" )
749+ @DistProfiler .annotate (color = "blue" , role = "actor_compute_log_prob" )
750750 def compute_log_prob (self , data : DataProto ):
751751 assert self ._is_actor
752752 if self ._is_offload_param :
@@ -1079,7 +1079,7 @@ def init_model(self):
10791079 )
10801080
10811081 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "critic" ))
1082- @DistProfiler .annotate (color = "cyan" )
1082+ @DistProfiler .annotate (color = "cyan" , role = "compute_values" )
10831083 def compute_values (self , data : DataProto ):
10841084 micro_batch_size = self .config .ppo_micro_batch_size_per_gpu
10851085 data .meta_info ["micro_batch_size" ] = micro_batch_size
@@ -1096,7 +1096,7 @@ def compute_values(self, data: DataProto):
10961096 return output
10971097
10981098 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "critic" ))
1099- @DistProfiler .annotate (color = "pink" )
1099+ @DistProfiler .annotate (color = "pink" , role = "update_critic" )
11001100 def update_critic (self , data : DataProto ):
11011101 data = data .to (get_device_id ())
11021102
@@ -1313,7 +1313,7 @@ def init_model(self):
13131313 # TODO: reward model use itself tokenizer instead of sft tokenizer
13141314 # the input_ids, responses, attention_mask and position_ids may be different!
13151315 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "reward" ))
1316- @DistProfiler .annotate (color = "brown" )
1316+ @DistProfiler .annotate (color = "brown" , role = "compute_rm_score" )
13171317 def compute_rm_score (self , data : DataProto ):
13181318 data .meta_info ["micro_batch_size" ] = self .config .micro_batch_size_per_gpu
13191319 data .meta_info ["max_token_len" ] = self .config .forward_max_token_len_per_gpu
0 commit comments