Skip to content

Commit 1ef53c3

Browse files
committed
Merge branch 'llama3' into 'main'
Use TP-CP group for fp8 amax reduction See merge request ADLR/megatron-lm!1747
2 parents 99f4c82 + 2e42909 commit 1ef53c3

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

megatron/core/parallel_state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -887,14 +887,14 @@ def get_amax_reduction_group(with_context_parallel=False):
887887
"""Get the FP8 amax reduction group the caller rank belongs to."""
888888
if with_context_parallel:
889889
assert (
890-
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
890+
_TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
891891
), 'FP8 amax reduction group is not initialized'
892-
return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
892+
return _TENSOR_AND_CONTEXT_PARALLEL_GROUP
893893
else:
894894
assert (
895-
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
895+
_TENSOR_MODEL_PARALLEL_GROUP is not None
896896
), 'FP8 amax reduction group is not initialized'
897-
return _TENSOR_AND_DATA_PARALLEL_GROUP
897+
return _TENSOR_MODEL_PARALLEL_GROUP
898898

899899

900900
def get_tensor_and_data_parallel_group(with_context_parallel=False):

tests/unit_tests/test_parallel_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_different_initialize_order_unconsistency(src_tp_pp, ep_size):
218218
assert dp_g != torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False))
219219
assert pp_g != torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group())
220220
assert cp_g == torch.distributed.get_process_group_ranks(ps.get_context_parallel_group())
221-
assert amax_g != torch.distributed.get_process_group_ranks(ps.get_amax_reduction_group(False))
221+
assert amax_g == torch.distributed.get_process_group_ranks(ps.get_amax_reduction_group(False))
222222
assert mp_g != torch.distributed.get_process_group_ranks(ps.get_model_parallel_group())
223223

224224
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)