@@ -821,7 +821,9 @@ def get_data_parallel_src_rank():
821821
822822def get_pipeline_model_parallel_first_rank ():
823823 """Return the global rank of the first process in the pipeline for the
824- current tensor parallel group"""
824+ current pipeline model parallel group
825+ NOTE (SpiralPipe) Returns `pp rank` of the first `cm rank` process
826+ """
825827 if _SPIRAL_CROSS_MAPPING :
826828 assert _SPIRAL_CROSS_MAPPING_LIST is not None
827829 return _SPIRAL_CROSS_MAPPING_LIST [0 ]
@@ -833,7 +835,9 @@ def get_pipeline_model_parallel_first_rank():
833835
834836def get_pipeline_model_parallel_last_rank ():
835837 """Return the global rank of the last process in the pipeline for the
836- current tensor parallel group"""
838+ current tensor parallel group
839+ NOTE (SpiralPipe) Returns `pp rank` of the last `cm rank` process
840+ """
837841 if _SPIRAL_CROSS_MAPPING :
838842 assert _SPIRAL_CROSS_MAPPING_LIST is not None
839843 return _SPIRAL_CROSS_MAPPING_LIST [- 1 ]
@@ -844,7 +848,9 @@ def get_pipeline_model_parallel_last_rank():
844848 return _PIPELINE_GLOBAL_RANKS [last_rank_local ]
845849
846850def get_pipeline_model_parallel_next_rank ():
847- """Return the global rank that follows the caller in the pipeline"""
851+ """Return the global rank that follows the caller in the pipeline
852+ NOTE (SpiralPipe) Returns `pp rank` of the next `cm rank` process
853+ """
848854 rank_in_pipeline = get_pipeline_model_parallel_rank ()
849855 world_size = get_pipeline_model_parallel_world_size ()
850856 if _SPIRAL_CROSS_MAPPING :
@@ -857,7 +863,9 @@ def get_pipeline_model_parallel_next_rank():
857863
858864
859865def get_pipeline_model_parallel_prev_rank ():
860- """Return the global rank that preceeds the caller in the pipeline"""
866+ """Return the global rank that preceeds the caller in the pipeline
867+ NOTE (SpiralPipe) Returns `pp rank` of the previous `cm rank` process
868+ """
861869 rank_in_pipeline = get_pipeline_model_parallel_rank ()
862870 world_size = get_pipeline_model_parallel_world_size ()
863871 if _SPIRAL_CROSS_MAPPING :
0 commit comments