Skip to content

Commit 3e42631

Browse files
authored
Print lm loss in the last cm rank process (#21)
* remove wrong logging in training.py * set lm loss logging process * update gitmodules commit idx
1 parent b87d831 commit 3e42631

File tree

5 files changed

+29
-11
lines changed

5 files changed

+29
-11
lines changed

csrc/external/DeepSpeed

csrc/external/spdlog

Submodule spdlog updated 249 files

external/Megatron-DeepSpeed

megatron/core/parallel_state.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,9 @@ def get_data_parallel_src_rank():
821821

822822
def get_pipeline_model_parallel_first_rank():
823823
"""Return the global rank of the first process in the pipeline for the
824-
current tensor parallel group"""
824+
current pipeline model parallel group
825+
NOTE (SpiralPipe) Returns `pp rank` of the first `cm rank` process
826+
"""
825827
if _SPIRAL_CROSS_MAPPING:
826828
assert _SPIRAL_CROSS_MAPPING_LIST is not None
827829
return _SPIRAL_CROSS_MAPPING_LIST[0]
@@ -833,7 +835,9 @@ def get_pipeline_model_parallel_first_rank():
833835

834836
def get_pipeline_model_parallel_last_rank():
835837
"""Return the global rank of the last process in the pipeline for the
836-
current tensor parallel group"""
838+
current tensor parallel group
839+
NOTE (SpiralPipe) Returns `pp rank` of the last `cm rank` process
840+
"""
837841
if _SPIRAL_CROSS_MAPPING:
838842
assert _SPIRAL_CROSS_MAPPING_LIST is not None
839843
return _SPIRAL_CROSS_MAPPING_LIST[-1]
@@ -844,7 +848,9 @@ def get_pipeline_model_parallel_last_rank():
844848
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
845849

846850
def get_pipeline_model_parallel_next_rank():
847-
"""Return the global rank that follows the caller in the pipeline"""
851+
"""Return the global rank that follows the caller in the pipeline
852+
NOTE (SpiralPipe) Returns `pp rank` of the next `cm rank` process
853+
"""
848854
rank_in_pipeline = get_pipeline_model_parallel_rank()
849855
world_size = get_pipeline_model_parallel_world_size()
850856
if _SPIRAL_CROSS_MAPPING:
@@ -857,7 +863,9 @@ def get_pipeline_model_parallel_next_rank():
857863

858864

859865
def get_pipeline_model_parallel_prev_rank():
860-
"""Return the global rank that preceeds the caller in the pipeline"""
866+
"""Return the global rank that preceeds the caller in the pipeline
867+
NOTE (SpiralPipe) Returns `pp rank` of the previous `cm rank` process
868+
"""
861869
rank_in_pipeline = get_pipeline_model_parallel_rank()
862870
world_size = get_pipeline_model_parallel_world_size()
863871
if _SPIRAL_CROSS_MAPPING:

megatron/training.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,9 +1079,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
10791079
iteration,
10801080
)
10811081

1082-
if iteration == 1:
1083-
timers("interval-time").elapsed(barrier=True)
1084-
10851082
if iteration % args.log_interval == 0:
10861083
elapsed_time = timers("interval-time").elapsed(barrier=True)
10871084

@@ -1160,7 +1157,20 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
11601157
total_loss_dict[advanced_iters_key] = 0
11611158
total_loss_dict[skipped_iters_key] = 0
11621159
total_loss_dict[nan_iters_key] = 0
1163-
print_rank_last(log_string)
1160+
# TODO (SpiralPipe) Only the pp rank that applies the `loss_func` saves "lm loss" in the `total_loss_dict`
1161+
# So, we can log loss in two ranks: the rank with last backward stage & the rank with last forward stage
1162+
# As the rank with last forward stage actually do not need to compute the loss (in future optimization),
1163+
# we should later print the log string in the rank with last backward stage. Note that this can handle the
1164+
# optimization case where the forward pass ends at the middle of the pipeline ranks (i.e., the last forward
1165+
# stage is not mapped to the last pipeline rank), does not compute the loss, while the last pipeline stage
1166+
# computes the loss after recomputation.
1167+
if args.spiral:
1168+
# NOTE (SpiralPipe) Currently, the last pipeline rank computes the loss.
1169+
# NOTE (SpiralPipe) We must consider the effect of cross-mapping also.
1170+
if mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1:
1171+
print(log_string, flush=True)
1172+
else:
1173+
print_rank_last(log_string)
11641174
if report_memory_flag and learning_rate > 0.:
11651175
# Report memory after optimizer state has been initialized.
11661176
report_memory('(after {} iterations)'.format(iteration))

0 commit comments

Comments
 (0)