From d90c1da3d7890d6784c30be8d8eca84e5754685a Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Thu, 16 Jan 2025 15:16:22 +0800 Subject: [PATCH 01/32] enable fsdp training and support huggingface models with ckpt in or out --- configs/1.8B_MoE16_sft.py | 3 +- configs/57B_qwen2_MoE.py | 3 +- configs/7B_MoE4_sft.py | 3 +- configs/7B_baichuan2.py | 1 - configs/7B_gemma.py | 1 - configs/7B_internlm2.py | 1 - configs/7B_isp_sft.py | 1 - configs/7B_llama2.py | 1 - configs/7B_qwen2.py | 1 - configs/7B_sft.py | 1 - configs/8x22B_mixtral.py | 3 +- configs/8x7B_mixtral.py | 3 +- configs/_base_/models/internlm2_1B.py | 1 - configs/_base_/models/internlm2_20B.py | 1 - configs/_base_/models/internlm2_7B.py | 1 - configs/_base_/models/internlm_20B.py | 1 - configs/_base_/models/internlm_7B.py | 1 - doc/code-docs/source/initialize.rst | 2 +- doc/code-docs/source/training.rst | 2 +- doc/en/train_performance.md | 2 +- doc/train_performance.md | 2 +- doc/usage.md | 2 - generate.py | 5 +- internlm/checkpoint/checkpoint_manager.py | 5 +- internlm/checkpoint/components.py | 194 ++++++++----- internlm/checkpoint/utils.py | 100 +++++-- internlm/core/context/__init__.py | 2 - internlm/core/context/parallel_context.py | 16 +- .../core/context/process_group_initializer.py | 70 ----- internlm/core/parallel/shard.py | 3 +- internlm/core/trainer_builder.py | 12 +- internlm/data/utils.py | 5 +- internlm/initialize/launch.py | 52 ++-- internlm/model/builder.py | 85 +++++- internlm/model/ops/fused_rmsnorm.py | 247 ++++++++++++++++ internlm/solver/activation_checkpoint.py | 13 + internlm/solver/optimizer/fsdp_optimizer.py | 83 ++++-- internlm/train/__init__.py | 4 +- internlm/train/pipeline.py | 237 +++++++++++++--- internlm/utils/gputest.py | 2 - internlm/utils/lazy.py | 265 ++++++++++++++++++ internlm/utils/parallel.py | 12 + internlm/utils/timeout.py | 2 +- tests/test_core/test_pipeline.py | 2 +- tests/test_data/test_batch_sampler.py | 2 +- tests/test_infer/test_generate.py | 7 +- tests/test_infer/test_trainer_generate.py | 6 +- tests/test_model/test_model_internlm.py | 2 +- .../test_forward_output_no_fa.py | 6 +- tests/test_training/test_load_ckpt_loss.py | 8 +- tests/test_training/test_loss.py | 10 +- tests/test_training/test_no_fa_train_temp.py | 10 +- tests/test_training/test_norm_weight.py | 10 +- .../test_swap_nb_loss_and_gradnorm.py | 6 +- tests/test_training/train_CI.py | 6 +- tests/test_utils/common_fixture.py | 4 +- tools/load_internlm2_model.py | 7 +- train.py | 2 +- 58 files changed, 1170 insertions(+), 369 deletions(-) create mode 100644 internlm/model/ops/fused_rmsnorm.py create mode 100644 internlm/utils/lazy.py diff --git a/configs/1.8B_MoE16_sft.py b/configs/1.8B_MoE16_sft.py index f85302778..eca10b045 100644 --- a/configs/1.8B_MoE16_sft.py +++ b/configs/1.8B_MoE16_sft.py @@ -170,7 +170,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -197,7 +196,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/57B_qwen2_MoE.py b/configs/57B_qwen2_MoE.py index abfb0a5b8..27f63cc1d 100644 --- a/configs/57B_qwen2_MoE.py +++ b/configs/57B_qwen2_MoE.py @@ -175,7 +175,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -202,7 +201,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 8d8acc406..74ebbcbb6 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -182,7 +182,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -217,7 +216,7 @@ 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), diff --git a/configs/7B_baichuan2.py b/configs/7B_baichuan2.py index eaa26a867..9957d6819 100644 --- a/configs/7B_baichuan2.py +++ b/configs/7B_baichuan2.py @@ -165,7 +165,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py index aff448232..643bcbdbf 100644 --- a/configs/7B_gemma.py +++ b/configs/7B_gemma.py @@ -172,7 +172,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 7a670171c..3c7bb9f4f 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -174,7 +174,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 95049036d..e7dd47b04 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -187,7 +187,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py index b0a173c8d..7783abaf7 100644 --- a/configs/7B_llama2.py +++ b/configs/7B_llama2.py @@ -164,7 +164,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py index 09b536ccc..3622e12f1 100644 --- a/configs/7B_qwen2.py +++ b/configs/7B_qwen2.py @@ -172,7 +172,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 4799b5f35..27847a5e8 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -174,7 +174,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/8x22B_mixtral.py b/configs/8x22B_mixtral.py index debd423b0..f1f1b6e60 100644 --- a/configs/8x22B_mixtral.py +++ b/configs/8x22B_mixtral.py @@ -176,7 +176,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -203,7 +202,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/8x7B_mixtral.py b/configs/8x7B_mixtral.py index 322342ea6..6db43f9c6 100644 --- a/configs/8x7B_mixtral.py +++ b/configs/8x7B_mixtral.py @@ -176,7 +176,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -203,7 +202,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index 5d050da92..cc3f186ad 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -51,7 +51,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py index 3b297e51f..dc461c0da 100644 --- a/configs/_base_/models/internlm2_20B.py +++ b/configs/_base_/models/internlm2_20B.py @@ -48,7 +48,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py index 37b99294b..cbdb03cb1 100644 --- a/configs/_base_/models/internlm2_7B.py +++ b/configs/_base_/models/internlm2_7B.py @@ -48,7 +48,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm_20B.py b/configs/_base_/models/internlm_20B.py index b7f7d8a59..26f4ff7f8 100644 --- a/configs/_base_/models/internlm_20B.py +++ b/configs/_base_/models/internlm_20B.py @@ -43,7 +43,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm_7B.py b/configs/_base_/models/internlm_7B.py index e666c02ee..8dde6e4e4 100644 --- a/configs/_base_/models/internlm_7B.py +++ b/configs/_base_/models/internlm_7B.py @@ -43,7 +43,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst index bcfe67d1a..721eec006 100644 --- a/doc/code-docs/source/initialize.rst +++ b/doc/code-docs/source/initialize.rst @@ -43,7 +43,7 @@ InternEvo 使用 `argparse `_ 模型初始化 ------------------------- -.. autofunction:: internlm.train.initialize_model +.. autofunction:: internlm.train.initialize_model_and_parallel_communicator InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下: diff --git a/doc/code-docs/source/training.rst b/doc/code-docs/source/training.rst index a0b4c2288..f43bfe4af 100644 --- a/doc/code-docs/source/training.rst +++ b/doc/code-docs/source/training.rst @@ -27,7 +27,7 @@ - 初始化模型 .. code-block:: python - model = initialize_model() + model = initialize_model_and_parallel_communicator() 详细介绍请参考: `模型初始化 `_ diff --git a/doc/en/train_performance.md b/doc/en/train_performance.md index d6b572f7b..ea998f06e 100644 --- a/doc/en/train_performance.md +++ b/doc/en/train_performance.md @@ -121,7 +121,7 @@ model = dict( ) parallel = dict( - zero1=dict(size=8, fsdp=False), + zero1=dict(size=8), tensor=1, pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, diff --git a/doc/train_performance.md b/doc/train_performance.md index 98364a753..891fa2f33 100644 --- a/doc/train_performance.md +++ b/doc/train_performance.md @@ -117,7 +117,7 @@ model = dict( ) parallel = dict( - zero1=dict(size=8, fsdp=False), + zero1=dict(size=8), tensor=1, pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, diff --git a/doc/usage.md b/doc/usage.md index 67ae1edf5..7c28d6d3e 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -268,7 +268,6 @@ zero1 parallel (dict): * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -432,7 +431,6 @@ parallel = dict( - 当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 - 当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 - 当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集 - 2. fsdp: 布尔值,启用/禁用torch的完全分片数据并行,默认为False。 - tensor(字典): 1. size: 整数,张量并行的大小。 2. mode: 字符串,张量并行模式,应该是 ['mtp', 'msp', 'fsp', 'isp'] 中的一个, diff --git a/generate.py b/generate.py index 4ae760299..48efa8b3f 100644 --- a/generate.py +++ b/generate.py @@ -21,7 +21,7 @@ from internlm.initialize import initialize_distributed_env from internlm.monitor import initialize_monitor_manager from internlm.monitor.monitor import monitor_manager as mm -from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.train import initialize_model_and_parallel_communicator from internlm.utils.common import ( enable_pytorch_expandable_segments, launch_time, @@ -106,8 +106,7 @@ def main(): raise e # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() model = model.model state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True) diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 2f7f5d4ed..1f9bf6a6c 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -23,6 +23,7 @@ from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import is_using_fsdp, is_using_hf from internlm.utils.storage_manager import ( get_storage_manager, init_storage_manager, @@ -271,7 +272,7 @@ def __init__( self.storage_manager = get_storage_manager() self.snapshot_counter = -1 - if hasattr(model, "model"): + if hasattr(model, "model") and not is_using_fsdp(): model = model.model self.model = model @@ -575,6 +576,8 @@ def try_resume_training(self, train_state: TrainState, current_time=""): f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) + elif is_using_fsdp() and is_using_hf() and not self.auto_resume: + pass else: load_path = self.load_ckpt_info["path"] load_content = self.load_ckpt_info["content"] diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py index eee92c9c5..d96bb65c5 100644 --- a/internlm/checkpoint/components.py +++ b/internlm/checkpoint/components.py @@ -4,7 +4,6 @@ from collections import defaultdict import torch -from torch.distributed._shard.api import load_with_process_group from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -13,16 +12,26 @@ from internlm.model.moe import MoE from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp +from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from .utils import ( - get_model_topology, - get_non_moe_state_dict, - get_shard_state_dict, - load_shard_state_dict, -) +from .utils import get_model_topology, get_non_moe_state_dict + +try: + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + set_model_state_dict, + ) + + DCP_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + DCP_SUPPORTED = False + +RESUME_HF_FORMAT = True logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -99,6 +108,34 @@ def try_save_moe_checkpoint(folder, model, expert_mp_rank, pp_rank): moe_layer_id += 1 +def load_fsdp_model_checkpoint(folder, model): + if DCP_SUPPORTED: + assert folder.startswith("local:"), "Currently we only support DCP load and save locally." + local_folder = folder[6:] + + if is_using_hf() and RESUME_HF_FORMAT: + hf = gpc.config.hf + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + state_dict = mod.from_pretrained( + pretrained_model_name_or_path=os.path.join(local_folder, "hf"), use_safetensors=True + ).state_dict() + state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} + set_model_state_dict( + model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) + ) + else: + state_dict = get_model_state_dict(model=model) + state_dict = {key: state_dict[key].clone().detach() for key in state_dict} + dcp.load(state_dict=state_dict, checkpoint_id=local_folder) + set_model_state_dict(model=model, model_state_dict=state_dict) + + del state_dict + internlm_accelerator.empty_cache() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + def load_model_checkpoint(folder, model): """ There should be weights with names similar to the following under the folder. @@ -109,43 +146,31 @@ def load_model_checkpoint(folder, model): - folder - model_wp{wp_rank}_pp{pp_rank}.pt - If fsdp is activated, the saved weight is named: - - folder - - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt - If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. """ + if is_using_fsdp(): + return load_fsdp_model_checkpoint(folder, model) + tp_size = gpc.get_world_size(ParallelMode.TENSOR) wp_size = gpc.get_world_size(ParallelMode.WEIGHT) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) - dp_size = gpc.get_world_size(ParallelMode.DATA) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - dp_rank = gpc.get_local_rank(ParallelMode.DATA) fns = get_fns(folder) - # avoid ckpt misuse between FSDP and no-FSDP _start_with = "model_w" if is_using_isp() else "model_t" - test_fn = list([f for f in fns if f.startswith(_start_with) and not f.endswith(".md5")]).pop() - assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or ( - "_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp - ), "FSDP model wants to load no-FSDP ckpts or reverse" - max_pp, max_wp, max_tp, max_zo = 0, 0, 0, 0 + max_pp, max_wp, max_tp = 0, 0, 0 for fn in fns: if fn.startswith(_start_with) and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") if is_using_isp(): max_pp = max(max_pp, int(segements[-1][2:])) max_wp = max(max_wp, int(segements[-2][2:])) - elif gpc.config.parallel.zero1.fsdp: - max_zo = max(max_zo, int(segements[-1][2:])) - max_pp = max(max_pp, int(segements[-2][2:])) - max_tp = max(max_tp, int(segements[-3][2:])) else: max_pp = max(max_pp, int(segements[-1][2:])) max_tp = max(max_tp, int(segements[-2][2:])) @@ -160,23 +185,13 @@ def load_model_checkpoint(folder, model): assert ( tp_size == max_tp + 1 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" - if gpc.config.parallel.zero1.fsdp: - assert ( - dp_size == max_zo + 1 - ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" - if is_using_isp(): should_load_name = f"model_wp{wp_rank}_pp{pp_rank}.pt" - elif gpc.config.parallel.zero1.fsdp: - should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" else: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, should_load_name) - # for FSDP shards loading, we need to set process group - with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)): - states = llm_load(fp, map_location=get_current_device()) - + states = llm_load(fp, map_location=get_current_device()) """ # need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in # gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make @@ -193,10 +208,7 @@ def load_model_checkpoint(folder, model): expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank) - if gpc.config.parallel.zero1.fsdp: - missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False) - else: - missing_k, unexpected_keys = model.load_state_dict(states, strict=False) + missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: logger.warning(f"Warning: missing keys {missing_k}") if len(unexpected_keys) != 0: @@ -207,6 +219,40 @@ def load_model_checkpoint(folder, model): internlm_accelerator.empty_cache() +def save_fsdp_model_checkpoint(folder, model): + def remove_model_prefix(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + new_key = key.replace("model.", "", 1) + new_state_dict[new_key] = state_dict[key].clone().detach() + return new_state_dict + + if DCP_SUPPORTED: + assert folder.startswith("local:"), "Currently we only support DCP load and save locally." + local_folder = folder[6:] + + if is_using_hf() and RESUME_HF_FORMAT: + state_dict = remove_model_prefix( + get_model_state_dict(model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)) + ) + if state_dict: + hf = gpc.config.hf + cfg = LazyObject(hf.cfg, hf.cfg_cls) + cfg = cfg.build() + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + with torch.device("meta"): + mod_to_save = mod(cfg(**hf.cfg_extra_kwargs)) + mod_to_save.load_state_dict(state_dict, strict=True, assign=True) + mod_to_save.save_pretrained(save_directory=os.path.join(local_folder, "hf"), safe_serialization=True) + else: + dcp.save(get_model_state_dict(model=model), checkpoint_id=local_folder) + + torch.distributed.barrier() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + def save_model_checkpoint(folder, model): """ Save the model according to the relationship between tp and dp. The principle is that the data of each tp @@ -218,10 +264,6 @@ def save_model_checkpoint(folder, model): - folder - model_wp{wp_rank}_pp{pp_rank}.pt - If fsdp is activated, the saved weight is named: - - folder - - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt - If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. Args: @@ -229,10 +271,10 @@ def save_model_checkpoint(folder, model): model: The model to be saved """ - if gpc.config.parallel.zero1.fsdp: - states = get_shard_state_dict(model) - else: - states = model.state_dict() + if is_using_fsdp(): + return save_fsdp_model_checkpoint(folder, model) + + states = model.state_dict() # get non-expert parameters states = get_non_moe_state_dict(states) @@ -268,21 +310,15 @@ def save_model_checkpoint(folder, model): else: # for tensor parallel mode with mtp/msp/fsp for i in range(tp_size): - if gpc.config.parallel.zero1.fsdp: - for j in range(dp_size): - should_save_rank_pair.add((i, j)) - else: - should_save_rank_pair.add((i, i % dp_size)) + should_save_rank_pair.add((i, i % dp_size)) if (tp_rank, dp_rank) in should_save_rank_pair: - f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else "" - fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt" + fn = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=states) - if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size: - topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" - topo_fp = os.path.join(folder, topo_fn) - llm_save(topo_fp, saved_obj=topo) + topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" + topo_fp = os.path.join(folder, topo_fn) + llm_save(topo_fp, saved_obj=topo) # try to save expert parameter to separate files if model have moe layer expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA) @@ -310,19 +346,25 @@ def load_optimizer_checkpoint(folder, optim): fns = get_fns(folder) max_tp, max_wp, max_pp, max_zero = 0, 0, 0, 0 + max_fsdp = 0 for fn in fns: if fn.startswith("optimizer_") and not fn.endswith(".md5"): - if is_using_isp(): - _, wp, pp, zero = os.path.splitext(fn)[0].split("_") - max_zero = max(max_zero, int(zero[2:])) - max_wp = max(max_wp, int(wp[2:])) - max_pp = max(max_pp, int(pp[2:])) + if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + if is_using_isp(): + _, wp, pp, zero = os.path.splitext(fn)[0].split("_") + max_zero = max(max_zero, int(zero[2:])) + max_wp = max(max_wp, int(wp[2:])) + max_pp = max(max_pp, int(pp[2:])) + else: + _, tp, pp, zero = os.path.splitext(fn)[0].split("_") + max_zero = max(max_zero, int(zero[2:])) + max_tp = max(max_tp, int(tp[2:])) + max_pp = max(max_pp, int(pp[2:])) else: - _, tp, pp, zero = os.path.splitext(fn)[0].split("_") - max_zero = max(max_zero, int(zero[2:])) - max_tp = max(max_tp, int(tp[2:])) - max_pp = max(max_pp, int(pp[2:])) + _, fsdp = os.path.splitext(fn)[0].split("_") + max_fsdp = max(max_fsdp, int(fsdp[4:])) + fsdp_size = gpc.get_world_size(ParallelMode.GLOBAL) zero_size = gpc.get_world_size(ParallelMode.ZERO1) tp_size = gpc.get_world_size(ParallelMode.TENSOR) wp_size = gpc.get_world_size(ParallelMode.WEIGHT) @@ -343,14 +385,24 @@ def load_optimizer_checkpoint(folder, optim): wp_size == max_wp + 1 ), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism" + if not isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + assert ( + fsdp_size == max_fsdp + 1 + ), f"The optimizer states are save for {max_fsdp+1} parallelism, while current has {fsdp_size} fsdp parallelism" + + fsdp_rank = gpc.get_local_rank(ParallelMode.GLOBAL) zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - if is_using_isp(): - fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + + if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + if is_using_isp(): + fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + else: + fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" else: - fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + fp = f"optimizer_fsdp{fsdp_rank}.pt" states = llm_load(os.path.join(folder, fp), map_location=get_current_device()) @@ -392,6 +444,7 @@ def save_optimizer_checkpoint(optim, state_path): """ # TODO sanity check for optimizer type + fsdp_rank = gpc.get_local_rank(ParallelMode.GLOBAL) zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) @@ -416,6 +469,7 @@ def save_optimizer_checkpoint(optim, state_path): fp_meta = os.path.join(state_path, optim.rank_unique_id) llm_save(fp_meta, params_per_rank_id_dict) else: + fp = f"optimizer_fsdp{fsdp_rank}.pt" llm_save(os.path.join(state_path, fp), states) diff --git a/internlm/checkpoint/utils.py b/internlm/checkpoint/utils.py index a63ddb948..cd8fae4bf 100644 --- a/internlm/checkpoint/utils.py +++ b/internlm/checkpoint/utils.py @@ -1,31 +1,21 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import itertools + +import numpy as np +import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import split_data_for_sequence_parallel +from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp logger = get_logger(__file__) -def get_shard_state_dict(shard_model): - """ - Only used for FSDP module saving. - It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter - (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu. - 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu - - """ - - # FSDP model can only save with sharded shape SHARDED_STATE_DICT when set use_orig_params=True - with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): - shard_states = shard_model.state_dict() - - return shard_states - - def get_non_moe_state_dict(full_state_dict): """ Get the state dict of the non-moe layers @@ -37,18 +27,6 @@ def get_non_moe_state_dict(full_state_dict): return full_state_dict -def load_shard_state_dict(shard_model, shard_state, **kwargs): - """ - Only used for FSDP module loading. - - """ - - with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): - missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs) - - return (missing_k, unexpected_keys) - - def get_model_topology(model): """ Returns: @@ -75,3 +53,67 @@ def process_load_info(load_info): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") return load_content_str, load_ckpt_folder, load_content + + +def init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP: + """ + Initialize Fully Sharded Data Parallel (FSDP) for the model. + This function is needed to properly initialize FSDP when resuming from a checkpoint. + It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. + + References: + https://github.com/pytorch/pytorch/issues/113496 + https://github.com/huggingface/transformers/pull/34032 + https://github.com/huggingface/transformers/issues/31892 + + Args: + model: The model to initialize with FSDP. + device: The device to run the model on. + + Returns: + The initialized FSDP model. + """ + model.train() + with torch.no_grad(): + # generate dummy packed sequence + seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz + input_ids = [1] * seq_len + label = input_ids[1:] + [-100] + cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len)) + + input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) + label = torch.tensor(label, device=device).unsqueeze(0) + indexes = torch.tensor( + list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])), + device=device, + ).unsqueeze(0) + cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0) + + data = { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": indexes, + "max_seqlen": seq_len, + } + + data_fns = [] + + # default data process function + if gpc.config.data.use_packed_dataset: + data_fns.append(packed_data_normalizer) + else: + data_fns.append(unpack_data) + + # support sequence parallel for isp + if is_using_isp(): + data_fns.append(split_data_for_sequence_parallel) + + # generate dummy_input + _data, _label = data, label + for fn in data_fns: + _data, _label = fn(_data, _label) + dummy_input = _data + + # run a forward pass with dummy_input to initialize FSDP + _ = model(**dummy_input) + return model diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index ae5f6a25f..b2fc95cc9 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -15,7 +15,6 @@ Initializer_Pipeline, Initializer_Tensor, Initializer_Zero1, - Initializer_Zero3_dp, ParallelMode, ProcessGroupInitializer, ) @@ -46,7 +45,6 @@ "Initializer_Data", "Initializer_Zero1", "Initializer_Nettest", - "Initializer_Zero3_dp", "ProcessGroupInitializer", "seed", "set_mode", diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index f4751f59a..5278426ed 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -483,16 +483,6 @@ def check_sanity(self): assert self.zero1_parallel_size > 0 - # check for fsdp: - # if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights - # because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank - # pytorch vision: 1.13.1+cu117 - if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False): - logger.warning( - f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, " - "will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value" - ) - def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: ele = config[key] @@ -518,7 +508,7 @@ def init_parallel_groups(self): if parallel_config is not None: # set default value for parallel size if "zero1" not in parallel_config: - parallel_config._add_item("zero1", dict(size=-1, fsdp=False)) + parallel_config._add_item("zero1", dict(size=-1)) if "pipeline" not in parallel_config: parallel_config._add_item("pipeline", dict(size=1, interleaved_overlap=False)) if "tensor" not in parallel_config: @@ -657,9 +647,7 @@ def init_parallel_groups(self): # process groups for parallelism. enable_moe = self.config.model.get("num_experts", 1) > 1 tp_mode = "mtp" if isinstance(parallel_config.tensor, int) else parallel_config.tensor.get("mode", "mtp") - is_fsdp = False if isinstance(parallel_config.zero1, int) else parallel_config.zero1.get("fsdp", False) - parallel_strategy = "fsdp" if is_fsdp else tp_mode - group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe) + group_configs = generate_parallel_group_configs(tp_mode, parallel_sizes, enable_moe) group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False) # process group for network test. diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 1e8057383..014bdbfdd 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -42,11 +42,6 @@ class ParallelMode(Enum): # runntime network test NETTEST = "nettest" - # zero3-dp parallel - # if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size - # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp - ZERO3_DP = "zero3_dp" - # expert parallel EXPERT = "expert" @@ -274,7 +269,6 @@ def create_single_process_group( ISP_SP_GROUP_ORDER = [ParallelMode.TENSOR, ParallelMode.DATA, ParallelMode.PIPELINE] ISP_WP_GROUP_ORDER = [ParallelMode.WEIGHT, ParallelMode.WEIGHT_DATA, ParallelMode.PIPELINE] ISP_MOE_GROUP_ORDER = [ParallelMode.EXPERT_WEIGHT, ParallelMode.EXPERT, ParallelMode.EXPERT_DATA, ParallelMode.PIPELINE] -FSDP_ORDER = [ParallelMode.DATA] # TODO: should we support moe for fsdp? SUBGROUP_SPEC = { "mtp": { @@ -283,9 +277,6 @@ def create_single_process_group( "isp": { ParallelMode.WEIGHT_DATA: [ParallelMode.ZERO1], }, # TODO: WEIGHT_ZERO1 - "fsdp": { - ParallelMode.DATA: [ParallelMode.ZERO3_DP, ParallelMode.ZERO1], - }, } @@ -321,8 +312,6 @@ def _recurse_generater(order: List[ParallelMode]): group_configs.append(("isp-wp", _recurse_generater(ISP_WP_GROUP_ORDER))) if enable_moe: group_configs.append(("isp-moe", _recurse_generater(ISP_MOE_GROUP_ORDER))) - elif parallel_strategy == "fsdp": - group_configs.append(("fsdp", _recurse_generater(FSDP_ORDER))) else: # 3d parallel: mtp, msp, fsp group_configs.append(("3d", _recurse_generater(MTP_GROUP_ORDER))) if enable_moe: @@ -1118,65 +1107,6 @@ def init_dist_group(self, use_cpu: bool = False): return groups -class Initializer_Zero3_dp(ProcessGroupInitializer): - """A ProcessGroupInitializer for data parallelism. - - Args: - rank (int): The rank of current process. - world_size (int): Size of whole communication world. - data_parallel_size (int): Size of data parallel. - pipeline_parallel_size (int): Size of pipeline parallel. - tensor_parallel_size (int): Size of tensor parallel. - zero1_parallel_size (int): Size of zero1 parallel. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert self.data_parallel_size % self.zero1_parallel_size == 0 - - # the only difference between this initializer and DP_initializer - # when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding - # eg: when zero=4 and dp=8 - # no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group - # fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually - - self.data_parallel_size //= self.zero1_parallel_size - self.rank_num_per_dp_group = self.world_size // self.data_parallel_size - - assert self.world_size % self.data_parallel_size == 0 - - def init_dist_group(self, use_cpu: bool = False): - """Initialize data parallel groups, and assign local_ranks and groups to each gpu. - - Returns: - Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): - A Data parallelism's information tuple. - """ - local_rank = None - ranks_in_group = None - process_group = None - cpu_group = None - group_world_size = None - mode = ParallelMode.ZERO3_DP - - for i in range(self.rank_num_per_dp_group): - ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] - group = dist.new_group(ranks) - if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group - else: - group_cpu = None - - if self.rank in ranks: - local_rank = ranks.index(self.rank) - group_world_size = len(ranks) - process_group = group - cpu_group = group_cpu - ranks_in_group = ranks - - return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode - - class Initializer_Weight(ProcessGroupInitializer): """A ProcessGroupInitializer for model weight parallelism. diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 4dc6a1f5b..308fbb897 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -11,6 +11,7 @@ from internlm.core.context import global_context as gpc from internlm.core.parallel.comm.utils import _gather, _split from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_hf from internlm.utils.utils import TensorParallelMode logger = get_logger(__file__) @@ -33,7 +34,7 @@ def _split_data_for_sequence_parallel(data, label): data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim) # NOTICE: For compatibility where the shape of position_ids is [batch, seqlen, ...] - if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False): + if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): _position_ids_seq_dim = 1 data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_position_ids_seq_dim) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 71c30d00d..4c18fd326 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -24,11 +24,9 @@ get_scheduler_hooks, initialize_llm_profile, initialize_optimizer, - initialize_parallel_communicator, inject_model, load_new_batch, record_current_batch_training_metrics, - set_param_unique_tracking_name, ) from internlm.utils.common import ( BatchSkipper, @@ -101,11 +99,8 @@ def __init__( # load config_lines config_lines = self._read_config(kwargs["config"]) - # set tracking name for parameters - set_param_unique_tracking_name(model) - - # inject model for amp and parallel training - model = inject_model(model) + # inject model for amp, parallel setting, parameter syncing and others + model, isp_communicator = inject_model(model) # check cuda env check_cuda_env() @@ -116,9 +111,6 @@ def __init__( # initialize loss function criterion = self._initialize_criterion() - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) - # initialize cpu offload manager for selective checkpoint initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 119f00f61..352273c79 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -7,6 +7,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode +from internlm.utils.parallel import is_using_hf def get_dataset_type_ids_map(path): @@ -64,7 +65,7 @@ def unpack_data(data, label): data["indexes"] = data["indexes"][0] # If model has inject_info and data_helper is enabled, we provide position_ids - if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False): + if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): data.pop("max_seqlen") data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen] @@ -81,7 +82,7 @@ def packed_data_normalizer(data, label): data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() # If model has inject_info and data_helper is enabled, we provide position_ids, cu_seqlens, max_seqlen - if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False): + if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens") gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen") data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen] diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e1cb2f0d2..f9e5b6ff0 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -15,7 +15,9 @@ from internlm.core.context.process_group_initializer import ParallelMode from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_hf from internlm.utils.timeout import llm_timeout from internlm.utils.utils import DataType, ModelType, TensorParallelMode @@ -64,6 +66,28 @@ def get_default_parser(): return parser +def inject_hf_config_before_launch(hf: dict): + # get HuggingFace model config + cfg = LazyObject(hf.cfg, hf.cfg_cls) + cfg = cfg.build() + model_config = cfg(**hf.cfg_extra_kwargs) + # inject HuggingFace model config into InternTrain as much as we know + if hasattr(model_config, "vocab_size"): + gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size + if hasattr(model_config, "num_hidden_layers"): + gpc.config.model.num_layers = gpc.config.NUM_LAYER = model_config.num_hidden_layers + if hasattr(model_config, "num_attention_heads"): + gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = model_config.num_attention_heads + if hasattr(model_config, "num_key_value_heads"): + gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = model_config.num_key_value_heads + if hasattr(model_config, "hidden_size"): + gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = model_config.hidden_size + if hasattr(model_config, "intermediate_size"): + gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size + if hasattr(model_config, "num_experts"): + gpc.config.model.num_experts = model_config.num_experts + + def args_sanity_check(): assert gpc.config is not None, "config is not load!" @@ -76,6 +100,11 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + # inject HuggingFace model config into IntrainTrain + if is_using_hf(): + inject_hf_config_before_launch(gpc.config.hf) + gpc.config.model_type = "hf" + if gpc.config.model_type == "InternLM3_M": # TODO: need check for isp overlap num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers @@ -88,11 +117,11 @@ def args_sanity_check(): # procssing the parallel config in gpc if "zero1" not in gpc.config.parallel: - gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False)) + gpc.config.parallel._add_item("zero1", dict(size=-1)) if isinstance(gpc.config.parallel.zero1, int): zero1_size = gpc.config.parallel.zero1 - gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False)) + gpc.config.parallel._add_item("zero1", dict(size=zero1_size)) if "pipeline" not in gpc.config.parallel: gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B")) @@ -131,19 +160,6 @@ def args_sanity_check(): if gpc.config.parallel.pipeline["mode"] == "ZBV": gpc.v_shape = True - # check fsdp config - if "fsdp" not in gpc.config.parallel.zero1: - gpc.config.parallel.zero1._add_item("fsdp", False) - - assert not ( - gpc.config.parallel.zero1.fsdp and pp > 1 - ), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP" - - if gpc.config.parallel.zero1.fsdp: - assert ( - torch.__version__ >= "2.0.1" - ), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}" - # processing the data config in gpc data = gpc.config.data @@ -401,11 +417,6 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name - if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: - assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" - assert ( - torch.__version__ >= "2.1.0" - ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" assert ( gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0 @@ -563,7 +574,6 @@ def args_sanity_check(): # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: - assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1" assert ( not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" diff --git a/internlm/model/builder.py b/internlm/model/builder.py index d6c3b20f1..e8d3f11b9 100644 --- a/internlm/model/builder.py +++ b/internlm/model/builder.py @@ -1,19 +1,34 @@ from typing import List, Union +import torch from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper from internlm.model.base_model import BaseModel +from internlm.model.modules.linear import ( + ParallelLinearWithCommExt, + ScaleColumnParallelLinear, +) from internlm.model.registry import model_initializer from internlm.utils.common import get_current_device +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp logger = get_logger(__file__) -def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: +def create_model() -> Union[nn.Module, List[nn.Module]]: + if is_using_hf(): + model = create_model_hf(hf=gpc.config.hf) + else: + model = create_model_builtin(model_type=gpc.config.model_type) + return model + + +def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: kwargs = dict(gpc.config.model) @@ -44,3 +59,71 @@ def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: logger.warning(f"To load/save huggingface ckpt, built-in model should inherited from {BaseModel.__name__}") return model + + +def create_model_hf(hf: dict) -> nn.Module: + cfg = LazyObject(hf.cfg, hf.cfg_cls) + cfg = cfg.build() + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + + assert is_using_fsdp(), "Curently HF models can only train with FSDP." + + fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") + if fsdp_init_method == "meta": + with torch.device("meta"): + model = mod(cfg(**hf.cfg_extra_kwargs)) + elif fsdp_init_method == "cuda": + # TODO: does HuggingFace models support directly initialized on cuda? + model = mod(cfg(**hf.cfg_extra_kwargs)).to(get_current_device()) + elif fsdp_init_method == "cpu": + model = mod(cfg(**hf.cfg_extra_kwargs)) + else: + raise ValueError(f"Unsupported fsdp init_method: {fsdp_init_method}") + + def traverse(module): + for name, child in module.named_children(): + if ( + isinstance(child, nn.Linear) + and not isinstance(child, ParallelLinearWithCommExt) + and child.weight.shape == (gpc.config.VOCAB_SIZE, gpc.config.HIDDEN_SIZE) + ): + child_new = ScaleColumnParallelLinear( + in_features=child.in_features, + out_features=child.out_features, + bias=child.bias is not None, + device=child.weight.device, + dtype=child.weight.dtype, + ) + setattr(module, name, child_new) + else: + traverse(child) + + # Do hack: lm_head or output layer should be replaced with ScaleColumnParallelLinear, + # to get ISP fwd gather / bwd split work normally. + if is_using_isp(): + # traverse model might be slower than replacement module by name directly + if getattr(model, "lm_head", None) is not None: + lm_head = model.lm_head + lm_head_new = ScaleColumnParallelLinear( + in_features=lm_head.in_features, + out_features=lm_head.out_features, + bias=lm_head.bias is not None, + device=lm_head.weight.device, + dtype=lm_head.weight.dtype, + ) + setattr(model, "lm_head", lm_head_new) + elif getattr(model, "output", None) is not None: + output = model.output + output_new = ScaleColumnParallelLinear( + in_features=output.in_features, + out_features=output.out_features, + bias=output.bias is not None, + device=output.weight.device, + dtype=output.weight.dtype, + ) + setattr(model, "output", output_new) + else: + traverse(model) + + return model diff --git a/internlm/model/ops/fused_rmsnorm.py b/internlm/model/ops/fused_rmsnorm.py new file mode 100644 index 000000000..bbf9e5d97 --- /dev/null +++ b/internlm/model/ops/fused_rmsnorm.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa=E731 +# pylint: disable=C3001 + +import math +from functools import partial + +import torch +import triton +import triton.language as tl +from torch.distributed._tensor import Partial, Replicate, Shard +from torch.distributed._tensor.experimental import local_map + +# FusedRMSNorm in Triton + +# Credit +# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py +# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_fwd_kernel( + X, + stride_x, + Y, + stride_y, + W, + Rstd, + eps, + M, # num rows # pylint: disable=W0613 + N, # num cols + block_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, block_N) + + # Load input data and weights + mask = cols < N + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Store the reciprocal standard deviation + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + x_hat = x * rstd + y = x_hat * w + + # Write output + tl.store(Y + row * stride_y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_bwd_kernel_sm( + X, + stride_x, + W, + DY, + stride_dy, + DX, + stride_dx, + Rstd, + DW, + eps, # pylint: disable=W0613 + M, # num rows + N, # num cols + rows_per_program, + block_N: tl.constexpr, +): + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, block_N) + mask = cols < N + + # Load weights + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Accumulate gradients for weights + dw = tl.zeros((block_N,), dtype=tl.float32) + + row_end = min(row_start + rows_per_program, M) + for row in range(row_start, row_end): + # Load input, output gradient, and reciprocal standard deviation + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute normalized input and gradients + x_hat = x * rstd + wdy = w * dy + dw += dy * x_hat + c1 = tl.sum(x_hat * wdy, axis=0) / N + dx = (wdy - x_hat * c1) * rstd + + # Store input gradient + tl.store(DX + row * stride_dx + cols, dx, mask=mask) + + # Store weight gradients + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + + +class TritonFusedRMSNorm(torch.autograd.Function): + """ + Triton based Fused RMSNorm + """ + + @partial( + local_map, + out_placements=[Shard(1)], + in_placements=(None, [Shard(1)], [Replicate()], None), + ) + @staticmethod + def forward(ctx, x, weight, eps): + x_shape_start = x.shape + + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (M,) + _rms_norm_fwd_kernel[grid]( + x, + x.stride(0), + y, + y.stride(0), + weight, + rstd, + eps, + M, + N, + block_N, + ) + + ctx.eps = eps + ctx.save_for_backward(x, weight, rstd) + ctx.x_shape_start = x_shape_start + + y = y.reshape(x_shape_start) + return y + + @partial( + local_map, + out_placements=([Shard(1)], [Partial()], None), + in_placements=(None, [Shard(1)]), + ) + @staticmethod + def backward(ctx, dy): + x, weight, rstd = ctx.saved_tensors + eps = ctx.eps + x_shape_start = ctx.x_shape_start + + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) + _rms_norm_bwd_kernel_sm[grid]( + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +# expose fusedRMSNorm as a function +def fused_rms_norm_fn( + x, + weight, + eps=1e-6, +): + return TritonFusedRMSNorm.apply( + x, + weight, + eps, + ) diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 5aedd9a3d..2b5c9e4ed 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -4,6 +4,9 @@ import weakref import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) from torch.utils.checkpoint import check_backward_validity, detach_variable from internlm.accelerator import get_accelerator @@ -273,3 +276,13 @@ def inner_unpack(packed): arg = arg.to(device="cpu") return output + + +def apply_ac_to_transformer_block(module: torch.nn.Module, checkpoint): + ac_freq = round(1 / checkpoint) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 5676608fa..94cc411c6 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -1,26 +1,83 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import math +from typing import Iterable +import torch +import torch.distributed as dist from torch.optim import Optimizer +from internlm.accelerator import get_accelerator from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc +from internlm.solver.optimizer.base_optimizer import BaseOptimizer from internlm.solver.optimizer.utils import ( DynamicGradScaler, - reduce_tensor, + get_norm, release_param_grad, ) +from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger -from .base_optimizer import BaseOptimizer -from .utils import compute_norm +try: + from torch.distributed.tensor import DTensor + + DTENSOR_SUPPORTED = True +except (ModuleNotFoundError, ImportError): + DTENSOR_SUPPORTED = False logger = get_logger(__file__) +inf = math.inf + +internlm_accelerator = get_accelerator() + + +def compute_norm( + gradients: Iterable[torch.Tensor], + parameters: Iterable[torch.Tensor], +) -> float: + """Get L2 norm + Arguments: + gradients (Iterable[Tensor]): The gradient value. + parameters (Iterable[Tensor]): The parameter each gradient corresponds to. + + Returns: + Total norm of the parameters, need total_norm**(1/norm) before using. + """ + + enable_cuda_kernels = gradients[0].device.type != "cpu" + + # Calculate norm. + tensor_parallel_grads = [g.data.float() for g, _ in zip(gradients, parameters)] + tensor_parallel_norm = get_norm(tensor_parallel_grads, float(2), enable_cuda_kernels) + # If norm is type of float, then we convert them into torch.Tensor. + total_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors + if not enable_cuda_kernels: + total_norm = move_norm_to_cuda(total_norm) + + if DTENSOR_SUPPORTED and isinstance(total_norm, DTensor): + total_norm = total_norm.full_tensor() + + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.GLOBAL)) + + if torch.is_tensor(total_norm): + total_norm = total_norm.item() + + # Scale. + if total_norm == float("inf") or total_norm == -float("inf"): + total_norm = -1 + + if math.isnan(total_norm): + total_norm = -2 + + return total_norm + class FSDPadaptOptimizer(BaseOptimizer): """ - optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file + optimizer for Pytorch FSDP if 'parallel.fsdp' is not None in config file reserve some necessary components of hybird-optim: grad_scaler; grad_clip and unscale; @@ -44,6 +101,7 @@ def __init__( growth_interval=grad_scal_cfg.fp16.growth_interval, hysteresis=grad_scal_cfg.hysteresis, max_scale=grad_scal_cfg.max_scale, + dtype=gpc.config.model.dtype, ) # clip gradient @@ -93,16 +151,6 @@ def zero_grad(self): param.grad = None def step(self): - # in case that fsdp-zero3 size is not equal to dp size - # FSDP module will only reduce gradient within FSDP process group - # so manually reduce grad is essential between two parallel FSDP process group - for group_idx in range(len(self.param_groups)): - params = self._fp16_param_groups[group_idx] - for param in params: - if param.requires_grad and param.grad is not None: - handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) - handle.wait() - # compute norm found_inf = False norm_groups = {} @@ -207,13 +255,6 @@ def load_state_dict(self, states): self.grad_scaler.load_state_dict(grad_scaler) optim_states = states["base_optim_states"] - if gpc.config.get("only_load_lr", False): - if gpc.is_rank_for_log(): - logger.info("Only load lr in param_groups, skip loading weights in optimizer...") - for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): - pg1["lr"] = pg2["lr"] - return - self.optim.load_state_dict(optim_states) # load fp32 optimizer weight diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 2ad60df09..f3c680da4 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -1,7 +1,7 @@ from .pipeline import ( get_scheduler_hooks, initialize_llm_profile, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, initialize_parallel_communicator, load_new_batch, @@ -12,7 +12,7 @@ __all__ = [ "initialize_llm_profile", - "initialize_model", + "initialize_model_and_parallel_communicator", "initialize_parallel_communicator", "initialize_optimizer", "load_new_batch", diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 784a5305a..945ee688a 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -1,15 +1,25 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import collections +import functools +import itertools import math import time -from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union import torch from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import DataLoader from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.checkpoint.utils import init_fsdp_v1 from internlm.core.context import ( IS_REPLICA_EXPERT_DATA_PARALLEL, IS_REPLICA_ZERO_PARALLEL, @@ -78,6 +88,7 @@ from internlm.solver.schedulers.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.train.utils import create_param_groups, map_param_block, timeout_input from internlm.utils.common import DummyProfile, SchedulerHook, get_current_device +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( @@ -85,6 +96,8 @@ is_replica_zero_parallel_parameter, is_tensor_expert_data_parallel_parameter, is_tensor_zero_parallel_parameter, + is_using_fsdp, + is_using_hf, is_using_isp, is_weight_expert_data_parallel_parameter, is_weight_zero_parallel_parameter, @@ -99,6 +112,25 @@ except (ImportError, ModuleNotFoundError): pass +try: + from torch.distributed._composable.fsdp import fully_shard + + FSDP2_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + FSDP2_SUPPORTED = False + + +try: + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, + ) + + DCP_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + DCP_SUPPORTED = False + + IS_INJECTED = "is_injected" LINEAR2NEWLINEAR_NAME_MAPPING = dict( @@ -220,44 +252,49 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) for _chunk in unwrap_naive_amp(model): - # special case for pure dp mode - if ( - isinstance(gpc.config.parallel["tensor"], dict) - and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name - and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) - ): - _check_module_func = _check_module_pure_dp - else: - _check_module_func = _check_module - # set param parallel attribute - for name, module in _chunk.named_modules(): - _check_module_func(name, module) - - for name, param in _chunk.named_parameters(): - assert ( - is_replica_zero_parallel_parameter(param) - or is_tensor_zero_parallel_parameter(param) - or is_weight_zero_parallel_parameter(param) - or is_tensor_expert_data_parallel_parameter(param) - or is_weight_expert_data_parallel_parameter(param) - or is_replica_expert_data_parallel_parameter(param) - ), f"parameter with name: {name} has no parallel attribution." - - -@llm_timeout(func_name="initialize_model") -def initialize_model(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None): + if not is_using_fsdp(): + # special case for pure dp mode + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) + == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): + _check_module_func = _check_module_pure_dp + else: + _check_module_func = _check_module + # set param parallel attribute + for name, module in _chunk.named_modules(): + _check_module_func(name, module) + + for name, param in _chunk.named_parameters(): + assert ( + is_replica_zero_parallel_parameter(param) + or is_tensor_zero_parallel_parameter(param) + or is_weight_zero_parallel_parameter(param) + or is_tensor_expert_data_parallel_parameter(param) + or is_weight_expert_data_parallel_parameter(param) + or is_replica_expert_data_parallel_parameter(param) + ), f"parameter with name: {name} has no parallel attribution." + + +@llm_timeout(func_name="initialize_model_and_parallel_communicator") +def initialize_model_and_parallel_communicator( + pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None +): """ Initialize model with Automatic Mixed Precision. Returns: torch.nn.Module: The neural network model to be trained or evaluated. + An isp communicator for managing comp/comm overlap. """ if pre_process_func: pre_process_output = pre_process_func() register_model_initializer() - model = create_model(model_type=gpc.config.model_type) + model = create_model() if post_process_func: post_process_func(pre_process_output) @@ -276,11 +313,18 @@ def inject_model(model): Returns: torch.nn.Module: The injected neural network model to be trained or evaluated. + An isp communicator for managing comp/comm overlap. """ if hasattr(model, IS_INJECTED) and getattr(model, IS_INJECTED): return model - inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None)) + # For non-HF cases, set tracking name for parameters + if not is_using_hf(): + set_param_unique_tracking_name(model) + + # For non-fsdp cases, set model inject helper + if not is_using_fsdp(): + inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None)) # should be set before NaiveAMPModel set_fp32_attr_for_model(model) @@ -310,7 +354,8 @@ def inject_model(model): # This sync is very important, cause the model weights kept in optimizer are copied # from the origin parameters in the memory, so we should make sure the dp sync # does not influence the model weights in optimizer be different with the origin parameters. - sync_model_param(model) + if not is_using_fsdp() or gpc.config.parallel.fsdp.get("init_method", "cuda") == "cuda": + sync_model_param(model) # This function is needed to make sure parameters that are not splitted by tensor parallelism are # the same across tensor parallelism. @@ -321,10 +366,15 @@ def inject_model(model): random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA set_mode(random_mode) + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + model = wrap_FSDP_model(model) + # set is_injected flag setattr(model, "IS_INJECTED", True) - return model + return model, isp_communicator _T = TypeVar("_T") @@ -360,7 +410,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): get_current_device(), gpc.config.model.checkpoint, ), - gpc.config.parallel.weight.overlap, + gpc.config.parallel.weight.overlap and not is_using_fsdp(), gpc.get_group(ParallelMode.WEIGHT), is_moe=False, selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), @@ -495,8 +545,9 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR) # register communitorc for embedding layer. - for embedding in _submodule_filter(model, Embedding1D): - _embedding_communicator.register_module_hook(embedding) + if not is_using_fsdp(): + for embedding in _submodule_filter(model, Embedding1D): + _embedding_communicator.register_module_hook(embedding) # register communictor for head layer. ScaleColumnParallelLinear.register_cls_communicator(_head_communicator) @@ -554,7 +605,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato else: param_bcast_sync_handler = None - if not gpc.config.parallel.zero1.fsdp: + if not is_using_fsdp(): if ( "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim @@ -975,6 +1026,122 @@ def inject_config(model: nn.Module) -> None: gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads +def _get_modules_to_materialize( + root_module: nn.Module, + ignored_modules: Set[nn.Module], +) -> List[nn.Module]: + # Run BFS to collect the modules to materialize via `reset_parameters()`, + # stopping at any module with FSDP already applied or at ignored modules. + modules_to_materialize: List[nn.Module] = [] + queue = collections.deque([root_module]) + visited_modules: Set[nn.Module] = {root_module} + while queue: + module = queue.popleft() + modules_to_materialize.append(module) + for child_module in module.children(): + if child_module not in visited_modules and child_module not in ignored_modules: + visited_modules.add(child_module) + queue.append(child_module) + return modules_to_materialize + + +def _materialize_meta_module( + root_module: nn.Module, + ignored_modules: Set[nn.Module], + device_id: Optional[torch.device], +) -> None: + # Run default meta device initialization + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + module = None + try: + # Assume that each module's `reset_parameters()` only initializes its + # own parameters and not those of its children + with torch.no_grad(): + for module in modules_to_materialize: + # As a contract to the user, only call `reset_parameters()` if + # the module has directly managed parameters/buffers + module_state_iter = itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False)) + has_module_states = len(list(module_state_iter)) > 0 + if has_module_states: + module.to_empty(device=device_id, recurse=False) + module.reset_parameters() # type: ignore[operator] + except BaseException as e: + logger.warning( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure that your module of" + f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] + ) + raise e + + +def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): + if is_using_fsdp(): + assert isinstance(model, nn.Module), "Currently FSDP does not support pipeline parallel." + wrap_cls = tuple( + LazyObject(warp_cls["mod"], warp_cls["mod_cls"]).build() for warp_cls in gpc.config.get("fsdp_wrap_cls", []) + ) + fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1") + fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") + + if fsdp_mode == "v1": + model = FSDP( + module=model, + process_group=gpc.get_group(ParallelMode.GLOBAL), + sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD + auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)), + sync_module_states=fsdp_init_method != "cuda", # sync model paramters + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + use_orig_params=True, + device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states + ) + # For FSDP v1, to get ckpt resuming work normally, we do dummy forward. + # This hack is needed due to FSDP v1 lazy initialization in model construction. + # FYI: https://github.com/pytorch/pytorch/issues/113496 + model = init_fsdp_v1(model, get_current_device()) + elif FSDP2_SUPPORTED and fsdp_mode == "v2": + fsdp_kwargs = { + "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True + } + for module in model.modules(): + if isinstance(module, wrap_cls): + fully_shard(module, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + if fsdp_init_method == "meta": + _materialize_meta_module(model, set(), get_current_device()) + elif fsdp_init_method == "cpu": + model.to(get_current_device()) + else: + raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}") + + if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False): + load_ckpt_info = gpc.config.ckpt.load_ckpt_info + load_ckpt_path = load_ckpt_info.get("path", None) + load_ckpt_content = load_ckpt_info.get("content", []) + if load_ckpt_path: + assert load_ckpt_content == ( + "model", + ), "If auto_resume=False and checkpoint path is given, only model can be loaded" + if DCP_SUPPORTED: + hf = gpc.config.hf + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + state_dict = mod.from_pretrained( + pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True + ).state_dict() + state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} + set_model_state_dict( + model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) + ) + del state_dict + internlm_accelerator.empty_cache() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + return model + + def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None: """ Inject model helper functions. diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 0e75cc48b..2d173fa37 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -118,8 +118,6 @@ def warmup_process_group(): dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO1)) if gpc.is_initialized(ParallelMode.MODEL): dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.MODEL)) - if gpc.is_initialized(ParallelMode.ZERO3_DP): - dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO3_DP)) if gpc.is_initialized(ParallelMode.EXPERT_DATA): dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT_DATA)) if gpc.is_initialized(ParallelMode.EXPERT): diff --git a/internlm/utils/lazy.py b/internlm/utils/lazy.py new file mode 100644 index 000000000..e67c63aa2 --- /dev/null +++ b/internlm/utils/lazy.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import abc +import importlib +from typing import Any, Optional, Type, Union + + +def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None) -> bool: + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type or tuple): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. Defaults to None. + + Returns: + bool: Return True if ``seq`` is valid else False. + + Examples: + >>> from mmengine.utils import is_seq_of + >>> seq = ['a', 'b', 'c'] + >>> is_seq_of(seq, str) + True + >>> is_seq_of(seq, int) + False + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +class LazyObject: + """LazyObject is used to lazily initialize the imported module during + parsing the configuration file. + + During parsing process, the syntax like: + + Examples: + >>> import torch.nn as nn + >>> from mmdet.models import RetinaNet + >>> import mmcls.models + >>> import mmcls.datasets + >>> import mmcls + + Will be parsed as: + + Examples: + >>> # import torch.nn as nn + >>> nn = lazyObject('torch.nn') + >>> # from mmdet.models import RetinaNet + >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') + >>> # import mmcls.models; import mmcls.datasets; import mmcls + >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) + + ``LazyObject`` records all module information and will be further + referenced by the configuration file. + + Args: + module (str or list or tuple): The module name to be imported. + imported (str, optional): The imported module name. Defaults to None. + location (str, optional): The filename and line number of the imported + module statement happened. + """ + + def __init__(self, module: Union[str, list, tuple], imported: Optional[str] = None, location: Optional[str] = None): + if not isinstance(module, str) and not is_seq_of(module, str): + raise TypeError( + "module should be `str`, `list`, or `tuple`" + f"but got {type(module)}, this might be " + "a bug of MMEngine, please report it to " + "https://github.com/open-mmlab/mmengine/issues" + ) + self._module: Union[str, list, tuple] = module + + if not isinstance(imported, str) and imported is not None: + raise TypeError( + "imported should be `str` or None, but got " + f"{type(imported)}, this might be " + "a bug of MMEngine, please report it to " + "https://github.com/open-mmlab/mmengine/issues" + ) + self._imported = imported + self.location = location + + def build(self) -> Any: + """Return imported object. + + Returns: + Any: Imported object + """ + if isinstance(self._module, str): + try: + module = importlib.import_module(self._module) + except Exception as e: + raise type(e)(f"Failed to import {self._module} " f"in {self.location} for {e}") + + if self._imported is not None: + if hasattr(module, self._imported): + module = getattr(module, self._imported) + else: + raise ImportError(f"Failed to import {self._imported} " f"from {self._module} in {self.location}") + + return module + else: + # import xxx.xxx + # import xxx.yyy + # import xxx.zzz + # return imported xxx + try: + for module in self._module: + importlib.import_module(module) # type: ignore + module_name = self._module[0].split(".")[0] + return importlib.import_module(module_name) + except Exception as e: + raise type(e)(f"Failed to import {self.module} " f"in {self.location} for {e}") + + @property + def module(self): + if isinstance(self._module, str): + return self._module + return self._module[0].split(".")[0] + + def __call__(self, *args, **kwargs): + raise RuntimeError() + + def __deepcopy__(self, memo): + return LazyObject(self._module, self._imported, self.location) + + def __getattr__(self, name): + # Cannot locate the line number of the getting attribute. + # Therefore only record the filename. + if self.location is not None: + location = self.location.split(", line")[0] + else: + location = self.location + return LazyAttr(name, self, location) + + def __str__(self) -> str: + if self._imported is not None: + return self._imported + return self.module + + __repr__ = __str__ + + # `pickle.dump` will try to get the `__getstate__` and `__setstate__` + # methods of the dumped object. If these two methods are not defined, + # LazyObject will return a `__getstate__` LazyObject` or `__setstate__` + # LazyObject. + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + + +class LazyAttr: + """The attribute of the LazyObject. + + When parsing the configuration file, the imported syntax will be + parsed as the assignment ``LazyObject``. During the subsequent parsing + process, users may reference the attributes of the LazyObject. + To ensure that these attributes also contain information needed to + reconstruct the attribute itself, LazyAttr was introduced. + + Examples: + >>> models = LazyObject(['mmdet.models']) + >>> model = dict(type=models.RetinaNet) + >>> print(type(model['type'])) # + >>> print(model['type'].build()) # + """ # noqa: E501 + + def __init__(self, name: str, source: Union["LazyObject", "LazyAttr"], location=None): + self.name = name + self.source: Union[LazyAttr, LazyObject] = source + + if isinstance(self.source, LazyObject): + if isinstance(self.source._module, str): + if self.source._imported is None: + # source code: + # from xxx.yyy import zzz + # equivalent code: + # zzz = LazyObject('xxx.yyy', 'zzz') + # The source code of get attribute: + # eee = zzz.eee + # Then, `eee._module` should be "xxx.yyy.zzz" + self._module = self.source._module + else: + # source code: + # import xxx.yyy as zzz + # equivalent code: + # zzz = LazyObject('xxx.yyy') + # The source code of get attribute: + # eee = zzz.eee + # Then, `eee._module` should be "xxx.yyy" + self._module = f"{self.source._module}.{self.source}" + else: + # The source code of LazyObject should be + # 1. import xxx.yyy + # 2. import xxx.zzz + # Equivalent to + # xxx = LazyObject(['xxx.yyy', 'xxx.zzz']) + + # The source code of LazyAttr should be + # eee = xxx.eee + # Then, eee._module = xxx + self._module = str(self.source) + elif isinstance(self.source, LazyAttr): + # 1. import xxx + # 2. zzz = xxx.yyy.zzz + + # Equivalent to: + # xxx = LazyObject('xxx') + # zzz = xxx.yyy.zzz + # zzz._module = xxx.yyy._module + zzz.name + self._module = f"{self.source._module}.{self.source.name}" + self.location = location + + @property + def module(self): + return self._module + + def __call__(self, *args, **kwargs: Any) -> Any: + raise RuntimeError() + + def __getattr__(self, name: str) -> "LazyAttr": + return LazyAttr(name, self) + + def __deepcopy__(self, memo): + return LazyAttr(self.name, self.source) + + def build(self) -> Any: + """Return the attribute of the imported object. + + Returns: + Any: attribute of the imported object. + """ + obj = self.source.build() + try: + return getattr(obj, self.name) + except AttributeError: + raise ImportError(f"Failed to import {self.module}.{self.name} in " f"{self.location}") + except ImportError as e: + raise e + + def __str__(self) -> str: + return self.name + + __repr__ = __str__ + + # `pickle.dump` will try to get the `__getstate__` and `__setstate__` + # methods of the dumped object. If these two methods are not defined, + # LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__` + # LazyAttr. + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 665353070..129b99366 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -16,6 +16,18 @@ from internlm.utils.utils import TensorParallelMode +def is_using_hf(): + return "hf" in gpc.config + + +def is_using_fsdp(): + return ( + "fsdp" in gpc.config.parallel + and isinstance(gpc.config.parallel["fsdp"], dict) + and gpc.config.parallel["fsdp"].get("enable", False) + ) + + def is_using_sequence_parallel(): return ( isinstance(gpc.config.parallel["tensor"], dict) diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py index 55b354c4d..5b09f9d5a 100644 --- a/internlm/utils/timeout.py +++ b/internlm/utils/timeout.py @@ -41,7 +41,7 @@ def __exit__(self, error_type, value, traceback): timeout_threshold_dict = { "initialize_distributed_env": 240, "nopp_forward_backward_step": 360, - "initialize_model": 60, + "initialize_model_and_parallel_communicator": 60, "initialize_optimizer": 60, "optim_step": 60, "build_train_loader_with_data_type": 600, diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 180fe4b71..efa5d7b71 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -21,7 +21,7 @@ dict( gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")], parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=8, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 6beeb7a7f..7600b7637 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -152,7 +152,7 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): config = Config( dict( parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py index bbb804a32..14741b494 100644 --- a/tests/test_infer/test_generate.py +++ b/tests/test_infer/test_generate.py @@ -6,7 +6,7 @@ from internlm.apis.inference import SequenceGenerator, batch_tokenize from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.train import initialize_model_and_parallel_communicator def set_seed(seed: int = 1024): @@ -30,7 +30,7 @@ def load_and_generate(path, model_type="INTERNLM2", tokenizer_path=""): model_type=model_type, model=model_config, parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), pipeline=dict(size=1, interleaved_overlap=True), tensor=dict(size=1, mode="mtp"), sequence_parallel=0, @@ -50,8 +50,7 @@ def convert_to_str(output_ids): all_output_str.append(cur_sent) return all_output_str - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # Directly get the origin model without NativeAMP wrapper. model = model.model diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py index 537a40777..c3149dda3 100644 --- a/tests/test_infer/test_trainer_generate.py +++ b/tests/test_infer/test_trainer_generate.py @@ -13,17 +13,15 @@ from internlm.model.losses import InternLoss # noqa: E402 from internlm.train import ( # noqa: E402 get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) def setup_generator(config, tokenizer): initialize_distributed_env(config=config) - model = initialize_model() - isp_communicator = initialize_parallel_communicator(model) + model, isp_communicator = initialize_model_and_parallel_communicator() criterion = InternLoss() diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index 3ce6f530e..e2655d291 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -33,7 +33,7 @@ config = Config( dict( parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index d01b876c8..ab81dbeed 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -18,9 +18,8 @@ from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -171,8 +170,7 @@ def train_check_output(args): seed_all(1024) # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index ddbb24a08..f9516c279 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -45,9 +45,8 @@ SchedulerMetricHook, ) from internlm.train import ( # noqa: E402 #pylint: disable=wrong-import-position - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import ( # noqa: E402 #pylint: disable=wrong-import-position @@ -67,7 +66,7 @@ dict( VOCAB_SIZE=103168, parallel=dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=dict(size=1, mode="mtp"), @@ -220,8 +219,7 @@ def train_model(args): current_time = objs[0] # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 967398e17..cdee0b18b 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -16,9 +16,8 @@ from internlm.model.losses import InternLoss from internlm.train import ( get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import BatchSkipper, launch_time @@ -167,11 +166,8 @@ def train( dist.broadcast_object_list(objs, src=0) current_time = objs[0] - # initialize model - model = initialize_model() - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) + # initialize model and isp_communicator + model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index 5f0782b4b..0b0493bb2 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -12,9 +12,8 @@ from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.logger import get_logger from tests.common_fixture import ( @@ -51,11 +50,8 @@ def train_check(args): # set seed seed_all(1024) - # initialize model - model = initialize_model() - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) + # initialize model and isp communicator + model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 990b334a6..1306da69b 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -15,9 +15,8 @@ from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -71,11 +70,8 @@ def train_check_norm_weight(args): # set seed seed_all(1024) - # initialize model - model = initialize_model() - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) + # initialize model and isp communicator + model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 13c01b1c5..84b79d9f0 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -24,9 +24,8 @@ from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -271,8 +270,7 @@ def exam_loss(args): seed_all(1024) # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index c7da6f85c..623f0ccec 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -36,9 +36,8 @@ from internlm.monitor.monitor import monitor_manager as mm # noqa: E402 from internlm.train import ( # noqa: E402 initialize_llm_profile, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, record_current_batch_training_metrics, ) from internlm.utils.common import ( # noqa: E402 @@ -116,8 +115,7 @@ def main(args): current_time = objs[0] # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model , _ = initialize_model_and_parallel_communicator() with open(args.config, "r") as f: config_lines = f.readlines() diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index f4b34ddee..d3405122d 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -46,7 +46,7 @@ init_config = Config( dict( parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), @@ -91,7 +91,7 @@ def init_naive_model(): register_model_initializer() - model = create_model(model_type=gpc.config.model_type) + model = create_model() model = NaiveAMPModel( model=model, output_to_fp32=False, diff --git a/tools/load_internlm2_model.py b/tools/load_internlm2_model.py index 70900cad5..4b639003e 100644 --- a/tools/load_internlm2_model.py +++ b/tools/load_internlm2_model.py @@ -11,7 +11,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.initialize.launch import initialize_distributed_env -from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.train import initialize_model_and_parallel_communicator from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -185,7 +185,7 @@ def initialize_internlm_model( model_type=model_type, model=model_config, parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), pipeline=dict(size=1, interleaved_overlap=True), tensor=dict(size=get_tp_world_size(), mode="mtp"), sequence_parallel=0, @@ -197,8 +197,7 @@ def initialize_internlm_model( args_check=False, ) # Directly get the origin model without NativeAMP wrapper. - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() model = model.model state_dict = merge_pp_within_tp(ckpt_dir, del_model_prefix=del_model_prefix) diff --git a/train.py b/train.py index 437774b1d..6e5e1399f 100755 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ @internevo_monitor(feishu_alert=True, clean_run=True) def main(args): # initialize model - model = create_model(model_type=gpc.config.model_type) + model = create_model() # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() From a70aaf60de0b152428bb17234aeb9e8f47b65a2b Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 14 Feb 2025 15:19:04 +0800 Subject: [PATCH 02/32] initial refactor: (1) reorg src structure to avoid cyclic imports (2)remove legacy or history codes (3)refine initializer interface --- .github/workflows/demo_in_readme.yaml | 2 + .github/workflows/lint_check.yaml | 8 +- README-ja-JP.md | 2 +- README-zh-Hans.md | 2 +- README.md | 2 +- ci_scripts/train/generate_config.py | 2 +- ci_scripts/train/load_ckpt.sh | 2 +- ci_scripts/train/slurm_train.sh | 2 +- ci_scripts/train/torchrun.sh | 2 +- .../locales/en/LC_MESSAGES/training.po | 4 +- doc/code-docs/source/example/20B_demo.rst | 2 +- doc/code-docs/source/example/7B_demo.rst | 2 +- doc/code-docs/source/initialize.rst | 2 +- doc/code-docs/source/training.rst | 6 +- doc/en/usage.md | 2 +- doc/usage.md | 2 +- generate.py | 10 +- internlm/__init__.py | 9 - internlm/accelerator/abstract_accelerator.py | 25 +- internlm/accelerator/cuda_accelerator.py | 1 + internlm/accelerator/dipu_accelerator.py | 1 + internlm/accelerator/ditorch_accelerator.py | 1 + internlm/accelerator/npu_accelerator.py | 1 + internlm/apis/inference_utils.py | 4 +- internlm/checkpoint/checkpoint_manager.py | 32 +- internlm/checkpoint/components.py | 2 +- internlm/checkpoint/load_funcs.py | 10 +- internlm/checkpoint/utils.py | 73 - internlm/core/context/__init__.py | 4 +- internlm/core/context/parallel_context.py | 96 +- internlm/core/engine.py | 4 +- internlm/core/fsdp.py | 224 +++ internlm/core/naive_amp.py | 2 +- internlm/core/parallel/comm/__init__.py | 47 +- internlm/core/parallel/comm/isp.py | 14 +- internlm/core/parallel/comm/tensor.py | 31 +- internlm/core/parallel/comm/zero.py | 13 +- internlm/core/parallel/shard.py | 2 +- internlm/core/trainer.py | 232 +++- internlm/core/trainer_builder.py | 32 +- internlm/data/tokenized/dummy_dataset.py | 2 +- internlm/data/utils.py | 8 +- internlm/eval/__init__.py | 8 +- internlm/eval/evaluation.py | 2 +- internlm/initialize/__init__.py | 14 +- internlm/initialize/constants.py | 9 + .../initialize/initialize_communicator.py | 214 +++ .../{launch.py => initialize_launcher.py} | 81 +- internlm/initialize/initialize_model.py | 228 ++++ internlm/initialize/initialize_optimizer.py | 189 +++ internlm/initialize/initialize_profiler.py | 61 + internlm/initialize/initialize_trainer.py | 4 +- internlm/initialize/legacy/launch.py | 40 - .../legacy => launcher}/__init__.py | 0 train.py => internlm/launcher/launch.py | 8 +- .../__init__.py | 0 .../{ => model_implementations}/builder.py | 14 +- .../{ => model_implementations}/registry.py | 33 +- .../transformers}/__init__.py | 0 .../transformers}/base_model.py | 6 +- .../transformers}/modeling_baichuan2.py | 22 +- .../transformers}/modeling_gemma.py | 22 +- .../transformers}/modeling_internlm.py | 25 +- .../transformers}/modeling_internlm2.py | 22 +- .../transformers}/modeling_llama.py | 22 +- .../transformers}/modeling_llava.py | 24 +- .../transformers}/modeling_mixtral.py | 27 +- .../transformers}/modeling_moe.py | 27 +- .../transformers}/modeling_qwen2.py | 22 +- .../transformers}/modeling_qwen2_moe.py | 24 +- .../transformers/utils.py} | 0 internlm/model/{ops => model_ops}/__init__.py | 0 internlm/model/model_ops/llava/__init__.py | 0 .../{ => model_ops}/llava/clip_builder.py | 0 .../{ => model_ops}/llava/clip_encoder.py | 0 .../llava/projector_builder.py | 0 .../model/{ => model_ops}/losses/__init__.py | 0 .../model/{ => model_ops}/losses/ce_loss.py | 2 +- internlm/model/{ => model_ops}/metrics.py | 2 +- internlm/model/model_ops/modules/__init__.py | 0 .../{ => model_ops}/modules/embedding.py | 2 +- .../model/{ => model_ops}/modules/linear.py | 5 +- internlm/model/{ => model_ops}/modules/mha.py | 8 +- internlm/model/{ => model_ops}/modules/mlp.py | 4 +- .../model/{ => model_ops}/modules/norm.py | 2 +- .../model/{ => model_ops}/modules/utils.py | 0 .../model/{ => model_ops}/moe/__init__.py | 0 .../model/{ => model_ops}/moe/base_layer.py | 2 +- .../{ => model_ops}/moe/dropless_layer.py | 2 +- internlm/model/{ => model_ops}/moe/experts.py | 0 .../model/{ => model_ops}/moe/gshard_layer.py | 2 +- .../moe/megablocks/__init__.py | 0 .../moe/megablocks/megablock_dmoe.py | 8 +- .../moe/megablocks/megablock_moe.py | 6 +- .../{ => model_ops}/moe/megablocks/mlp.py | 4 +- .../{ => model_ops}/moe/megablocks/utils.py | 2 +- internlm/model/{ => model_ops}/moe/moe.py | 10 +- internlm/model/{ => model_ops}/moe/utils.py | 0 internlm/model/model_ops/ops/__init__.py | 0 .../model/{ => model_ops}/ops/_flash_attn.py | 0 .../model/{ => model_ops}/ops/attention.py | 9 +- .../{ => model_ops}/ops/cross_entropy.py | 2 +- .../ops/cross_entropy_ops/__init__.py | 0 .../ops/cross_entropy_ops/apex_naive_loss.py | 0 .../ops/cross_entropy_ops/py_naive_loss.py | 0 .../py_vocab_parallel_loss.py | 0 .../sequence_parallel_loss.py | 0 .../{ => model_ops}/ops/fused_rmsnorm.py | 0 internlm/model/{ => model_ops}/ops/linear.py | 0 internlm/model/{ => model_ops}/ops/norm.py | 0 .../ops/ring_flash_attn/__init__.py | 0 .../ops/ring_flash_attn/utils.py | 0 ...zag_ring_flash_attn_with_sliding_window.py | 2 +- .../model/{ => model_ops}/ops/rotary_emb.py | 0 internlm/model/{ => model_ops}/ops/utils.py | 0 internlm/model/{ => model_ops}/utils.py | 4 +- internlm/monitor/__init__.py | 12 +- internlm/monitor/monitor.py | 6 +- internlm/monitor/utils.py | 4 - internlm/solver/activation_checkpoint.py | 14 +- internlm/solver/optimizer/__init__.py | 3 +- internlm/solver/optimizer/fsdp_optimizer.py | 5 +- .../solver/optimizer/hybrid_zero_optim.py | 12 +- .../solver/optimizer/hybrid_zero_optim_v2.py | 9 +- internlm/train/__init__.py | 23 - internlm/train/pipeline.py | 1210 ----------------- internlm/train/utils.py | 116 -- internlm/utils/common.py | 39 +- internlm/utils/config.py | 103 ++ internlm/utils/lazy.py | 5 +- internlm/utils/timeout.py | 2 +- requirements/runtime.txt | 18 +- setup.py | 68 +- tests/common_fixture.py | 9 +- tests/test_core/test_pipeline.py | 2 +- tests/test_core/utils.py | 9 +- tests/test_data/test_batch_sampler.py | 15 +- tests/test_infer/test_generate.py | 8 +- tests/test_infer/test_trainer_generate.py | 18 +- tests/test_model/test_embedding.py | 2 +- tests/test_model/test_feed_forward.py | 2 +- .../test_fused_precision.py | 8 +- tests/test_model/test_model_internlm.py | 17 +- tests/test_model/test_norm.py | 2 +- .../test_npu_ops/test_flash_attention.py | 11 +- .../test_npu_ops/test_npu_rmsnorm.py | 4 +- .../test_npu_ops/test_rotary_embed.py | 2 +- tests/test_solver/test_optimizer.py | 10 +- .../test_forward_output_no_fa.py | 19 +- tests/test_training/test_load_ckpt_loss.py | 34 +- tests/test_training/test_loss.py | 22 +- tests/test_training/test_no_fa_train_temp.py | 16 +- tests/test_training/test_norm_weight.py | 14 +- .../test_swap_nb_loss_and_gradnorm.py | 22 +- tests/test_training/train_CI.py | 36 +- tests/test_utils/common_fixture.py | 16 +- tests/test_utils/test_model_checkpoint.py | 44 +- tests/test_utils/test_storage_manager.py | 3 +- tests/test_utils/test_timeout.py | 4 +- tools/load_internlm2_model.py | 10 +- tools/moe_group_ckpt_converter.py | 1 - version.txt | 2 +- 162 files changed, 1974 insertions(+), 2187 deletions(-) create mode 100644 internlm/core/fsdp.py create mode 100644 internlm/initialize/constants.py create mode 100644 internlm/initialize/initialize_communicator.py rename internlm/initialize/{launch.py => initialize_launcher.py} (91%) create mode 100644 internlm/initialize/initialize_model.py create mode 100644 internlm/initialize/initialize_optimizer.py create mode 100644 internlm/initialize/initialize_profiler.py delete mode 100644 internlm/initialize/legacy/launch.py rename internlm/{initialize/legacy => launcher}/__init__.py (100%) rename train.py => internlm/launcher/launch.py (75%) mode change 100755 => 100644 rename internlm/model/{llava => model_implementations}/__init__.py (100%) rename internlm/model/{ => model_implementations}/builder.py (90%) rename internlm/model/{ => model_implementations}/registry.py (77%) rename internlm/model/{modules => model_implementations/transformers}/__init__.py (100%) rename internlm/model/{ => model_implementations/transformers}/base_model.py (74%) rename internlm/model/{ => model_implementations/transformers}/modeling_baichuan2.py (97%) rename internlm/model/{ => model_implementations/transformers}/modeling_gemma.py (98%) rename internlm/model/{ => model_implementations/transformers}/modeling_internlm.py (98%) rename internlm/model/{ => model_implementations/transformers}/modeling_internlm2.py (98%) rename internlm/model/{ => model_implementations/transformers}/modeling_llama.py (98%) rename internlm/model/{ => model_implementations/transformers}/modeling_llava.py (93%) rename internlm/model/{ => model_implementations/transformers}/modeling_mixtral.py (96%) rename internlm/model/{ => model_implementations/transformers}/modeling_moe.py (96%) rename internlm/model/{ => model_implementations/transformers}/modeling_qwen2.py (98%) rename internlm/model/{ => model_implementations/transformers}/modeling_qwen2_moe.py (97%) rename internlm/{initialize/initialize_tensor.py => model/model_implementations/transformers/utils.py} (100%) rename internlm/model/{ops => model_ops}/__init__.py (100%) create mode 100644 internlm/model/model_ops/llava/__init__.py rename internlm/model/{ => model_ops}/llava/clip_builder.py (100%) rename internlm/model/{ => model_ops}/llava/clip_encoder.py (100%) rename internlm/model/{ => model_ops}/llava/projector_builder.py (100%) rename internlm/model/{ => model_ops}/losses/__init__.py (100%) rename internlm/model/{ => model_ops}/losses/ce_loss.py (97%) rename internlm/model/{ => model_ops}/metrics.py (99%) create mode 100644 internlm/model/model_ops/modules/__init__.py rename internlm/model/{ => model_ops}/modules/embedding.py (99%) rename internlm/model/{ => model_ops}/modules/linear.py (99%) rename internlm/model/{ => model_ops}/modules/mha.py (99%) rename internlm/model/{ => model_ops}/modules/mlp.py (98%) rename internlm/model/{ => model_ops}/modules/norm.py (91%) rename internlm/model/{ => model_ops}/modules/utils.py (100%) rename internlm/model/{ => model_ops}/moe/__init__.py (100%) rename internlm/model/{ => model_ops}/moe/base_layer.py (95%) rename internlm/model/{ => model_ops}/moe/dropless_layer.py (99%) rename internlm/model/{ => model_ops}/moe/experts.py (100%) rename internlm/model/{ => model_ops}/moe/gshard_layer.py (99%) rename internlm/model/{ => model_ops}/moe/megablocks/__init__.py (100%) rename internlm/model/{ => model_ops}/moe/megablocks/megablock_dmoe.py (96%) rename internlm/model/{ => model_ops}/moe/megablocks/megablock_moe.py (98%) rename internlm/model/{ => model_ops}/moe/megablocks/mlp.py (95%) rename internlm/model/{ => model_ops}/moe/megablocks/utils.py (99%) rename internlm/model/{ => model_ops}/moe/moe.py (96%) rename internlm/model/{ => model_ops}/moe/utils.py (100%) create mode 100644 internlm/model/model_ops/ops/__init__.py rename internlm/model/{ => model_ops}/ops/_flash_attn.py (100%) rename internlm/model/{ => model_ops}/ops/attention.py (99%) rename internlm/model/{ => model_ops}/ops/cross_entropy.py (98%) rename internlm/model/{ => model_ops}/ops/cross_entropy_ops/__init__.py (100%) rename internlm/model/{ => model_ops}/ops/cross_entropy_ops/apex_naive_loss.py (100%) rename internlm/model/{ => model_ops}/ops/cross_entropy_ops/py_naive_loss.py (100%) rename internlm/model/{ => model_ops}/ops/cross_entropy_ops/py_vocab_parallel_loss.py (100%) rename internlm/model/{ => model_ops}/ops/cross_entropy_ops/sequence_parallel_loss.py (100%) rename internlm/model/{ => model_ops}/ops/fused_rmsnorm.py (100%) rename internlm/model/{ => model_ops}/ops/linear.py (100%) rename internlm/model/{ => model_ops}/ops/norm.py (100%) rename internlm/model/{ => model_ops}/ops/ring_flash_attn/__init__.py (100%) rename internlm/model/{ => model_ops}/ops/ring_flash_attn/utils.py (100%) rename internlm/model/{ => model_ops}/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py (99%) rename internlm/model/{ => model_ops}/ops/rotary_emb.py (100%) rename internlm/model/{ => model_ops}/ops/utils.py (100%) rename internlm/model/{ => model_ops}/utils.py (98%) delete mode 100644 internlm/train/__init__.py delete mode 100644 internlm/train/pipeline.py delete mode 100644 internlm/train/utils.py create mode 100644 internlm/utils/config.py diff --git a/.github/workflows/demo_in_readme.yaml b/.github/workflows/demo_in_readme.yaml index a764a39f6..5a1fa3a85 100644 --- a/.github/workflows/demo_in_readme.yaml +++ b/.github/workflows/demo_in_readme.yaml @@ -63,6 +63,7 @@ jobs: export GITHUB_WORKSPACE=$GITHUB_WORKSPACE export SLURM_PARTITION=$SLURM_PARTITION source activate ${evo_env_torch21_flash2} + export PYTHONPATH=$PWD:$PYTHONPATH sh ./ci_scripts/train/slurm_train.sh ${GITHUB_RUN_ID}-${GITHUB_JOB} EOF @@ -97,6 +98,7 @@ jobs: export GITHUB_WORKSPACE=$GITHUB_WORKSPACE export SLURM_PARTITION=$SLURM_PARTITION source activate ${evo_env_torch21_flash2} + export PYTHONPATH=$PWD:$PYTHONPATH sh ./ci_scripts/train/torchrun.sh ${GITHUB_RUN_ID}-${GITHUB_JOB} rm -rf $GITHUB_WORKSPACE/llm_ckpts EOF diff --git a/.github/workflows/lint_check.yaml b/.github/workflows/lint_check.yaml index fe86bd05a..1d881cd2b 100644 --- a/.github/workflows/lint_check.yaml +++ b/.github/workflows/lint_check.yaml @@ -18,25 +18,21 @@ jobs: run: | pip install flake8==v3.8.4 FLAKE_DISABLE_LIST="F403,F405,W504,W503,E203" - flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST --exclude=./internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/* - flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST ./train.py + flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST --exclude=./internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/* - name: lint-isort run: | pip install isort==5.12.0 isort --check --profile=black ./internlm/* - isort --check --profile=black ./train.py - name: lint-black run: | pip install black==22.8.0 BLACK_EXCLUDE_SETTINGS='\.venv/|\.local/|\.cache/|\.git/' black --line-length=120 --check --exclude $BLACK_EXCLUDE_SETTINGS ./internlm/* - black --line-length=120 --check --exclude $BLACK_EXCLUDE_SETTINGS ./train.py - name: lint-pylint run: | pip install pylint==v2.17.2 PYLINT_DISABLE_LIST="C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203" - pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST --ignore=./internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/* - pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST ./train.py + pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST --ignore=./internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/* diff --git a/README-ja-JP.md b/README-ja-JP.md index bb4c9c201..3a2711611 100644 --- a/README-ja-JP.md +++ b/README-ja-JP.md @@ -99,7 +99,7 @@ data = dict( Slurm環境で2ノード16カードを使用する場合、コマンドは以下の通りです: ```bash -$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py ``` torchを使用し、1ノード8カードで実行する場合、コマンドは以下の通りです: diff --git a/README-zh-Hans.md b/README-zh-Hans.md index 98a9caab0..d955d3d1e 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -99,7 +99,7 @@ data = dict( slurm环境,双机16卡,启动训练命令如下: ```bash -$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py ``` torch环境,单机8卡,启动训练命令如下: diff --git a/README.md b/README.md index 8a9b96612..7f3247b40 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ Training can be started on slurm or torch distributed environment. On slurm, using 2 nodes and 16 cards, the command is as follows: ```bash -$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py ``` On torch, using 1 node and 8 cards, the command is as follows: diff --git a/ci_scripts/train/generate_config.py b/ci_scripts/train/generate_config.py index 096334d06..a2a0aaf0d 100644 --- a/ci_scripts/train/generate_config.py +++ b/ci_scripts/train/generate_config.py @@ -5,7 +5,7 @@ import os from ci_scripts.common import com_func -from internlm.core.context import Config +from internlm.utils.config import Config def generate_new_config(config_py_file, test_config_json, case_name): diff --git a/ci_scripts/train/load_ckpt.sh b/ci_scripts/train/load_ckpt.sh index 287adbd89..50b293da4 100644 --- a/ci_scripts/train/load_ckpt.sh +++ b/ci_scripts/train/load_ckpt.sh @@ -22,7 +22,7 @@ if [[ ! -f ${file} ]]; then exit_code=$(($exit_code + 1)) fi -srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$2 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ${file} +srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$2 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python internlm/launcher/launch.py --config ${file} [[ $? -ne 0 ]] && { echo "test slurm training failed."; exit_code=$(($exit_code + 1)); } diff --git a/ci_scripts/train/slurm_train.sh b/ci_scripts/train/slurm_train.sh index b3117a165..753feaab1 100644 --- a/ci_scripts/train/slurm_train.sh +++ b/ci_scripts/train/slurm_train.sh @@ -22,7 +22,7 @@ if [[ -d ${CKPTS20_PATH} ]]; then fi fi -srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./ci_scripts/train/ci_7B_sft.py +srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python internlm/launcher/launch.py --config ./ci_scripts/train/ci_7B_sft.py [[ $? -ne 0 ]] && { echo "test slurm training failed."; exit_code=$(($exit_code + 1)); } num=$(num_files "${CKPTS20_OUTPUT}") diff --git a/ci_scripts/train/torchrun.sh b/ci_scripts/train/torchrun.sh index 31681d02c..5c928d84c 100644 --- a/ci_scripts/train/torchrun.sh +++ b/ci_scripts/train/torchrun.sh @@ -22,7 +22,7 @@ if [[ -d ${CKPTS20_PATH} ]]; then fi fi -srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -N 1 torchrun --nnodes=1 --nproc_per_node=8 --master_port=29501 train.py --config ./ci_scripts/train/ci_7B_sft.py --launcher torch +srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -N 1 torchrun --nnodes=1 --nproc_per_node=8 --master_port=29501 internlm/launcher/launch.py --config ./ci_scripts/train/ci_7B_sft.py --launcher torch [[ $? -ne 0 ]] && { echo "test torch training failed."; exit_code=$(($exit_code + 1)); } num=$(num_files "${CKPTS_OUTPUT}") diff --git a/doc/code-docs/locales/en/LC_MESSAGES/training.po b/doc/code-docs/locales/en/LC_MESSAGES/training.po index 25b4a4927..fc59d8c13 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/training.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/training.po @@ -68,10 +68,10 @@ msgstr "Initialize Distributed Training Environment" #: ../../source/training.rst:23 msgid "" -"调用 ``initialize_distributed_env`` 函数,支持通过 slurm 或 torch " +"调用 ``init_distributed`` 函数,支持通过 slurm 或 torch " "方式启动训练脚本,并传入配置文件、端口号、进程随机种子等信息。函数详细说明如下:" msgstr "" -"Call the initialize_distributed_env function, which supports launching " +"Call the init_distributed function, which supports launching " "the training script through Slurm or Torch, and pass in information such " "as the configuration file, port number, and process random seed. Detailed" " description of the function is as follows:" diff --git a/doc/code-docs/source/example/20B_demo.rst b/doc/code-docs/source/example/20B_demo.rst index da7f1d2df..0fd0d0221 100644 --- a/doc/code-docs/source/example/20B_demo.rst +++ b/doc/code-docs/source/example/20B_demo.rst @@ -167,7 +167,7 @@ .. code-block:: bash - srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/20B_sft.py + srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/20B_sft.py 训练结果 ---------------- diff --git a/doc/code-docs/source/example/7B_demo.rst b/doc/code-docs/source/example/7B_demo.rst index 78154175e..67df4261e 100644 --- a/doc/code-docs/source/example/7B_demo.rst +++ b/doc/code-docs/source/example/7B_demo.rst @@ -165,7 +165,7 @@ .. code-block:: bash - srun -p internllm -N 1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py + srun -p internllm -N 1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py 训练结果 ---------------- diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst index 721eec006..9b7ee3b3c 100644 --- a/doc/code-docs/source/initialize.rst +++ b/doc/code-docs/source/initialize.rst @@ -43,7 +43,7 @@ InternEvo 使用 `argparse `_ 模型初始化 ------------------------- -.. autofunction:: internlm.train.initialize_model_and_parallel_communicator +.. autofunction:: internlm.initialize.initialize_model.initialize_model_and_parallel_communicator InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下: diff --git a/doc/code-docs/source/training.rst b/doc/code-docs/source/training.rst index f43bfe4af..22b0ed2ba 100644 --- a/doc/code-docs/source/training.rst +++ b/doc/code-docs/source/training.rst @@ -18,11 +18,11 @@ - 初始化分布式训练环境 .. code-block:: python - initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + init_distributed(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) -调用 ``initialize_distributed_env`` 函数,支持通过 slurm 或 torch 方式启动训练脚本,并传入配置文件、端口号、进程随机种子等信息。函数详细说明如下: +调用 ``init_distributed`` 函数,支持通过 slurm 或 torch 方式启动训练脚本,并传入配置文件、端口号、进程随机种子等信息。函数详细说明如下: -.. autofunction:: internlm.initialize.initialize_distributed_env +.. autofunction:: internlm.initialize.init_distributed - 初始化模型 .. code-block:: python diff --git a/doc/en/usage.md b/doc/en/usage.md index 8e1670c2f..f8ae268a1 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -407,7 +407,7 @@ After completing the data preparation and relevant training configurations menti If you want to start distributed training on slurm with 16 GPUs across multiple nodes, use the following command: ```bash -$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py ``` If you want to start distributed training on torch with 8 GPUs on a single node, use the following command: diff --git a/doc/usage.md b/doc/usage.md index 7c28d6d3e..cba2b4be2 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -453,7 +453,7 @@ parallel = dict( 若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示: ```bash -$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py ``` 若在 torch 上启动分布式运行环境,单节点 8 卡的运行命令如下所示: diff --git a/generate.py b/generate.py index 48efa8b3f..69d4f1c51 100644 --- a/generate.py +++ b/generate.py @@ -18,10 +18,12 @@ from internlm.apis.inference import SequenceGenerator from internlm.core.context import global_context as gpc from internlm.data import build_generation_loader_with_data_type -from internlm.initialize import initialize_distributed_env +from internlm.initialize import initialize_launcher +from internlm.initialize.initialize_model import ( + initialize_model_and_parallel_communicator, +) from internlm.monitor import initialize_monitor_manager -from internlm.monitor.monitor import monitor_manager as mm -from internlm.train import initialize_model_and_parallel_communicator +from internlm.monitor import monitor_manager as mm from internlm.utils.common import ( enable_pytorch_expandable_segments, launch_time, @@ -219,7 +221,7 @@ def main(): hostname = socket.gethostname() # initialize distributed environment - initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file" assert ( diff --git a/internlm/__init__.py b/internlm/__init__.py index dc34a3167..e69de29bb 100644 --- a/internlm/__init__.py +++ b/internlm/__init__.py @@ -1,9 +0,0 @@ -from .initialize.initialize_trainer import initialize_trainer -from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch - -__all__ = [ - "get_default_parser", - "initialize_trainer", - "launch_from_slurm", - "launch_from_torch", -] diff --git a/internlm/accelerator/abstract_accelerator.py b/internlm/accelerator/abstract_accelerator.py index ffedaed59..734d45da5 100644 --- a/internlm/accelerator/abstract_accelerator.py +++ b/internlm/accelerator/abstract_accelerator.py @@ -1,8 +1,10 @@ """ Universal accelerator interface implementation, inspired by DeepSpeed. """ +import abc import enum import os +from abc import ABC class AcceleratorType(enum.Enum): @@ -17,57 +19,72 @@ class AcceleratorType(enum.Enum): internlm_accelerator = None -class Accelerator: +class Accelerator(ABC): """ Abstract base class for accelerator """ def __init__(self) -> None: - pass + self._name_str = None + self._communication_backend_name = None + @abc.abstractmethod def get_backend_name(self): """ Return the name of the accelerator. """ raise NotImplementedError + @abc.abstractmethod def get_accelerator_backend(self): """ - Return the name of the backend. + Return the name of the accelerator backend. """ raise NotImplementedError - # Device APIs + @abc.abstractmethod + def communication_backend_name(self): + """ + Return the name of the communication backend. + """ + raise NotImplementedError + + @abc.abstractmethod def device_name(self, device_index=None): """ Return the name of the device. """ raise NotImplementedError + @abc.abstractmethod def set_device(self, device_index): """ Bind the current process to a device. """ raise NotImplementedError + @abc.abstractmethod def get_device_id(self): """ Return the current device index. """ raise NotImplementedError + @abc.abstractmethod def current_device_name(self): """ Return the name of the current device. """ raise NotImplementedError + @abc.abstractmethod def device_count(self): """ Return the number of devices on the machine. """ raise NotImplementedError + @abc.abstractmethod def synchronize(self, device_index=None): """ Synchronize the current process. diff --git a/internlm/accelerator/cuda_accelerator.py b/internlm/accelerator/cuda_accelerator.py index 48a471657..d5986077c 100644 --- a/internlm/accelerator/cuda_accelerator.py +++ b/internlm/accelerator/cuda_accelerator.py @@ -14,6 +14,7 @@ class CUDA_Accelerator(Accelerator): """ def __init__(self) -> None: + super().__init__() self._name_str = "cuda" self._communication_backend_name = "nccl" self.amp = self.get_amp() diff --git a/internlm/accelerator/dipu_accelerator.py b/internlm/accelerator/dipu_accelerator.py index 7943b4c7f..b5383eded 100644 --- a/internlm/accelerator/dipu_accelerator.py +++ b/internlm/accelerator/dipu_accelerator.py @@ -14,6 +14,7 @@ class DIPU_Accelerator(Accelerator): """ def __init__(self) -> None: + super().__init__() self._name_str = "cuda" self._communication_backend_name = "nccl" self.amp = self.get_amp() diff --git a/internlm/accelerator/ditorch_accelerator.py b/internlm/accelerator/ditorch_accelerator.py index 528b858e2..e4a2fca54 100644 --- a/internlm/accelerator/ditorch_accelerator.py +++ b/internlm/accelerator/ditorch_accelerator.py @@ -14,6 +14,7 @@ class DITORCH_Accelerator(Accelerator): """ def __init__(self) -> None: + super().__init__() self._name_str = "cuda" self._communication_backend_name = "nccl" self.amp = self.get_amp() diff --git a/internlm/accelerator/npu_accelerator.py b/internlm/accelerator/npu_accelerator.py index e1bd3549d..e078014e6 100644 --- a/internlm/accelerator/npu_accelerator.py +++ b/internlm/accelerator/npu_accelerator.py @@ -14,6 +14,7 @@ class ASCEND_Accelerator(Accelerator): """ def __init__(self) -> None: + super().__init__() self._name_str = "npu" self._communication_backend_name = "hccl" self.amp = self.get_amp() diff --git a/internlm/apis/inference_utils.py b/internlm/apis/inference_utils.py index 423e7aafe..931d10537 100644 --- a/internlm/apis/inference_utils.py +++ b/internlm/apis/inference_utils.py @@ -2,7 +2,7 @@ from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.parallel.comm.utils import _gather as gather +from internlm.core.parallel.comm.utils import _gather class InferenceParams: @@ -64,6 +64,6 @@ def process_parallel_output(model_output): # gather tp parallel output if gpc.config.model.parallel_output and gpc.is_initialized(ParallelMode.TENSOR): - return gather(model_output, ParallelMode.TENSOR, -1) + return _gather(model_output, ParallelMode.TENSOR, -1) else: return model_output diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 1f9bf6a6c..eb25f19b6 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -11,16 +11,14 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState -from internlm.initialize.launch import get_config_value -from internlm.initialize.legacy.launch import ( - auto_resume_sanity_check, - ckpt_info_sanity_check, +from internlm.model.model_implementations.registry import model_initializer +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, ) -from internlm.model.base_model import BaseModel -from internlm.model.registry import model_initializer from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device +from internlm.utils.config import get_config_value from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import is_using_fsdp, is_using_hf @@ -287,7 +285,7 @@ def __init__( k: partial(try_load_internlm_ckpt_func, func=v) for k, v in LOAD_FUNC_DICT.items() } # Register huggingface ckpt load type - if isinstance(model, BaseModel): + if isinstance(model, BaseTransformerModel): self.defalut_load_type_func.update( { "hf": partial( @@ -309,14 +307,10 @@ def __init__( f.write("0") self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None) - if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces - self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config) # Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'. - self.auto_resume = get_config_value(ckpt_config, "auto_resume", None) - if self.auto_resume is None: # (legacy): Try Compatible with old interfaces - self.auto_resume = auto_resume_sanity_check(ckpt_config) - if self.auto_resume: + self.auto_resume = get_config_value(ckpt_config, "auto_resume", False) + if self.auto_resume and self.save_ckpt_folder and self.has_available_ckpt(self.save_ckpt_folder): self.load_ckpt_info = self.query_lastest_ckpt() if self.stop_file_path is None and gpc.is_rank_for_log(): @@ -391,6 +385,16 @@ def quit_signal_handler(self, train_state) -> bool: return now_break, now_save_ckpt, save_type + def has_available_ckpt(self, folder) -> bool: + """Check if there is an available ckpt in the folder.""" + folder = folder.split(":")[-1] + for _, _, files in os.walk(folder, followlinks=True): + for fn in files: + fn = fn.strip("/") + if fn.endswith(".step"): + return True + return False + def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSaveType, bool): save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False if force: @@ -444,7 +448,7 @@ def try_save_checkpoint(self, train_state, force=False): ) if ( - isinstance(self.model, BaseModel) + isinstance(self.model, BaseTransformerModel) and self.enable_internevo2hf_ckpt and save_type == CheckpointSaveType.NORMAL_CHECKPOINT and gpc.is_rank_for_log() diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py index d96bb65c5..a94237ac5 100644 --- a/internlm/checkpoint/components.py +++ b/internlm/checkpoint/components.py @@ -9,7 +9,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState -from internlm.model.moe import MoE +from internlm.model.model_ops.moe import MoE from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device from internlm.utils.lazy import LazyObject diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index dde4bc523..5b9ad74de 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -1,8 +1,12 @@ # Copyright (c) InternLM. All rights reserved. -from internlm.model.modeling_internlm import InternLM1 -from internlm.model.modeling_internlm2 import InternLM2 -from internlm.model.modeling_llama import Llama2 +from internlm.model.model_implementations.transformers.modeling_internlm import ( + InternLM1, +) +from internlm.model.model_implementations.transformers.modeling_internlm2 import ( + InternLM2, +) +from internlm.model.model_implementations.transformers.modeling_llama import Llama2 from internlm.utils.logger import get_logger logger = get_logger(__file__) diff --git a/internlm/checkpoint/utils.py b/internlm/checkpoint/utils.py index cd8fae4bf..401bd54ec 100644 --- a/internlm/checkpoint/utils.py +++ b/internlm/checkpoint/utils.py @@ -1,17 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import itertools - -import numpy as np -import torch -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from internlm.core.context import global_context as gpc -from internlm.core.parallel.shard import split_data_for_sequence_parallel -from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp logger = get_logger(__file__) @@ -53,67 +44,3 @@ def process_load_info(load_info): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") return load_content_str, load_ckpt_folder, load_content - - -def init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP: - """ - Initialize Fully Sharded Data Parallel (FSDP) for the model. - This function is needed to properly initialize FSDP when resuming from a checkpoint. - It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. - - References: - https://github.com/pytorch/pytorch/issues/113496 - https://github.com/huggingface/transformers/pull/34032 - https://github.com/huggingface/transformers/issues/31892 - - Args: - model: The model to initialize with FSDP. - device: The device to run the model on. - - Returns: - The initialized FSDP model. - """ - model.train() - with torch.no_grad(): - # generate dummy packed sequence - seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz - input_ids = [1] * seq_len - label = input_ids[1:] + [-100] - cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len)) - - input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) - label = torch.tensor(label, device=device).unsqueeze(0) - indexes = torch.tensor( - list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])), - device=device, - ).unsqueeze(0) - cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0) - - data = { - "input_ids": input_ids, - "cu_seqlens": cu_seqlens, - "indexes": indexes, - "max_seqlen": seq_len, - } - - data_fns = [] - - # default data process function - if gpc.config.data.use_packed_dataset: - data_fns.append(packed_data_normalizer) - else: - data_fns.append(unpack_data) - - # support sequence parallel for isp - if is_using_isp(): - data_fns.append(split_data_for_sequence_parallel) - - # generate dummy_input - _data, _label = data, label - for fn in data_fns: - _data, _label = fn(_data, _label) - dummy_input = _data - - # run a forward pass with dummy_input to initialize FSDP - _ = model(**dummy_input) - return model diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index b2fc95cc9..be444ea30 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -5,7 +5,6 @@ IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_EXPERT_DATA_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, - Config, ParallelContext, global_context, ) @@ -19,6 +18,7 @@ ProcessGroupInitializer, ) from .random import ( + _SEED_MANAGER, add_seed, get_current_mode, get_seeds, @@ -30,7 +30,6 @@ ) __all__ = [ - "Config", "IS_REPLICA_EXPERT_DATA_PARALLEL", "IS_TENSOR_ZERO_PARALLEL", "IS_REPLICA_ZERO_PARALLEL", @@ -54,4 +53,5 @@ "get_current_mode", "set_seed_states", "sync_states", + "_SEED_MANAGER", ] diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 5278426ed..fbe2b6247 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -3,12 +3,8 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context -import inspect import random import socket -import sys -from importlib.machinery import SourceFileLoader -from pathlib import Path from typing import Union import numpy as np @@ -17,6 +13,7 @@ from internlm.accelerator import get_accelerator from internlm.utils.common import SingletonMeta +from internlm.utils.config import Config from internlm.utils.logger import get_logger from internlm.utils.timeout import LLM_NCCL_TIMEOUT from internlm.utils.utils import TensorParallelMode @@ -46,97 +43,6 @@ internlm_accelerator = get_accelerator() -class Config(dict): - """This is a wrapper class for dict objects so that values of which can be - accessed as attributes. - - Args: - config (dict): The dict object to be wrapped. - """ - - def __init__(self, config: dict = None): # pylint: disable=W0231 - if config is not None: - for k, v in config.items(): - self._add_item(k, v) - - def __missing__(self, key): - raise KeyError(key) - - def __getattr__(self, key): - try: - value = super().__getitem__(key) - return value - except KeyError: - raise AttributeError(key) - - def __setattr__(self, key, value): - super().__setitem__(key, value) - - def _add_item(self, key, value): - if isinstance(value, dict): - self.__setattr__(key, Config(value)) - else: - self.__setattr__(key, value) - - def update(self, config): - assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." - for k, v in config.items(): - self._add_item(k, v) - return self - - @staticmethod - def from_file(filename: str): - """Reads a python file and constructs a corresponding :class:`Config` object. - - Args: - filename (str): Name of the file to construct the return object. - - Returns: - :class:`Config`: A :class:`Config` object constructed with information in the file. - - Raises: - AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file - """ - - # check config path - if isinstance(filename, str): - filepath = Path(filename).absolute() - elif isinstance(filename, Path): - filepath = filename.absolute() - - assert filepath.exists(), f"{filename} is not found, please check your configuration path" - - # check extension - extension = filepath.suffix - assert extension == ".py", "only .py files are supported" - - # import the config as module - remove_path = False - if filepath.parent not in sys.path: - sys.path.insert(0, (filepath)) - remove_path = True - - module_name = filepath.stem - source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) - module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 - - # load into config - config = Config() - - for k, v in module.__dict__.items(): - if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): - continue - else: - config._add_item(k, v) - - # remove module - del sys.modules[module_name] - if remove_path: - sys.path.pop(0) - - return config - - class ParallelContext(metaclass=SingletonMeta): """This class provides interface functions for users to get the parallel context, such as the global rank, the local rank, the world size, etc. of each device. diff --git a/internlm/core/engine.py b/internlm/core/engine.py index 5989536dc..97cb41db0 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import _LRScheduler from internlm.core.gradient_handler import BaseGradientHandler -from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer -from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler +from internlm.solver.optimizer import BaseOptimizer +from internlm.solver.schedulers import Beta2Scheduler from internlm.utils.common import get_batch_size, move_to_device diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py new file mode 100644 index 000000000..3f74b7b34 --- /dev/null +++ b/internlm/core/fsdp.py @@ -0,0 +1,224 @@ +import collections +import functools +import itertools +from typing import List, Optional, Set, Union + +import numpy as np +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + +from internlm.accelerator.abstract_accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import split_data_for_sequence_parallel +from internlm.data.utils import packed_data_normalizer, unpack_data +from internlm.utils.common import get_current_device +from internlm.utils.lazy import LazyObject +from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp + +internlm_accelerator = get_accelerator() +logger = get_logger(__file__) + +try: + from torch.distributed._composable.fsdp import fully_shard + + FSDP2_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + FSDP2_SUPPORTED = False + +try: + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, + ) + + DCP_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + DCP_SUPPORTED = False + + +def _get_modules_to_materialize( + root_module: nn.Module, + ignored_modules: Set[nn.Module], +) -> List[nn.Module]: + # Run BFS to collect the modules to materialize via `reset_parameters()`, + # stopping at any module with FSDP already applied or at ignored modules. + modules_to_materialize: List[nn.Module] = [] + queue = collections.deque([root_module]) + visited_modules: Set[nn.Module] = {root_module} + while queue: + module = queue.popleft() + modules_to_materialize.append(module) + for child_module in module.children(): + if child_module not in visited_modules and child_module not in ignored_modules: + visited_modules.add(child_module) + queue.append(child_module) + return modules_to_materialize + + +def _materialize_meta_module( + root_module: nn.Module, + ignored_modules: Set[nn.Module], + device_id: Optional[torch.device], +) -> None: + # Run default meta device initialization + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + module = None + try: + # Assume that each module's `reset_parameters()` only initializes its + # own parameters and not those of its children + with torch.no_grad(): + for module in modules_to_materialize: + # As a contract to the user, only call `reset_parameters()` if + # the module has directly managed parameters/buffers + module_state_iter = itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False)) + has_module_states = len(list(module_state_iter)) > 0 + if has_module_states: + module.to_empty(device=device_id, recurse=False) + module.reset_parameters() # type: ignore[operator] + except BaseException as e: + logger.warning( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure that your module of" + f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] + ) + raise e + + +def _init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP: + """ + Initialize Fully Sharded Data Parallel (FSDP) for the model. + This function is needed to properly initialize FSDP when resuming from a checkpoint. + It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. + + References: + https://github.com/pytorch/pytorch/issues/113496 + https://github.com/huggingface/transformers/pull/34032 + https://github.com/huggingface/transformers/issues/31892 + + Args: + model: The model to initialize with FSDP. + device: The device to run the model on. + + Returns: + The initialized FSDP model. + """ + model.train() + with torch.no_grad(): + # generate dummy packed sequence + seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz + input_ids = [1] * seq_len + label = input_ids[1:] + [-100] + cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len)) + + input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) + label = torch.tensor(label, device=device).unsqueeze(0) + indexes = torch.tensor( + list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])), + device=device, + ).unsqueeze(0) + cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0) + + data = { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": indexes, + "max_seqlen": seq_len, + } + + data_fns = [] + + # default data process function + if gpc.config.data.use_packed_dataset: + data_fns.append(packed_data_normalizer) + else: + data_fns.append(unpack_data) + + # support sequence parallel for isp + if is_using_isp(): + data_fns.append(split_data_for_sequence_parallel) + + # generate dummy_input + _data, _label = data, label + for fn in data_fns: + _data, _label = fn(_data, _label) + dummy_input = _data + + # run a forward pass with dummy_input to initialize FSDP + _ = model(**dummy_input) + return model + + +def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): + if is_using_fsdp(): + assert isinstance(model, nn.Module), "Currently FSDP does not support pipeline parallel." + wrap_cls = tuple( + LazyObject(warp_cls["mod"], warp_cls["mod_cls"]).build() for warp_cls in gpc.config.get("fsdp_wrap_cls", []) + ) + fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1") + fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") + + if fsdp_mode == "v1": + model = FSDP( + module=model, + process_group=gpc.get_group(ParallelMode.GLOBAL), + sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD + auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)), + sync_module_states=fsdp_init_method != "cuda", # sync model paramters + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + use_orig_params=True, + device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states + ) + # For FSDP v1, to get ckpt resuming work normally, we do dummy forward. + # This hack is needed due to FSDP v1 lazy initialization in model construction. + # FYI: https://github.com/pytorch/pytorch/issues/113496 + model = _init_fsdp_v1(model, get_current_device()) + elif FSDP2_SUPPORTED and fsdp_mode == "v2": + fsdp_kwargs = { + "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True + } + for module in model.modules(): + if isinstance(module, wrap_cls): + fully_shard(module, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + if fsdp_init_method == "meta": + _materialize_meta_module(model, set(), get_current_device()) + elif fsdp_init_method == "cpu": + model.to(get_current_device()) + else: + raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}") + + if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False): + load_ckpt_info = gpc.config.ckpt.load_ckpt_info + load_ckpt_path = load_ckpt_info.get("path", None) + load_ckpt_content = load_ckpt_info.get("content", []) + if load_ckpt_path: + assert load_ckpt_content == ( + "model", + ), "If auto_resume=False and checkpoint path is given, only model can be loaded" + if DCP_SUPPORTED: + hf = gpc.config.hf + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + state_dict = mod.from_pretrained( + pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True + ).state_dict() + state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} + set_model_state_dict( + model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) + ) + del state_dict + internlm_accelerator.empty_cache() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + return model diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 7cac640da..177a5c1c4 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -14,7 +14,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc internlm_accelerator = get_accelerator() diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py index be170f286..422578ed1 100644 --- a/internlm/core/parallel/comm/__init__.py +++ b/internlm/core/parallel/comm/__init__.py @@ -1,3 +1,48 @@ from .attn_offload import get_offload_manager, initialize_offload_manager +from .isp import ( + EmbeddingWeightParallelCommunicator, + HeadWeightParallelCommunicator, + ISPCommModelConfig, + ISPCommunicator, + ISPCommunicatorSchedulerHook, + ISPCommunicatorWrapper, + WPCommunicator, + auto_wrap_distributed_attention, + auto_wrap_func_distributed_attention, +) +from .tensor import ( + EmbeddingSequenceParallelCommunicator, + EmbeddingTensorParallelCommunicator, + HeadSequenceParallelCommunicator, + HeadTensorParallelCommunicator, + LinearRole, + MoESequenceParallelCommunicator, + SequenceParallelCommunicator, + TensorParallelCommunicator, + TPCommunicator, +) +from .zero import ParamAsyncBcastHandler -__all__ = ["initialize_offload_manager", "get_offload_manager"] +__all__ = [ + "initialize_offload_manager", + "get_offload_manager", + "EmbeddingWeightParallelCommunicator", + "HeadWeightParallelCommunicator", + "ISPCommModelConfig", + "ISPCommunicator", + "ISPCommunicatorWrapper", + "ISPCommunicatorSchedulerHook", + "WPCommunicator", + "auto_wrap_distributed_attention", + "auto_wrap_func_distributed_attention", + "EmbeddingSequenceParallelCommunicator", + "EmbeddingTensorParallelCommunicator", + "HeadSequenceParallelCommunicator", + "HeadTensorParallelCommunicator", + "LinearRole", + "MoESequenceParallelCommunicator", + "SequenceParallelCommunicator", + "TensorParallelCommunicator", + "TPCommunicator", + "ParamAsyncBcastHandler", +] diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 23a92980c..0590f03ca 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -25,9 +25,8 @@ expandKVPacked, reduce_scatter_raw, ) -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import ParallelLinearWithCommExt -from internlm.model.modules.utils import is_moe_param +from internlm.model.model_ops.modules.linear import ParallelLinearWithCommExt +from internlm.model.model_ops.modules.utils import is_moe_param from internlm.utils.common import SchedulerHook, UniqueChainMap, get_current_device from internlm.utils.utils import ( CuSeqlenType, @@ -179,14 +178,19 @@ class EmbeddingWeightParallelCommunicator: """ def __init__(self, parallel_mode: ParallelMode) -> None: + from internlm.model.model_ops.modules.embedding import Embedding1D + + self.embedding1d_cls = Embedding1D self.parallel_mode = parallel_mode self.gather_dim = 0 self._cur_micro_step = 0 self._num_micro_step = gpc.config.data.micro_num - def register_module_hook(self, module: Embedding1D) -> None: - assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D" + def register_module_hook(self, module: nn.Module) -> None: + assert isinstance( + module, self.embedding1d_cls + ), "Embbeding weight parallel communicator is only support Embedding1D" module.weight.evo_tensor = None self.gather_dim = 0 if module.vocab_parallel else 1 diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index 229f9a9b0..a9c8b1f44 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -10,7 +10,7 @@ from torch import distributed as dist from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc from internlm.core.parallel.comm.utils import ( DUMMY_HANDLE_CONST, AsyncCommHandle, @@ -23,8 +23,7 @@ reduce_scatter_raw, split_forward_gather_backward, ) -from internlm.model.modules.embedding import Embedding1D -from internlm.model.moe.moe import MoE +from internlm.model.model_ops.moe.moe import MoE # input gather dim _GATHER_DIM = 1 # shape: [batch, seqlen, dim] or [1, packlen, dim] @@ -339,14 +338,21 @@ class EmbeddingTensorParallelCommunicator: """ def __init__(self, parallel_mode: ParallelMode) -> None: + from internlm.model.model_ops.modules.embedding import Embedding1D + + self.embedding1d_class = Embedding1D self._parallel_mode = parallel_mode - def register_module_hook(self, module: Embedding1D) -> None: - assert isinstance(module, Embedding1D), "Embbeding tensor parallel communicator is only support Embedding1D" + def register_module_hook(self, module: torch.nn.Module) -> None: + assert isinstance( + module, self.embedding1d_class + ), "Embbeding tensor parallel communicator is only support Embedding1D" module.register_forward_hook(self.output_hook) - def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + def output_hook( + self, module: torch.nn.Module, args: Any, output: Tuple[Any] # pylint: disable=W0613 + ) -> Tuple[Any]: """ split output after forward and allgather grad_output before backward. """ @@ -366,14 +372,21 @@ class EmbeddingSequenceParallelCommunicator: """ def __init__(self, parallel_mode: ParallelMode) -> None: + from internlm.model.model_ops.modules.embedding import Embedding1D + + self.embedding1d_class = Embedding1D self._parallel_mode = parallel_mode - def register_module_hook(self, module: Embedding1D) -> None: - assert isinstance(module, Embedding1D), "Embbeding sequence parallel communicator is only support Embedding1D" + def register_module_hook(self, module: torch.nn.Module) -> None: + assert isinstance( + module, self.embedding1d_class + ), "Embbeding sequence parallel communicator is only support Embedding1D" module.register_forward_hook(self.output_hook) - def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + def output_hook( + self, module: torch.nn.Module, args: Any, output: Tuple[Any] # pylint: disable=W0613 + ) -> Tuple[Any]: """ split output after forward and allgather grad_output before backward. """ diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py index 58929290f..72218a556 100644 --- a/internlm/core/parallel/comm/zero.py +++ b/internlm/core/parallel/comm/zero.py @@ -11,9 +11,8 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import unwrap_naive_amp -from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import ScaleColumnParallelLinear +from internlm.core.parallel.comm import ISPCommunicatorWrapper +from internlm.model.model_ops.modules.linear import ScaleColumnParallelLinear from internlm.solver.optimizer.utils import flatten @@ -28,6 +27,10 @@ def __init__( model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicatorWrapper = None, ) -> None: + from internlm.model.model_ops.modules.embedding import Embedding1D + + self.embedding1d_cls = Embedding1D + self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict() self._param_to_rank: Dict[nn.Parameter, int] = {} self._block_to_rank: Dict[nn.Module, int] = {} @@ -121,7 +124,7 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06 # NOTE: Although the layernorm layer does not have explicit processing, # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, # so everything is fine. - if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)): + if isp_communicator is None or isinstance(block, (self.embedding1d_cls, ScaleColumnParallelLinear)): block.register_forward_pre_hook(_pre_forward_hook) if isp_communicator: isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) @@ -170,7 +173,7 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06 # NOTE: Although the layernorm layer does not have explicit processing, # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, # so everything is fine. - if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)): + if isp_communicator is None or isinstance(block, (self.embedding1d_cls, ScaleColumnParallelLinear)): block.register_forward_pre_hook(_pre_forward_hook) if isp_communicator: isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 308fbb897..22e297c3c 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -34,7 +34,7 @@ def _split_data_for_sequence_parallel(data, label): data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim) # NOTICE: For compatibility where the shape of position_ids is [batch, seqlen, ...] - if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): + if is_using_hf(): _position_ids_seq_dim = 1 data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_position_ids_seq_dim) diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 3b01d3afd..6fd85886d 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -4,12 +4,25 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine import json +import math import os +import time from collections import deque -from typing import Iterable, Optional +from typing import Iterable, List, Optional +from torch.utils.data import DataLoader + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc from internlm.core.engine import Engine +from internlm.core.parallel.comm import ISPCommunicatorSchedulerHook from internlm.core.scheduler import BaseScheduler, NonPipelineScheduler +from internlm.data.utils import unpack_type_ids +from internlm.model.model_ops.metrics import SchedulerMetricHook +from internlm.monitor import monitor_manager as mm +from internlm.utils.common import SchedulerHook, set_env_var +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.timeout import llm_timeout class TrainState: @@ -206,3 +219,220 @@ def execute_schedule(self, data_iter: Iterable, **kwargs): Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss). """ return self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) + + +def get_scheduler_hooks(metric, zero_optim, isp_communicator_wrapper) -> List[SchedulerHook]: + scheduler_hooks: List[SchedulerHook] = [] + + if metric is not None: + scheduler_hooks.append( + SchedulerMetricHook( + metric=metric, + skip=( + gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) + ), + ), + ) + + if isp_communicator_wrapper is not None: + for isp_communicator in isp_communicator_wrapper.isp_communicators: + if isp_communicator is not None and isp_communicator.overlap: + scheduler_hooks.append(ISPCommunicatorSchedulerHook(isp_communicator, zero_optim)) + + return scheduler_hooks + + +@llm_timeout(func_name="load_new_batch") +def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): + """ + Load and return the new batch data based on training data loader. + + Args: + train_dl (torch.utils.data.DataLoader): Dataloader for training. + train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). + train_state (TrainState): Current training state. + + Returns: A batch data and the updated train_iter. + """ + + timer("batch-gen").start() + try: + batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor) + if hasattr(train_state, "batch_sampler_iter"): + next(train_state.batch_sampler_iter) + except StopIteration: + train_iter = iter(train_dl) + batch = next(train_iter) + train_state.num_consumed_samples_in_epoch = 0 + if hasattr(train_state, "batch_sampler"): + train_state.batch_sampler.batch_count = 0 + train_state.batch_sampler.num_consumed_samples_in_epoch = 0 + train_state.batch_sampler_iter = iter(train_state.batch_sampler) + next(train_state.batch_sampler_iter) + timer("batch-gen").stop() + + if batch[0].get("type_ids", None) is not None: + # if use_packed_dataset is False, we need to unpack type_ids + if not gpc.config.data.use_packed_dataset: + batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) + + return batch, train_iter + + +@llm_timeout(func_name="record_current_batch_training_metrics") +def record_current_batch_training_metrics( + get_tflops_func, + logger, + writer, + success_update, + batch_count, + batch, + train_state, + optimizer, + beta2_scheduler, + engine, + start_time, + very_begining_time, + loss, + moe_loss, + grad_norm, + metric, +): + """ + Print some training metrics of current batch. + """ + + set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) + + timer.store_last_timers() + if success_update in (0, True): + train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) + if gpc.is_no_pp_or_last_stage(): + acc_perplex = metric.get_metric() + + if success_update and gpc.is_rank_for_log(): + lr = optimizer.param_groups[0]["lr"] + if hasattr(engine.optimizer, "grad_scaler"): + scaler = engine.optimizer.grad_scaler._scale.item() + elif hasattr(engine.optimizer.optim, "grad_scaler"): + scaler = engine.optimizer.optim.grad_scaler._scale.item() + + num_tokens_in_batch = batch[1].nelement() + real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL)) + num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) + max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + time_cost = time.time() - start_time + tk_per_gpu = round( + num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL), + 4, + ) + tgs_statistic = train_state.tgs_statistic + tgs_statistic["sum_step"] += 1 + tgs_statistic["sum_tg"] += tk_per_gpu + tgs_statistic["total_time"] = time.time() - very_begining_time + tgs_statistic["sum_last_tg_10"] += tk_per_gpu + tgs_statistic["sum_last_time_10"] += time_cost + tgs_statistic["sum_last_tg_50"] += tk_per_gpu + tgs_statistic["sum_last_time_50"] += time_cost + tgs_statistic["SMA_tg_50"] += tk_per_gpu + tgs_statistic["SMA_time_50"] += time_cost + tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu) + tgs_statistic["SMA_time_50_list"].append(time_cost) + if tgs_statistic["sum_step"] > 50: + tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0] + tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0] + tgs_statistic["SMA_tg_50_list"].popleft() + tgs_statistic["SMA_time_50_list"].popleft() + + last_tgs_1 = round(tk_per_gpu / time_cost, 2) + tgs_statistic["sum_tgs"] += last_tgs_1 + + if tgs_statistic["sum_step"] % 10 == 0: + tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2) + tgs_statistic["sum_last_tg_10"] = 0 + tgs_statistic["sum_last_time_10"] = 0 + + if tgs_statistic["sum_step"] % 50 == 0: + tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2) + tgs_statistic["sum_last_tg_50"] = 0 + tgs_statistic["sum_last_time_50"] = 0 + + last_tgs_10 = tgs_statistic["last_tgs_10"] + last_tgs_50 = tgs_statistic["last_tgs_50"] + + tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["total_time"], 2) + tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2) + tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2) + + tflops = get_tflops_func(time_cost) + + tgs_origin = round( + num_tokens_in_batch + * gpc.get_world_size(ParallelMode.DATA) + / gpc.get_world_size(ParallelMode.GLOBAL) + / time_cost, + 2, + ) + + real_tgs = round( + real_num_tokens / time_cost, + 2, + ) + + infos = { + "tflops": tflops, + "step": batch_count, + "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), + "real_tgs": real_tgs, + "tgs (tokens/gpu/second)": tgs_origin, + "tgs/last_tgs_1": last_tgs_1, + "tgs/tgs_all": tgs_all, + "tgs/tgs_avg": tgs_avg, + "tgs/tgs_SMA": tgs_SMA, + "tgs/last_tgs_10": last_tgs_10, + "tgs/last_tgs_50": last_tgs_50, + "lr": lr, + "loss_scale": scaler, + "grad_norm": grad_norm, + } + if moe_loss is not None: + infos["moe_loss"] = moe_loss.item() + + infos["micro_num"] = len(batch[1]) + infos["num_consumed_tokens"] = train_state.num_consumed_tokens + infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches + infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples + infos["largest_length"] = max_length_in_batch # the longest input + infos["largest_batch"] = max_samples_in_batch # the batch with the most samples + infos["smallest_batch"] = min_samples_in_batch + infos["adam_beta2"] = beta2_scheduler.get_beta2() + + fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) + infos["fwd_bwd_time"] = fwd_bwd_time + bwd_time = round(timer("bwd").elapsed(), 2) + infos["bwd_time"] = bwd_time + + for key, value in acc_perplex.items(): + infos[key] = value + + line = "" + for key, value in infos.items(): + line += f"{key}={value} " + if isinstance(value, dict): + writer.add_scalars(key=key, value=value, step=train_state.step_count) + else: + writer.add_scalar(key=key, value=value, step=train_state.step_count) + + logger.info(line) + + # if loss spike occurs, send alert info to feishu + mm.monitor_loss_spike( + alert_address=gpc.config.monitor.alert.feishu_alert_address, + step_count=batch_count, + cur_step_loss=loss.item(), + ) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 4c18fd326..532da9494 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -9,25 +9,27 @@ from torch.utils.data import DataLoader from internlm.checkpoint.checkpoint_manager import CheckpointManager +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer import ParallelMode from internlm.core.parallel.comm import initialize_offload_manager -from internlm.core.trainer import Trainer -from internlm.data.streaming.utils import streaming_simple_resume -from internlm.data.train_state import get_train_state -from internlm.eval.evaluation import evaluate_on_val_dls -from internlm.initialize.initialize_trainer import initialize_trainer -from internlm.model.losses.ce_loss import InternLoss -from internlm.model.metrics import AccPerplex -from internlm.monitor.monitor import send_alert_message -from internlm.train.pipeline import ( +from internlm.core.trainer import ( + Trainer, get_scheduler_hooks, - initialize_llm_profile, - initialize_optimizer, - inject_model, load_new_batch, record_current_batch_training_metrics, ) +from internlm.data.streaming.utils import streaming_simple_resume +from internlm.data.train_state import get_train_state +from internlm.eval import evaluate_on_val_dls +from internlm.initialize import initialize_trainer +from internlm.initialize.initialize_model import ( + initialize_model_and_parallel_communicator, +) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize.initialize_profiler import initialize_llm_profile +from internlm.model.model_ops.losses.ce_loss import InternLoss +from internlm.model.model_ops.metrics import AccPerplex +from internlm.monitor import send_alert_message from internlm.utils.common import ( BatchSkipper, check_cuda_env, @@ -99,8 +101,8 @@ def __init__( # load config_lines config_lines = self._read_config(kwargs["config"]) - # inject model for amp, parallel setting, parameter syncing and others - model, isp_communicator = inject_model(model) + # initialize model and communicators + model, isp_communicator = initialize_model_and_parallel_communicator(model) # check cuda env check_cuda_env() diff --git a/internlm/data/tokenized/dummy_dataset.py b/internlm/data/tokenized/dummy_dataset.py index dcb6c027d..f057941bc 100644 --- a/internlm/data/tokenized/dummy_dataset.py +++ b/internlm/data/tokenized/dummy_dataset.py @@ -4,7 +4,7 @@ import numpy as np from torch.utils.data import Dataset -# from internlm.core.context.parallel_context import global_context as gpc +# from internlm.core.context import global_context as gpc class RandomDataset(Dataset): diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 352273c79..74e860997 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,8 +5,8 @@ import torch +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer import ParallelMode from internlm.utils.parallel import is_using_hf @@ -64,8 +64,7 @@ def unpack_data(data, label): # per batch's index should be equal, so we select first batch data["indexes"] = data["indexes"][0] - # If model has inject_info and data_helper is enabled, we provide position_ids - if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): + if is_using_hf(): data.pop("max_seqlen") data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen] @@ -81,8 +80,7 @@ def packed_data_normalizer(data, label): data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0) data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() - # If model has inject_info and data_helper is enabled, we provide position_ids, cu_seqlens, max_seqlen - if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): + if is_using_hf(): gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens") gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen") data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen] diff --git a/internlm/eval/__init__.py b/internlm/eval/__init__.py index dc70e4d45..208779157 100644 --- a/internlm/eval/__init__.py +++ b/internlm/eval/__init__.py @@ -1,5 +1,11 @@ -from .evaluation import evaluate_on_val_dls +from .evaluation import ( + evaluate_on_val_dls, + switch_evaluation_mode, + switch_evaluation_pipeline_scheduler, +) __all__ = [ "evaluate_on_val_dls", + "switch_evaluation_mode", + "switch_evaluation_pipeline_scheduler", ] diff --git a/internlm/eval/evaluation.py b/internlm/eval/evaluation.py index 862057a3d..2b8dd08a3 100644 --- a/internlm/eval/evaluation.py +++ b/internlm/eval/evaluation.py @@ -9,7 +9,7 @@ from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import split_data_for_sequence_parallel from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape -from internlm.model.metrics import AccPerplex, SchedulerMetricHook +from internlm.model.model_ops.metrics import AccPerplex, SchedulerMetricHook from internlm.utils.common import get_current_device from internlm.utils.parallel import is_using_isp diff --git a/internlm/initialize/__init__.py b/internlm/initialize/__init__.py index 14fe06bbb..c7d474957 100644 --- a/internlm/initialize/__init__.py +++ b/internlm/initialize/__init__.py @@ -1,17 +1,7 @@ +from .initialize_launcher import initialize_launcher from .initialize_trainer import initialize_trainer -from .launch import ( - get_default_parser, - initialize_distributed_env, - launch_from_slurm, - launch_from_torch, - try_bind_numa, -) __all__ = [ - "get_default_parser", + "initialize_launcher", "initialize_trainer", - "launch_from_slurm", - "launch_from_torch", - "initialize_distributed_env", - "try_bind_numa", ] diff --git a/internlm/initialize/constants.py b/internlm/initialize/constants.py new file mode 100644 index 000000000..28474d075 --- /dev/null +++ b/internlm/initialize/constants.py @@ -0,0 +1,9 @@ +############################################# +# Default Distributed Master Port # +############################################# +DEFAULT_DISTRIBUTED_PORT = 8888 + +############################################# +# Default Universal Random Seed # +############################################# +DEFAULT_RANDOM_SEED = 1024 diff --git a/internlm/initialize/initialize_communicator.py b/internlm/initialize/initialize_communicator.py new file mode 100644 index 000000000..74fe06af8 --- /dev/null +++ b/internlm/initialize/initialize_communicator.py @@ -0,0 +1,214 @@ +from typing import Iterable, Tuple, TypeVar, Union + +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm import ( + EmbeddingSequenceParallelCommunicator, + EmbeddingTensorParallelCommunicator, + EmbeddingWeightParallelCommunicator, + HeadSequenceParallelCommunicator, + HeadTensorParallelCommunicator, + HeadWeightParallelCommunicator, + ISPCommModelConfig, + ISPCommunicator, + ISPCommunicatorWrapper, + LinearRole, + MoESequenceParallelCommunicator, + SequenceParallelCommunicator, + TensorParallelCommunicator, +) +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import ( + ColumnParallelLinear, + GroupedColumnLinear, + GroupedRowLinear, + GroupedWPLinear, + RewardModelLinear, + RowParallelLinear, + ScaleColumnParallelLinear, +) +from internlm.model.model_ops.moe import Experts, MoE +from internlm.utils.common import get_current_device +from internlm.utils.parallel import is_using_fsdp, is_using_isp +from internlm.utils.utils import TensorParallelMode + +_T = TypeVar("_T") + + +def submodule_filter(model: Union[nn.Module, nn.ModuleList], target_cls: Union[_T, Tuple[_T]]) -> Iterable[_T]: + for _chunk in unwrap_naive_amp(model): + for _module in _chunk.modules(): + if not isinstance(_module, target_cls): + continue + + yield _module + + +def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): + """ + Initialize communicator for isp tensor parallel mode. + + Args: + model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated. + + Returns: + An isp communicator for managing comp/comm overlap. + """ + isp_communicator_wrapper = None + _retain_out_sharded = gpc.config.model.get("parallel_output", True) + + if is_using_isp(): + isp_communicator = ISPCommunicator( + model, + ISPCommModelConfig( + gpc.config.model.dtype, + get_current_device(), + gpc.config.model.checkpoint, + ), + gpc.config.parallel.weight.overlap and not is_using_fsdp(), + gpc.get_group(ParallelMode.WEIGHT), + is_moe=False, + selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), + early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release, + ) + # register communicator for isp column parallel linear. + ColumnParallelLinear.register_cls_communicator(isp_communicator) + # row parallel linear will not be used. + RowParallelLinear.register_cls_communicator(None) + _head_communicator = HeadWeightParallelCommunicator( + weight_process_group=gpc.get_group(ParallelMode.WEIGHT), + seq_process_group=gpc.get_group(ParallelMode.TENSOR), + retain_out_sharded=_retain_out_sharded, + ) + _embedding_communicator = EmbeddingWeightParallelCommunicator(ParallelMode.WEIGHT) + + if gpc.config.model.get("num_experts", 1) > 1: + # register communicator for moe isp column parallel linear. + # NOTE: this wil overwrite registed communicator + moe_isp_communicator = ISPCommunicator( + model, + ISPCommModelConfig( + gpc.config.model.dtype, + get_current_device(), + gpc.config.model.checkpoint, + ), + gpc.config.parallel.expert_weight.overlap, + gpc.get_group(ParallelMode.EXPERT_WEIGHT), + is_moe=True, + early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release, + ) + for moe in submodule_filter(model, Experts): + for column_linear in submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)): + column_linear.register_communicator(moe_isp_communicator) + for row_linear in submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(None) + + isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator, moe_isp_communicator]) + else: + isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator]) + + # register communictor for mtp/msp/fsp linear. + + # tensor parallel + if gpc.config.parallel.tensor.mode == TensorParallelMode.mtp.name: + ColumnParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) + ) + RowParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) + ) + + if gpc.config.model.get("num_experts", 1) > 1: + GroupedColumnLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) + ) + GroupedRowLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) + ) + GroupedWPLinear.register_cls_communicator(None) + # treat as sequence paralle if no_tp + if gpc.config.parallel.expert.no_tp: + _column_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN + ) + _row_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW + ) + for moe in submodule_filter(model, MoE): + # 1. the linear in MoE degrades as no tp communication pattern + for column_linear in submodule_filter(moe, ColumnParallelLinear): + column_linear.register_communicator(_column_communicator) + for row_linear in submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(_row_communicator) + # 2. register MoESequenceParallelCommunicator for MoE layer + MoESequenceParallelCommunicator(ParallelMode.TENSOR, reverse=True).register_module_hook(moe) + + _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR) + # sequence parallel + if gpc.config.parallel.tensor.mode in (TensorParallelMode.msp.name, TensorParallelMode.fsp.name): + save_total_input_as_activation = gpc.config.parallel.tensor.mode == TensorParallelMode.msp.name + + ColumnParallelLinear.register_cls_communicator( + SequenceParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + RowParallelLinear.register_cls_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + if gpc.config.model.get("num_experts", 1) > 1: + GroupedColumnLinear.register_cls_communicator( + SequenceParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + GroupedRowLinear.register_cls_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + GroupedWPLinear.register_cls_communicator(None) + if gpc.config.parallel.expert.no_tp: + _column_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN + ) + _row_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW + ) + for moe in submodule_filter(model, MoE): + # 1. the linear in MoE degrades as no tp communication pattern + for column_linear in submodule_filter(moe, ColumnParallelLinear): + column_linear.register_communicator(_column_communicator) + for row_linear in submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(_row_communicator) + + _head_communicator = HeadSequenceParallelCommunicator( + ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation + ) + + _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR) + + # register communitorc for embedding layer. + if not is_using_fsdp(): + for embedding in submodule_filter(model, Embedding1D): + _embedding_communicator.register_module_hook(embedding) + + # register communictor for head layer. + ScaleColumnParallelLinear.register_cls_communicator(_head_communicator) + RewardModelLinear.register_cls_communicator(_head_communicator) + + return isp_communicator_wrapper diff --git a/internlm/initialize/launch.py b/internlm/initialize/initialize_launcher.py similarity index 91% rename from internlm/initialize/launch.py rename to internlm/initialize/initialize_launcher.py index f9e5b6ff0..a9e925005 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/initialize_launcher.py @@ -2,7 +2,6 @@ # -*- encoding: utf-8 -*- # Copyright (c) InternLM. All rights reserved. -import argparse import os from pathlib import Path from typing import Dict, Union @@ -10,10 +9,11 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import Config +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer import ParallelMode +from internlm.initialize.constants import DEFAULT_DISTRIBUTED_PORT, DEFAULT_RANDOM_SEED from internlm.utils.common import get_master_node +from internlm.utils.config import Config from internlm.utils.gputest import warmup_process_group from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger @@ -35,43 +35,8 @@ internlm_accelerator = get_accelerator() -def get_default_parser(): - """Reads user command line and uses an argument parser to parse the input arguments. - Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. - - Returns: - Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser. - """ - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, help="path to the config file") - parser.add_argument( - "--launcher", - type=str, - default="slurm", - choices=["slurm", "torch"], - help="launcher for launching distributed environment", - ) - parser.add_argument("--host", type=str, help="the master address for distributed training") - parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training") - parser.add_argument("--world_size", type=int, help="world size for distributed training") - parser.add_argument("--rank", type=int, help="rank for the default process group") - parser.add_argument("--local_rank", type=int, help="local rank on the node") - parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") - parser.add_argument("--seed", type=int, default=1024) - parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") - parser.add_argument("--enable_ali_topology", default=False, action="store_true", help="enable ali switch topology.") - parser.add_argument( - "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology." - ) - return parser - - -def inject_hf_config_before_launch(hf: dict): - # get HuggingFace model config - cfg = LazyObject(hf.cfg, hf.cfg_cls) - cfg = cfg.build() - model_config = cfg(**hf.cfg_extra_kwargs) - # inject HuggingFace model config into InternTrain as much as we know +def dispatch_hf_config_before_launch(model_config) -> None: + # dispatch HuggingFace model config into InternEvo model config as much as we know if hasattr(model_config, "vocab_size"): gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size if hasattr(model_config, "num_hidden_layers"): @@ -100,10 +65,12 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) - # inject HuggingFace model config into IntrainTrain + # dispatch HuggingFace model config into InternEvo model config if is_using_hf(): - inject_hf_config_before_launch(gpc.config.hf) - gpc.config.model_type = "hf" + cfg = LazyObject(gpc.config.hf.cfg, gpc.config.hf.cfg_cls) + cfg = cfg.build() + model_config = cfg(**gpc.config.hf.cfg_extra_kwargs) + dispatch_hf_config_before_launch(model_config) if gpc.config.model_type == "InternLM3_M": # TODO: need check for isp overlap @@ -766,14 +733,14 @@ def launch_from_torch( ) -@llm_timeout(func_name="initialize_distributed_env") -def initialize_distributed_env( +@llm_timeout(func_name="init_distributed") +def initialize_launcher( config: str, launcher: str = "slurm", - master_port: int = 8888, - seed: int = 1024, - args_check=True, - backend: str = "nccl", + distributed_port: int = DEFAULT_DISTRIBUTED_PORT, + seed: int = DEFAULT_RANDOM_SEED, + args_check: bool = True, + dist_backend: str = "nccl", ): """ Initialize distributed environment for distributed training. @@ -781,18 +748,18 @@ def initialize_distributed_env( Args: config (str): Config file path. launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default. - master_port (str): The master port for distributed training. 8888 by default. + distributed_port (str): Distributed backend port. 8888 by default. seed (int, optional): Specified random seed for every process. 1024 by default. """ - backend = internlm_accelerator._communication_backend_name + dist_backend = internlm_accelerator.communication_backend_name() if launcher == "torch": - launch_from_torch(config=config, seed=seed, backend=backend) + launch_from_torch(config=config, seed=seed, backend=dist_backend) elif launcher == "slurm": launch_from_slurm( config=config, host=get_master_node(), - port=master_port, + port=distributed_port, seed=seed, ) else: @@ -802,14 +769,6 @@ def initialize_distributed_env( args_sanity_check() -def get_config_value(config, key, defalut): - try: - value = config[key] - except KeyError: - value = defalut - return value - - def try_bind_numa(global_rank, world_size, local_rank=None): # Early return if numa module not available if not get_numa: diff --git a/internlm/initialize/initialize_model.py b/internlm/initialize/initialize_model.py new file mode 100644 index 000000000..9e8c46342 --- /dev/null +++ b/internlm/initialize/initialize_model.py @@ -0,0 +1,228 @@ +from typing import Optional, Union + +import torch +from torch import nn + +from internlm.core.context import ( + IS_REPLICA_EXPERT_DATA_PARALLEL, + IS_REPLICA_ZERO_PARALLEL, + IS_TENSOR_EXPERT_DATA_PARALLEL, + IS_TENSOR_ZERO_PARALLEL, + IS_WEIGHT_EXPERT_DATA_PARALLEL, + IS_WEIGHT_ZERO_PARALLEL, + ParallelMode, +) +from internlm.core.context import global_context as gpc +from internlm.core.context import set_mode +from internlm.core.fsdp import wrap_FSDP_model +from internlm.core.naive_amp import ( + NaiveAMPModel, + set_fp32_attr_to_module, + unwrap_naive_amp, +) +from internlm.initialize.initialize_communicator import initialize_parallel_communicator +from internlm.model.model_implementations.builder import create_model +from internlm.model.model_implementations.registry import register_model_initializer +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import ( + ParallelLinearWithCommExt, + ScaleColumnParallelLinear, +) +from internlm.model.model_ops.moe import Experts, MoE +from internlm.model.model_ops.moe.moe import Qwen2MoE +from internlm.model.model_ops.ops.norm import RMSNorm +from internlm.utils.parallel import ( + is_replica_expert_data_parallel_parameter, + is_replica_zero_parallel_parameter, + is_tensor_expert_data_parallel_parameter, + is_tensor_zero_parallel_parameter, + is_using_fsdp, + is_using_hf, + is_using_isp, + is_weight_expert_data_parallel_parameter, + is_weight_zero_parallel_parameter, + sync_model_param, + sync_model_replica_param_group, +) +from internlm.utils.timeout import llm_timeout + + +def set_param_unique_tracking_name(model): + for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): + # Important: only works for llama-class models + childrens = chunk.named_children() + for _, children in childrens: + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + for name, child in block.named_modules(): + if isinstance(child, (ParallelLinearWithCommExt)): + full_name = f"{chunk_id}.{idx}.{name}" + setattr( + child.weight, + "tracking_name", + f"{full_name}.weight", + ) + if child.bias is not None: + setattr( + child.bias, + "tracking_name", + f"{full_name}.bias", + ) + else: + if isinstance(children, Embedding1D): + setattr( + children.weight, + "tracking_name", + f"{chunk_id}_embedding.weight", + ) + else: + setattr( + children.weight, + "tracking_name", + f"{chunk_id}_head.weight", + ) + + +def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)) and gpc.config.get("use_fp32_norm", False): + set_fp32_attr_to_module(module) + + +def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): + def _check_module(name, module): + # layer_norm + if isinstance(module, (RMSNorm, nn.LayerNorm)): + for param in module.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + + if isinstance(module, (MoE, Qwen2MoE)): + for param in module.moe_layer.gate.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + if hasattr(module, "coefficient"): + for param in module.coefficient.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + + # embedding and head + if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)): + for param in module.parameters(): + if gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): + setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) + elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): + setattr(param, IS_TENSOR_ZERO_PARALLEL, True) + + # for moe linear module + if isinstance(module, nn.Linear) and not isinstance(module, ParallelLinearWithCommExt): + for param in module.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + + if isinstance(module, Experts): + for param in module.parameters(): + if ( + gpc.is_initialized(ParallelMode.TENSOR) + and not is_using_isp() + and getattr(gpc.config.parallel.expert, "no_tp", False) + ): + setattr(param, IS_REPLICA_EXPERT_DATA_PARALLEL, True) + elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): + setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True) + elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): + setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True) + # for non-moe linear module + elif isinstance(module, ParallelLinearWithCommExt): + for param in module.parameters(): + if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): + setattr(param, IS_TENSOR_ZERO_PARALLEL, True) + elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): + setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) + + # for vit and vit project + if "vision_tower" in name.lower() or "vision_proj" in name.lower(): + for param in module.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + + for _chunk in unwrap_naive_amp(model): + if not is_using_fsdp(): + # set param parallel attribute + for name, module in _chunk.named_modules(): + _check_module(name, module) + + for name, param in _chunk.named_parameters(): + assert ( + is_replica_zero_parallel_parameter(param) + or is_tensor_zero_parallel_parameter(param) + or is_weight_zero_parallel_parameter(param) + or is_tensor_expert_data_parallel_parameter(param) + or is_weight_expert_data_parallel_parameter(param) + or is_replica_expert_data_parallel_parameter(param) + ), f"parameter with name: {name} has no parallel attribution." + + +@llm_timeout(func_name="initialize_model_and_parallel_communicator") +def initialize_model_and_parallel_communicator(model: Optional[Union[nn.Module, nn.ModuleList]] = None): + """ + initialize model with Automatic Mixed Precision. + + Returns: + torch.nn.Module: + The neural network model to be trained or evaluated. + An isp communicator for managing comp/comm overlap. + """ + if model is None: + register_model_initializer() + model = create_model() + + # For non-HF cases, set tracking name for parameters + if not is_using_hf(): + set_param_unique_tracking_name(model) + + # should be set before NaiveAMPModel + set_fp32_attr_for_model(model) + + if isinstance(model, nn.ModuleList): + model = nn.ModuleList( + [ + NaiveAMPModel( + model=_m, + output_to_fp32=False, # manually controlled by interleaved pipleline scheduler + dtype=gpc.config.model.get("dtype", torch.half), + sync_buffer=False, + ) + for _m in model + ] + ) + else: + model = NaiveAMPModel( + model=model, + output_to_fp32=gpc.is_no_pp_or_last_stage(), + dtype=gpc.config.model.get("dtype", torch.half), + sync_buffer=False, + ) + + set_parallel_attr_for_param_groups(model) + + # This sync is very important, cause the model weights kept in optimizer are copied + # from the origin parameters in the memory, so we should make sure the dp sync + # does not influence the model weights in optimizer be different with the origin parameters. + if not is_using_fsdp() or gpc.config.parallel.fsdp.get("init_method", "cuda") == "cuda": + sync_model_param(model) + + # This function is needed to make sure parameters that are not splitted by tensor parallelism are + # the same across tensor parallelism. + sync_model_replica_param_group(model) + + # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random + # state in the same dp group are all the same. + random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA + set_mode(random_mode) + + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + model = wrap_FSDP_model(model) + + return model, isp_communicator diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py new file mode 100644 index 000000000..28082cb85 --- /dev/null +++ b/internlm/initialize/initialize_optimizer.py @@ -0,0 +1,189 @@ +from typing import Dict, Tuple, Union + +import torch +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm import ISPCommunicatorWrapper, ParamAsyncBcastHandler +from internlm.model.model_ops.modules.utils import is_moe_param +from internlm.solver.optimizer import ( + FSDPadaptOptimizer, + HybridZeroOptimizer, + HybridZeroOptimizer_v2, +) +from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw +from internlm.solver.schedulers import Beta2Scheduler, FineTuneCosineAnnealingWarmupLR +from internlm.utils.parallel import is_using_fsdp +from internlm.utils.timeout import llm_timeout + + +def split_params_into_different_groups_for_optimizer( + param_groups: Tuple[Dict], +) -> Tuple[Dict]: + """Split parameters into different groups for optimizer + + Args: + param_groups (Tuple[Dict]): The list of parameter groups to split + Input Example: + >>> ( + >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, + >>> ) + + Returns: + Tuple[Dict]: list of params groups for optimizer + Output Example: + >>> ( + >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, + >>> {'name': 'embed_head', 'params': [tensor], 'weight_decay' :xxx}, + >>> {'name': 'fp32', 'params': [tensor], 'weight_decay' :xxx}, + >>> ) + """ + + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f"Unknown param group type of {type(param_groups)}") + + new_groups = {} + # create new groups for fp32 parameter group + new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1} + + if gpc.config.model.get("num_experts", 1) > 1: + for key in gpc.expert_parallel_group_names: + new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": ParallelMode.EXPERT_DATA} + + for pgroup in param_groups: + # copy attribute from origin group, we assume the input param_groups only + # have one group, so the attribute will not be copyed multiple times. + for ori_key in pgroup.keys(): + if ori_key not in ("name", "params"): + for _, group in new_groups.items(): + group[ori_key] = pgroup[ori_key] + # assign param + origin_params = [] + for param in pgroup["params"]: + # moe param means MoE is enabled + if is_moe_param(param): + new_groups[param.group_name]["params"].append(param) + elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32: + new_groups["fp32"]["params"].append(param) + else: + origin_params.append(param) + + # default param group, which is the first group in the param groups + pgroup["params"] = origin_params + pgroup["optimizer_mode"] = ParallelMode.ZERO1 + + # param groups may contain empty groups, such as fp32 + param_groups.extend(new_groups.values()) + + return tuple(param_groups) + + +def create_param_groups(model, weight_decay): + parameters = { + "params": [param for param in model.parameters() if param.requires_grad], + "name": "default", + "weight_decay": weight_decay, + } + return split_params_into_different_groups_for_optimizer(parameters) + + +def map_param_block(model): + for _chunk in unwrap_naive_amp(model): + for name, children in _chunk.named_children(): + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + block_name = name + f"_{idx}" + for param in block.parameters(): + setattr(param, "block_name", block_name) + else: + for param in children.parameters(): + setattr(param, "block_name", name) + + +@llm_timeout(func_name="initialize_optimizer") +def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicatorWrapper = None): + """ + Initialize optimizer. + + Args: + model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated. + + Returns: + A tuple of (optimizer, beta2_scheduler, lr_scheduler). + """ + + adam_cfg = gpc.config.adam + zero_cfg = gpc.config.hybrid_zero_optimizer + grad_scal_cfg = gpc.config.grad_scaler + use_apex_adam = getattr(gpc.config, "use_apex_adam", False) + + if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim: + map_param_block(model) + + params = create_param_groups(model, adam_cfg.weight_decay) + + naive_optimizer = new_compatible_adamw( + params=params, + lr=adam_cfg.lr, + betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), + eps=adam_cfg.adam_eps, + use_apex_adam=use_apex_adam, + ) + + if ( + zero_cfg.overlap_sync_grad + and gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + and gpc.is_pipeline_first_stage() is False + ): + # When pipeline parallelism is enabled, we prefer to only enable optimizer + # gradient communication overlap in the first stage, to avoid amplifying + # the communication overhead stage by stage in cases where the optimizer + # communication overhead is greater than the compute overhead. + # For pipeline stages except the first, even if overlap is not enabled, + # their gradient synchronization overhead can be well hidden by + # the inherent bubbles of pipeline parallelism. + zero_cfg.overlap_sync_grad = False + + if zero_cfg.overlap_sync_param: + param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator) + else: + param_bcast_sync_handler = None + + if not is_using_fsdp(): + if ( + "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer + or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim + ): + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=grad_scal_cfg, + zero_cfg=zero_cfg, + param_bcast_sync_handler=param_bcast_sync_handler, + isp_communicator=isp_communicator, + ) + else: + optimizer = HybridZeroOptimizer_v2( + naive_optimizer, + grad_scal_cfg=grad_scal_cfg, + zero_cfg=zero_cfg, + param_bcast_sync_handler=param_bcast_sync_handler, + isp_communicator=isp_communicator, + ) + else: + optimizer = FSDPadaptOptimizer( + naive_optimizer, + grad_scal_cfg=grad_scal_cfg, + zero_cfg=zero_cfg, + ) + + beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) + + lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler) + + return optimizer, beta2_scheduler, lr_scheduler diff --git a/internlm/initialize/initialize_profiler.py b/internlm/initialize/initialize_profiler.py new file mode 100644 index 000000000..eb9b41a19 --- /dev/null +++ b/internlm/initialize/initialize_profiler.py @@ -0,0 +1,61 @@ +import torch + +from internlm.accelerator import AcceleratorType +from internlm.accelerator.abstract_accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.common import DummyProfile +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +try: + import torch_npu +except (ModuleNotFoundError, ImportError): + pass + + +def initialize_llm_profile(profiling: bool = False, start_time: str = None): + """Initialize and return the profiler context manager instance.""" + + if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 3} + trace_path = ( + f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_" + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}" + ) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + ) + llm_profile = torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(**schedule_config), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_path), + record_shapes=True, + profile_memory=True, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) + logger.info(f"Do profiling for NPU on rank {gpc.get_global_rank()}!") + else: + llm_profile = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(**schedule_config), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), + with_stack=True, + with_modules=True, + profile_memory=True, + ) + logger.info(f"Do profiling for GPU on rank {gpc.get_global_rank()}!") + else: + llm_profile = DummyProfile() + + return llm_profile diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 48487c5fb..5b8cd9f35 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -26,8 +26,8 @@ from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape from internlm.core.trainer import Trainer from internlm.data.utils import packed_data_normalizer, unpack_data -from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer -from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler +from internlm.solver.optimizer import BaseOptimizer +from internlm.solver.schedulers import Beta2Scheduler from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.parallel import is_using_isp diff --git a/internlm/initialize/legacy/launch.py b/internlm/initialize/legacy/launch.py deleted file mode 100644 index 3a8ccedee..000000000 --- a/internlm/initialize/legacy/launch.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from internlm.initialize.launch import get_config_value -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -def auto_resume_sanity_check(ckpt_config): - load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None) - if load_given_ckpt is None: - return True # default value is True - else: - return not load_given_ckpt - - -def ckpt_info_sanity_check(ckpt_config): - load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None) - - load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None) - - if load_model_only_folder is not None: - assert ( - load_ckpt_folder is None - ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ -# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" - return dict(path=load_model_only_folder, content=("model",), ckpt_type="internevo") - else: - load_optimizer = get_config_value(ckpt_config, "load_optimizer", True) - - if isinstance(load_ckpt_folder, str): - if load_optimizer: - return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internevo") - else: - return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internevo") - elif load_ckpt_folder is None: - return None - else: - assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'" diff --git a/internlm/initialize/legacy/__init__.py b/internlm/launcher/__init__.py similarity index 100% rename from internlm/initialize/legacy/__init__.py rename to internlm/launcher/__init__.py diff --git a/train.py b/internlm/launcher/launch.py old mode 100755 new mode 100644 similarity index 75% rename from train.py rename to internlm/launcher/launch.py index 6e5e1399f..16eec6c68 --- a/train.py +++ b/internlm/launcher/launch.py @@ -7,8 +7,9 @@ build_train_loader_with_data_type, build_valid_loader_with_data_type, ) -from internlm.initialize import initialize_distributed_env -from internlm.model.builder import create_model +from internlm.initialize import initialize_launcher +from internlm.model.model_implementations.builder import create_model +from internlm.model.model_implementations.registry import register_model_initializer from internlm.monitor import internevo_monitor from internlm.utils.common import parse_args @@ -16,6 +17,7 @@ @internevo_monitor(feishu_alert=True, clean_run=True) def main(args): # initialize model + register_model_initializer() model = create_model() # initialize train dataloader @@ -36,7 +38,7 @@ def main(args): args = parse_args() # Initialize distributed environment - initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None # Run the main function with parsed arguments diff --git a/internlm/model/llava/__init__.py b/internlm/model/model_implementations/__init__.py similarity index 100% rename from internlm/model/llava/__init__.py rename to internlm/model/model_implementations/__init__.py diff --git a/internlm/model/builder.py b/internlm/model/model_implementations/builder.py similarity index 90% rename from internlm/model/builder.py rename to internlm/model/model_implementations/builder.py index e8d3f11b9..168a00df5 100644 --- a/internlm/model/builder.py +++ b/internlm/model/model_implementations/builder.py @@ -6,12 +6,14 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper -from internlm.model.base_model import BaseModel -from internlm.model.modules.linear import ( +from internlm.model.model_implementations.registry import model_initializer +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_ops.modules.linear import ( ParallelLinearWithCommExt, ScaleColumnParallelLinear, ) -from internlm.model.registry import model_initializer from internlm.utils.common import get_current_device from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger @@ -55,8 +57,10 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: else: model = pipeline_parallel_sharding_wrapper(num_layers, num_chunks, model_buidler, **kwargs) - if not isinstance(model, BaseModel) and gpc.is_rank_for_log(): - logger.warning(f"To load/save huggingface ckpt, built-in model should inherited from {BaseModel.__name__}") + if not isinstance(model, BaseTransformerModel) and gpc.is_rank_for_log(): + logger.warning( + f"To load/save huggingface ckpt, built-in model should inherited from {BaseTransformerModel.__name__}" + ) return model diff --git a/internlm/model/registry.py b/internlm/model/model_implementations/registry.py similarity index 77% rename from internlm/model/registry.py rename to internlm/model/model_implementations/registry.py index 68013d268..a7857e7f8 100644 --- a/internlm/model/registry.py +++ b/internlm/model/model_implementations/registry.py @@ -4,16 +4,26 @@ from typing import Callable -from internlm.model.modeling_baichuan2 import Baichuan2 -from internlm.model.modeling_gemma import Gemma -from internlm.model.modeling_internlm import InternLM1 -from internlm.model.modeling_internlm2 import InternLM2 -from internlm.model.modeling_llama import Llama2 -from internlm.model.modeling_llava import Llava -from internlm.model.modeling_mixtral import MixtralMoE -from internlm.model.modeling_moe import Internlm1MoE -from internlm.model.modeling_qwen2 import Qwen2 -from internlm.model.modeling_qwen2_moe import Qwen2Moe +from internlm.model.model_implementations.transformers.modeling_baichuan2 import ( + Baichuan2, +) +from internlm.model.model_implementations.transformers.modeling_gemma import Gemma +from internlm.model.model_implementations.transformers.modeling_internlm import ( + InternLM1, +) +from internlm.model.model_implementations.transformers.modeling_internlm2 import ( + InternLM2, +) +from internlm.model.model_implementations.transformers.modeling_llama import Llama2 +from internlm.model.model_implementations.transformers.modeling_llava import Llava +from internlm.model.model_implementations.transformers.modeling_mixtral import ( + MixtralMoE, +) +from internlm.model.model_implementations.transformers.modeling_moe import Internlm1MoE +from internlm.model.model_implementations.transformers.modeling_qwen2 import Qwen2 +from internlm.model.model_implementations.transformers.modeling_qwen2_moe import ( + Qwen2Moe, +) from internlm.utils.common import SingletonMeta from internlm.utils.utils import ModelType @@ -95,6 +105,3 @@ def register_model_initializer() -> None: model_initializer.register_module(ModelType.GEMMA.name, Gemma) model_initializer.register_module(ModelType.QWEN2MOE.name, Qwen2Moe) model_initializer.register_module(ModelType.MIXTRALMOE.name, MixtralMoE) - - -register_model_initializer() diff --git a/internlm/model/modules/__init__.py b/internlm/model/model_implementations/transformers/__init__.py similarity index 100% rename from internlm/model/modules/__init__.py rename to internlm/model/model_implementations/transformers/__init__.py diff --git a/internlm/model/base_model.py b/internlm/model/model_implementations/transformers/base_model.py similarity index 74% rename from internlm/model/base_model.py rename to internlm/model/model_implementations/transformers/base_model.py index cdbd04d6e..17bb0155e 100644 --- a/internlm/model/base_model.py +++ b/internlm/model/model_implementations/transformers/base_model.py @@ -2,12 +2,12 @@ from torch import nn -from internlm.model.utils import load_src_states, merge_pp_src_states +from internlm.model.model_ops.utils import load_src_states, merge_pp_src_states -class BaseModel(nn.Module, metaclass=ABCMeta): +class BaseTransformerModel(nn.Module, metaclass=ABCMeta): """ - Base class for all models. + Base class for InternEvo transformer models. """ @staticmethod diff --git a/internlm/model/modeling_baichuan2.py b/internlm/model/model_implementations/transformers/modeling_baichuan2.py similarity index 97% rename from internlm/model/modeling_baichuan2.py rename to internlm/model/model_implementations/transformers/modeling_baichuan2.py index 7dd632351..09bde6c3e 100644 --- a/internlm/model/modeling_baichuan2.py +++ b/internlm/model/model_implementations/transformers/modeling_baichuan2.py @@ -10,20 +10,22 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.initialize.initialize_tensor import ( +from internlm.core.context import global_context as gpc +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import MHA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.utils import ( +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import MHA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) @@ -271,7 +273,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class Baichuan2(BaseModel): +class Baichuan2(BaseTransformerModel): """ 1D Packed Flash Llama. diff --git a/internlm/model/modeling_gemma.py b/internlm/model/model_implementations/transformers/modeling_gemma.py similarity index 98% rename from internlm/model/modeling_gemma.py rename to internlm/model/model_implementations/transformers/modeling_gemma.py index 74d71796e..5e8bd0a6d 100644 --- a/internlm/model/modeling_gemma.py +++ b/internlm/model/model_implementations/transformers/modeling_gemma.py @@ -9,20 +9,22 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.initialize.initialize_tensor import ( +from internlm.core.context import global_context as gpc +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import GQA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.utils import ( +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import GQA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) @@ -308,7 +310,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class Gemma(BaseModel): +class Gemma(BaseTransformerModel): """ 1D Packed Flash Llama. diff --git a/internlm/model/modeling_internlm.py b/internlm/model/model_implementations/transformers/modeling_internlm.py similarity index 98% rename from internlm/model/modeling_internlm.py rename to internlm/model/model_implementations/transformers/modeling_internlm.py index 367ba524a..6a29124d7 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm.py @@ -11,17 +11,22 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module from internlm.core.parallel.shard import partition_uniform -from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import MHA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.utils import ( +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( + normal_, + scaled_init_method_normal, +) +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import MHA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, internlm1_mha_pre_load_convert, @@ -230,7 +235,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual -class InternLM1(BaseModel): +class InternLM1(BaseTransformerModel): """ 1D Packed Flash InternLm. diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/model_implementations/transformers/modeling_internlm2.py similarity index 98% rename from internlm/model/modeling_internlm2.py rename to internlm/model/model_implementations/transformers/modeling_internlm2.py index 0453b9dcb..62d760fec 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm2.py @@ -10,21 +10,23 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import partition_uniform -from internlm.initialize.initialize_tensor import ( +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import GQA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.utils import ( +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import GQA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, get_parallel_size_from_file, @@ -293,7 +295,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class InternLM2(BaseModel): +class InternLM2(BaseTransformerModel): """ InternLM2 Model. diff --git a/internlm/model/modeling_llama.py b/internlm/model/model_implementations/transformers/modeling_llama.py similarity index 98% rename from internlm/model/modeling_llama.py rename to internlm/model/model_implementations/transformers/modeling_llama.py index 56b88e83e..03dec8b1b 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/model_implementations/transformers/modeling_llama.py @@ -8,21 +8,23 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module -from internlm.initialize.initialize_tensor import ( +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import GQA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.utils import ( +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import GQA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) @@ -281,7 +283,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class Llama2(BaseModel): +class Llama2(BaseTransformerModel): """ Llama2 Model. diff --git a/internlm/model/modeling_llava.py b/internlm/model/model_implementations/transformers/modeling_llava.py similarity index 93% rename from internlm/model/modeling_llava.py rename to internlm/model/model_implementations/transformers/modeling_llava.py index 4c2bb1745..614578d43 100644 --- a/internlm/model/modeling_llava.py +++ b/internlm/model/model_implementations/transformers/modeling_llava.py @@ -4,22 +4,26 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module -from internlm.initialize.initialize_tensor import normal_, uniform_ -from internlm.model.base_model import BaseModel -from internlm.model.llava.clip_builder import build_vision_tower -from internlm.model.llava.projector_builder import build_vision_projector -from internlm.model.modeling_llama import Llama2Decoder -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.norm import new_layer_norm +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.modeling_llama import ( + Llama2Decoder, +) +from internlm.model.model_implementations.transformers.utils import normal_, uniform_ +from internlm.model.model_ops.llava.clip_builder import build_vision_tower +from internlm.model.model_ops.llava.projector_builder import build_vision_projector +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.norm import new_layer_norm from internlm.utils.logger import get_logger logger = get_logger(__file__) -class Llava(BaseModel): +class Llava(BaseTransformerModel): """ 1D Packed Flash Llava. diff --git a/internlm/model/modeling_mixtral.py b/internlm/model/model_implementations/transformers/modeling_mixtral.py similarity index 96% rename from internlm/model/modeling_mixtral.py rename to internlm/model/model_implementations/transformers/modeling_mixtral.py index 8e8767ced..340da871e 100644 --- a/internlm/model/modeling_mixtral.py +++ b/internlm/model/model_implementations/transformers/modeling_mixtral.py @@ -8,16 +8,21 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import SWA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.moe.moe import MoE -from internlm.model.utils import ( +from internlm.core.context import global_context as gpc +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( + normal_, + scaled_init_method_normal, +) +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import SWA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.moe.moe import MoE +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) @@ -252,7 +257,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual, moe_loss -class MixtralMoE(BaseModel): +class MixtralMoE(BaseTransformerModel): """ InternLM1 MoE. diff --git a/internlm/model/modeling_moe.py b/internlm/model/model_implementations/transformers/modeling_moe.py similarity index 96% rename from internlm/model/modeling_moe.py rename to internlm/model/model_implementations/transformers/modeling_moe.py index f40d35f32..9bf01df9e 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/model_implementations/transformers/modeling_moe.py @@ -8,16 +8,21 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import MHA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.moe.moe import MoE -from internlm.model.utils import ( +from internlm.core.context import global_context as gpc +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( + normal_, + scaled_init_method_normal, +) +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import MHA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.moe.moe import MoE +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, internlm1_mha_pre_load_convert, @@ -243,7 +248,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual, moe_loss -class Internlm1MoE(BaseModel): +class Internlm1MoE(BaseTransformerModel): """ InternLM1 MoE. diff --git a/internlm/model/modeling_qwen2.py b/internlm/model/model_implementations/transformers/modeling_qwen2.py similarity index 98% rename from internlm/model/modeling_qwen2.py rename to internlm/model/model_implementations/transformers/modeling_qwen2.py index 5a4bde534..b1d18e634 100644 --- a/internlm/model/modeling_qwen2.py +++ b/internlm/model/model_implementations/transformers/modeling_qwen2.py @@ -9,20 +9,22 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.initialize.initialize_tensor import ( +from internlm.core.context import global_context as gpc +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import SWA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.utils import ( +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import SWA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) @@ -288,7 +290,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class Qwen2(BaseModel): +class Qwen2(BaseTransformerModel): """ 1D Packed Flash Qwen. diff --git a/internlm/model/modeling_qwen2_moe.py b/internlm/model/model_implementations/transformers/modeling_qwen2_moe.py similarity index 97% rename from internlm/model/modeling_qwen2_moe.py rename to internlm/model/model_implementations/transformers/modeling_qwen2_moe.py index cfa98098a..ec3978bb1 100644 --- a/internlm/model/modeling_qwen2_moe.py +++ b/internlm/model/model_implementations/transformers/modeling_qwen2_moe.py @@ -7,21 +7,23 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.initialize.initialize_tensor import ( +from internlm.core.context import global_context as gpc +from internlm.model.model_implementations.transformers.base_model import ( + BaseTransformerModel, +) +from internlm.model.model_implementations.transformers.utils import ( normal_, scaled_init_method_normal, scaled_init_method_uniform, uniform_, ) -from internlm.model.base_model import BaseModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.modules.mha import SWA -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.modules.norm import new_layer_norm -from internlm.model.moe.moe import Qwen2MoE -from internlm.model.utils import ( +from internlm.model.model_ops.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.mha import SWA +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.norm import new_layer_norm +from internlm.model.model_ops.moe.moe import Qwen2MoE +from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) @@ -314,7 +316,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states, moe_loss -class Qwen2Moe(BaseModel): +class Qwen2Moe(BaseTransformerModel): """ 1D Packed Flash Qwen. diff --git a/internlm/initialize/initialize_tensor.py b/internlm/model/model_implementations/transformers/utils.py similarity index 100% rename from internlm/initialize/initialize_tensor.py rename to internlm/model/model_implementations/transformers/utils.py diff --git a/internlm/model/ops/__init__.py b/internlm/model/model_ops/__init__.py similarity index 100% rename from internlm/model/ops/__init__.py rename to internlm/model/model_ops/__init__.py diff --git a/internlm/model/model_ops/llava/__init__.py b/internlm/model/model_ops/llava/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/model/llava/clip_builder.py b/internlm/model/model_ops/llava/clip_builder.py similarity index 100% rename from internlm/model/llava/clip_builder.py rename to internlm/model/model_ops/llava/clip_builder.py diff --git a/internlm/model/llava/clip_encoder.py b/internlm/model/model_ops/llava/clip_encoder.py similarity index 100% rename from internlm/model/llava/clip_encoder.py rename to internlm/model/model_ops/llava/clip_encoder.py diff --git a/internlm/model/llava/projector_builder.py b/internlm/model/model_ops/llava/projector_builder.py similarity index 100% rename from internlm/model/llava/projector_builder.py rename to internlm/model/model_ops/llava/projector_builder.py diff --git a/internlm/model/losses/__init__.py b/internlm/model/model_ops/losses/__init__.py similarity index 100% rename from internlm/model/losses/__init__.py rename to internlm/model/model_ops/losses/__init__.py diff --git a/internlm/model/losses/ce_loss.py b/internlm/model/model_ops/losses/ce_loss.py similarity index 97% rename from internlm/model/losses/ce_loss.py rename to internlm/model/model_ops/losses/ce_loss.py index 5b2a380e8..e5645aba4 100644 --- a/internlm/model/losses/ce_loss.py +++ b/internlm/model/model_ops/losses/ce_loss.py @@ -2,7 +2,7 @@ from torch import nn from internlm.accelerator import get_accelerator -from internlm.model.ops.cross_entropy import new_cross_entropy +from internlm.model.model_ops.ops.cross_entropy import new_cross_entropy internlm_accelerator = get_accelerator() diff --git a/internlm/model/metrics.py b/internlm/model/model_ops/metrics.py similarity index 99% rename from internlm/model/metrics.py rename to internlm/model/model_ops/metrics.py index a7f6c9668..e67079534 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/model_ops/metrics.py @@ -4,7 +4,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import global_context as gpc -from internlm.model.ops.cross_entropy import new_cross_entropy +from internlm.model.model_ops.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer diff --git a/internlm/model/model_ops/modules/__init__.py b/internlm/model/model_ops/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/model/modules/embedding.py b/internlm/model/model_ops/modules/embedding.py similarity index 99% rename from internlm/model/modules/embedding.py rename to internlm/model/model_ops/modules/embedding.py index 93fcd6b23..d7c2850ae 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/model_ops/modules/embedding.py @@ -10,7 +10,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.rotary_emb import apply_rotary_emb +from internlm.model.model_ops.ops.rotary_emb import apply_rotary_emb from internlm.utils.parallel import is_using_isp diff --git a/internlm/model/modules/linear.py b/internlm/model/model_ops/modules/linear.py similarity index 99% rename from internlm/model/modules/linear.py rename to internlm/model/model_ops/modules/linear.py index 4a30967b0..a9cbd9b10 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/model_ops/modules/linear.py @@ -19,7 +19,7 @@ get_parallel_strategies_split_mode, get_tensor_split_parallel_mode, ) -from internlm.model.ops.linear import ( +from internlm.model.model_ops.ops.linear import ( gmm_backward_op, gmm_forward_op, linear_backward_op, @@ -28,8 +28,7 @@ from internlm.utils.logger import get_logger if TYPE_CHECKING: - from internlm.core.parallel.comm.isp import WPCommunicator - from internlm.core.parallel.comm.tensor import TPCommunicator + from internlm.core.parallel.comm import TPCommunicator, WPCommunicator logger = get_logger(__file__) internlm_accelerator = get_accelerator() diff --git a/internlm/model/modules/mha.py b/internlm/model/model_ops/modules/mha.py similarity index 99% rename from internlm/model/modules/mha.py rename to internlm/model/model_ops/modules/mha.py index 42418a212..227df0c8b 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/model_ops/modules/mha.py @@ -11,10 +11,10 @@ from torch.nn import functional as F from internlm.core.context import global_context as gpc -from internlm.model.modules.embedding import new_rotary_embedding -from internlm.model.modules.linear import new_linear -from internlm.model.modules.utils import update_kv_cache -from internlm.model.ops.attention import CrossAttention, SelfAttention +from internlm.model.model_ops.modules.embedding import new_rotary_embedding +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.utils import update_kv_cache +from internlm.model.model_ops.ops.attention import CrossAttention, SelfAttention from internlm.utils.logger import get_logger logger = get_logger(__file__) diff --git a/internlm/model/modules/mlp.py b/internlm/model/model_ops/modules/mlp.py similarity index 98% rename from internlm/model/modules/mlp.py rename to internlm/model/model_ops/modules/mlp.py index e51e5897f..c802d4d99 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/model_ops/modules/mlp.py @@ -6,8 +6,8 @@ import torch from torch import nn -from internlm.model.modules.linear import new_linear -from internlm.model.modules.utils import Gelu, Silu +from internlm.model.model_ops.modules.linear import new_linear +from internlm.model.model_ops.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger from internlm.utils.utils import ActivationType diff --git a/internlm/model/modules/norm.py b/internlm/model/model_ops/modules/norm.py similarity index 91% rename from internlm/model/modules/norm.py rename to internlm/model/model_ops/modules/norm.py index 2a9700f8d..cab90e0f5 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/model_ops/modules/norm.py @@ -8,7 +8,7 @@ import torch from torch import nn -from internlm.model.ops.norm import RMSNorm +from internlm.model.model_ops.ops.norm import RMSNorm Shape = Union[int, List[int], torch.Size] diff --git a/internlm/model/modules/utils.py b/internlm/model/model_ops/modules/utils.py similarity index 100% rename from internlm/model/modules/utils.py rename to internlm/model/model_ops/modules/utils.py diff --git a/internlm/model/moe/__init__.py b/internlm/model/model_ops/moe/__init__.py similarity index 100% rename from internlm/model/moe/__init__.py rename to internlm/model/model_ops/moe/__init__.py diff --git a/internlm/model/moe/base_layer.py b/internlm/model/model_ops/moe/base_layer.py similarity index 95% rename from internlm/model/moe/base_layer.py rename to internlm/model/model_ops/moe/base_layer.py index 7811e056d..a99a7b3b6 100644 --- a/internlm/model/moe/base_layer.py +++ b/internlm/model/model_ops/moe/base_layer.py @@ -4,7 +4,7 @@ from torch.nn import Module, ModuleList from internlm.core.context import global_context as gpc -from internlm.model.moe.experts import Experts +from internlm.model.model_ops.moe.experts import Experts if TYPE_CHECKING: Base = Module[Tensor] diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/model_ops/moe/dropless_layer.py similarity index 99% rename from internlm/model/moe/dropless_layer.py rename to internlm/model/model_ops/moe/dropless_layer.py index 031c23065..c2868d7bc 100644 --- a/internlm/model/moe/dropless_layer.py +++ b/internlm/model/model_ops/moe/dropless_layer.py @@ -15,7 +15,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.mlp import new_feed_forward from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger diff --git a/internlm/model/moe/experts.py b/internlm/model/model_ops/moe/experts.py similarity index 100% rename from internlm/model/moe/experts.py rename to internlm/model/model_ops/moe/experts.py diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/model_ops/moe/gshard_layer.py similarity index 99% rename from internlm/model/moe/gshard_layer.py rename to internlm/model/model_ops/moe/gshard_layer.py index a102b8c9e..c15810070 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/model_ops/moe/gshard_layer.py @@ -15,7 +15,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.modules.mlp import new_feed_forward +from internlm.model.model_ops.modules.mlp import new_feed_forward from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer diff --git a/internlm/model/moe/megablocks/__init__.py b/internlm/model/model_ops/moe/megablocks/__init__.py similarity index 100% rename from internlm/model/moe/megablocks/__init__.py rename to internlm/model/model_ops/moe/megablocks/__init__.py diff --git a/internlm/model/moe/megablocks/megablock_dmoe.py b/internlm/model/model_ops/moe/megablocks/megablock_dmoe.py similarity index 96% rename from internlm/model/moe/megablocks/megablock_dmoe.py rename to internlm/model/model_ops/moe/megablocks/megablock_dmoe.py index 46e1a81cd..ee80a07d8 100644 --- a/internlm/model/moe/megablocks/megablock_dmoe.py +++ b/internlm/model/model_ops/moe/megablocks/megablock_dmoe.py @@ -5,10 +5,10 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.moe.base_layer import BaseMoELayer -from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE -from internlm.model.moe.megablocks.mlp import MegaBlockGroupedFeedForward -from internlm.model.moe.megablocks.utils import promote_scalar +from internlm.model.model_ops.moe.base_layer import BaseMoELayer +from internlm.model.model_ops.moe.megablocks.megablock_moe import MegaBlockMoE +from internlm.model.model_ops.moe.megablocks.mlp import MegaBlockGroupedFeedForward +from internlm.model.model_ops.moe.megablocks.utils import promote_scalar try: import stk diff --git a/internlm/model/moe/megablocks/megablock_moe.py b/internlm/model/model_ops/moe/megablocks/megablock_moe.py similarity index 98% rename from internlm/model/moe/megablocks/megablock_moe.py rename to internlm/model/model_ops/moe/megablocks/megablock_moe.py index 257585da0..86a87fff6 100644 --- a/internlm/model/moe/megablocks/megablock_moe.py +++ b/internlm/model/model_ops/moe/megablocks/megablock_moe.py @@ -6,9 +6,9 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.moe.base_layer import BaseMoELayer -from internlm.model.moe.megablocks.mlp import MegaBlockFeedForward -from internlm.model.moe.utils import all_to_all +from internlm.model.model_ops.moe.base_layer import BaseMoELayer +from internlm.model.model_ops.moe.megablocks.mlp import MegaBlockFeedForward +from internlm.model.model_ops.moe.utils import all_to_all try: from megablocks import ops diff --git a/internlm/model/moe/megablocks/mlp.py b/internlm/model/model_ops/moe/megablocks/mlp.py similarity index 95% rename from internlm/model/moe/megablocks/mlp.py rename to internlm/model/model_ops/moe/megablocks/mlp.py index 374793d6c..91519a890 100644 --- a/internlm/model/moe/megablocks/mlp.py +++ b/internlm/model/model_ops/moe/megablocks/mlp.py @@ -3,8 +3,8 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.modules.utils import Silu -from internlm.model.moe.megablocks.utils import ( +from internlm.model.model_ops.modules.utils import Silu +from internlm.model.model_ops.moe.megablocks.utils import ( act_fn, dsd_nn, sdd_nt, diff --git a/internlm/model/moe/megablocks/utils.py b/internlm/model/model_ops/moe/megablocks/utils.py similarity index 99% rename from internlm/model/moe/megablocks/utils.py rename to internlm/model/model_ops/moe/megablocks/utils.py index 857dd8b73..5c40dd619 100644 --- a/internlm/model/moe/megablocks/utils.py +++ b/internlm/model/model_ops/moe/megablocks/utils.py @@ -1,7 +1,7 @@ import torch from internlm.accelerator import get_accelerator -from internlm.model.modules.utils import Silu +from internlm.model.model_ops.modules.utils import Silu try: import stk diff --git a/internlm/model/moe/moe.py b/internlm/model/model_ops/moe/moe.py similarity index 96% rename from internlm/model/moe/moe.py rename to internlm/model/model_ops/moe/moe.py index 67fc40b56..ba96ecbca 100644 --- a/internlm/model/moe/moe.py +++ b/internlm/model/model_ops/moe/moe.py @@ -4,11 +4,11 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import set_fp32_attr_to_module -from internlm.model.modules.mlp import new_feed_forward -from internlm.model.moe.dropless_layer import DroplessMoELayer -from internlm.model.moe.gshard_layer import GShardMoELayer -from internlm.model.moe.megablocks.megablock_dmoe import MegaBlockdMoE -from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE +from internlm.model.model_ops.modules.mlp import new_feed_forward +from internlm.model.model_ops.moe.dropless_layer import DroplessMoELayer +from internlm.model.model_ops.moe.gshard_layer import GShardMoELayer +from internlm.model.model_ops.moe.megablocks.megablock_dmoe import MegaBlockdMoE +from internlm.model.model_ops.moe.megablocks.megablock_moe import MegaBlockMoE from internlm.utils.logger import get_logger # global llm logger diff --git a/internlm/model/moe/utils.py b/internlm/model/model_ops/moe/utils.py similarity index 100% rename from internlm/model/moe/utils.py rename to internlm/model/model_ops/moe/utils.py diff --git a/internlm/model/model_ops/ops/__init__.py b/internlm/model/model_ops/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/model_ops/ops/_flash_attn.py similarity index 100% rename from internlm/model/ops/_flash_attn.py rename to internlm/model/model_ops/ops/_flash_attn.py diff --git a/internlm/model/ops/attention.py b/internlm/model/model_ops/ops/attention.py similarity index 99% rename from internlm/model/ops/attention.py rename to internlm/model/model_ops/ops/attention.py index 3aec51f55..e2b622412 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/model_ops/ops/attention.py @@ -17,11 +17,14 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.parallel.comm.isp import ( +from internlm.core.parallel.comm import ( auto_wrap_distributed_attention, auto_wrap_func_distributed_attention, ) -from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn +from internlm.model.model_ops.ops.utils import ( + pack_output_after_attn, + unpack_qkv_before_attn, +) from internlm.utils.common import get_current_device from internlm.utils.utils import ( CuSeqlenType, @@ -41,7 +44,7 @@ pass else: try: - from internlm.model.ops.ring_flash_attn import ( + from internlm.model.model_ops.ops.ring_flash_attn import ( zigzag_ring_flash_attn_kvpacked_func_with_sliding_window, zigzag_ring_flash_attn_qkvpacked_func_with_sliding_window, zigzag_ring_flash_attn_qkvsplited_func_with_sliding_window, diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/model_ops/ops/cross_entropy.py similarity index 98% rename from internlm/model/ops/cross_entropy.py rename to internlm/model/model_ops/ops/cross_entropy.py index 99bf1e047..35de1b6ef 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/model_ops/ops/cross_entropy.py @@ -14,7 +14,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.cross_entropy_ops import ( +from internlm.model.model_ops.ops.cross_entropy_ops import ( CrossEntropyApexVocabParallel, CrossEntropyLossApex, CrossEntropyPython, diff --git a/internlm/model/ops/cross_entropy_ops/__init__.py b/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py similarity index 100% rename from internlm/model/ops/cross_entropy_ops/__init__.py rename to internlm/model/model_ops/ops/cross_entropy_ops/__init__.py diff --git a/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/apex_naive_loss.py similarity index 100% rename from internlm/model/ops/cross_entropy_ops/apex_naive_loss.py rename to internlm/model/model_ops/ops/cross_entropy_ops/apex_naive_loss.py diff --git a/internlm/model/ops/cross_entropy_ops/py_naive_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/py_naive_loss.py similarity index 100% rename from internlm/model/ops/cross_entropy_ops/py_naive_loss.py rename to internlm/model/model_ops/ops/cross_entropy_ops/py_naive_loss.py diff --git a/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/py_vocab_parallel_loss.py similarity index 100% rename from internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py rename to internlm/model/model_ops/ops/cross_entropy_ops/py_vocab_parallel_loss.py diff --git a/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/sequence_parallel_loss.py similarity index 100% rename from internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py rename to internlm/model/model_ops/ops/cross_entropy_ops/sequence_parallel_loss.py diff --git a/internlm/model/ops/fused_rmsnorm.py b/internlm/model/model_ops/ops/fused_rmsnorm.py similarity index 100% rename from internlm/model/ops/fused_rmsnorm.py rename to internlm/model/model_ops/ops/fused_rmsnorm.py diff --git a/internlm/model/ops/linear.py b/internlm/model/model_ops/ops/linear.py similarity index 100% rename from internlm/model/ops/linear.py rename to internlm/model/model_ops/ops/linear.py diff --git a/internlm/model/ops/norm.py b/internlm/model/model_ops/ops/norm.py similarity index 100% rename from internlm/model/ops/norm.py rename to internlm/model/model_ops/ops/norm.py diff --git a/internlm/model/ops/ring_flash_attn/__init__.py b/internlm/model/model_ops/ops/ring_flash_attn/__init__.py similarity index 100% rename from internlm/model/ops/ring_flash_attn/__init__.py rename to internlm/model/model_ops/ops/ring_flash_attn/__init__.py diff --git a/internlm/model/ops/ring_flash_attn/utils.py b/internlm/model/model_ops/ops/ring_flash_attn/utils.py similarity index 100% rename from internlm/model/ops/ring_flash_attn/utils.py rename to internlm/model/model_ops/ops/ring_flash_attn/utils.py diff --git a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py b/internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py similarity index 99% rename from internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py rename to internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py index 5c22fed3d..8b908e4b1 100644 --- a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py +++ b/internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py @@ -4,7 +4,7 @@ import torch.distributed from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward -from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.context import global_context as gpc from .utils import RingComm, update_out_and_lse diff --git a/internlm/model/ops/rotary_emb.py b/internlm/model/model_ops/ops/rotary_emb.py similarity index 100% rename from internlm/model/ops/rotary_emb.py rename to internlm/model/model_ops/ops/rotary_emb.py diff --git a/internlm/model/ops/utils.py b/internlm/model/model_ops/ops/utils.py similarity index 100% rename from internlm/model/ops/utils.py rename to internlm/model/model_ops/ops/utils.py diff --git a/internlm/model/utils.py b/internlm/model/model_ops/utils.py similarity index 98% rename from internlm/model/utils.py rename to internlm/model/model_ops/utils.py index 7c974abeb..e3035a102 100644 --- a/internlm/model/utils.py +++ b/internlm/model/model_ops/utils.py @@ -4,8 +4,8 @@ from tqdm import tqdm -from internlm.core.context.parallel_context import global_context as gpc -from internlm.model.modules.mha import MHA +from internlm.core.context import global_context as gpc +from internlm.model.model_ops.modules.mha import MHA from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load from internlm.utils.utils import TensorParallelMode diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py index 2bcfa2ccf..6f5e511f3 100644 --- a/internlm/monitor/__init__.py +++ b/internlm/monitor/__init__.py @@ -1,9 +1,15 @@ -from .monitor import initialize_monitor_manager, internevo_monitor, send_alert_message -from .utils import set_env_var +from .alert import send_feishu_msg_with_webhook +from .monitor import ( + initialize_monitor_manager, + internevo_monitor, + monitor_manager, + send_alert_message, +) __all__ = [ "send_alert_message", "initialize_monitor_manager", - "set_env_var", "internevo_monitor", + "monitor_manager", + "send_feishu_msg_with_webhook", ] diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index fc33de62a..252a0380b 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -12,10 +12,10 @@ from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import global_context as gpc -from internlm.monitor.alert import send_feishu_msg_with_webhook -from internlm.utils.common import SingletonMeta +from internlm.monitor import send_feishu_msg_with_webhook +from internlm.utils.common import SingletonMeta, set_env_var -from .utils import get_job_key, set_env_var +from .utils import get_job_key logger = logging.getLogger(__file__) internlm_accelerator = get_accelerator() diff --git a/internlm/monitor/utils.py b/internlm/monitor/utils.py index 34360b521..0bdd3db2e 100644 --- a/internlm/monitor/utils.py +++ b/internlm/monitor/utils.py @@ -6,10 +6,6 @@ def now_time(): return datetime.now().strftime("%b%d_%H-%M-%S") -def set_env_var(key, value): - os.environ[str(key)] = str(value) - - def get_job_id(): job_id = "none" if os.getenv("SLURM_JOB_ID") is not None: diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 2b5c9e4ed..93d7a1ba1 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -10,16 +10,10 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable from internlm.accelerator import get_accelerator -from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.context.random import ( - get_current_mode, - get_states, - set_mode, - set_seed_states, - sync_states, -) - -from ..utils.common import get_current_device +from internlm.core.context import get_current_mode, get_states +from internlm.core.context import global_context as gpc +from internlm.core.context import set_mode, set_seed_states, sync_states +from internlm.utils.common import get_current_device internlm_accelerator = get_accelerator() diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index 55070fc33..7f848c9bd 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -1,8 +1,9 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from .base_optimizer import BaseOptimizer from .fsdp_optimizer import FSDPadaptOptimizer from .hybrid_zero_optim import HybridZeroOptimizer from .hybrid_zero_optim_v2 import HybridZeroOptimizer_v2 -__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "HybridZeroOptimizer_v2"] +__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "BaseOptimizer", "HybridZeroOptimizer_v2"] diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 94cc411c6..2d9bb755b 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -8,15 +8,16 @@ from torch.optim import Optimizer from internlm.accelerator import get_accelerator -from internlm.core.context import Config, ParallelMode +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.solver.optimizer.base_optimizer import BaseOptimizer +from internlm.solver.optimizer import BaseOptimizer from internlm.solver.optimizer.utils import ( DynamicGradScaler, get_norm, release_param_grad, ) from internlm.utils.common import get_tensor_norm, move_norm_to_cuda +from internlm.utils.config import Config from internlm.utils.logger import get_logger try: diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 8d3ce3add..89fe346a4 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,19 +11,18 @@ from torch.optim import Optimizer from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import Config, ParallelMode -from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import ( +from internlm.core.context import ( IS_REPLICA_EXPERT_DATA_PARALLEL, IS_REPLICA_ZERO_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_EXPERT_DATA_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, + ParallelMode, ) -from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper -from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler -from internlm.model.modules.utils import is_gate_param, is_moe_param +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import ISPCommunicatorWrapper, ParamAsyncBcastHandler +from internlm.model.model_ops.modules.utils import is_gate_param, is_moe_param from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -42,6 +41,7 @@ sync_param, ) from internlm.utils.common import get_current_device +from internlm.utils.config import Config from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 36e5f073f..fd54d5d15 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -7,15 +7,15 @@ import torch.distributed as dist from torch.optim import Optimizer -from internlm.core.context import Config, ParallelMode -from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import ( +from internlm.core.context import ( IS_REPLICA_ZERO_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, + ParallelMode, ) -from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import ParamAsyncBcastHandler from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore_v2, @@ -30,6 +30,7 @@ sync_param, ) from internlm.utils.common import get_current_device +from internlm.utils.config import Config from internlm.utils.logger import get_logger from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py deleted file mode 100644 index f3c680da4..000000000 --- a/internlm/train/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .pipeline import ( - get_scheduler_hooks, - initialize_llm_profile, - initialize_model_and_parallel_communicator, - initialize_optimizer, - initialize_parallel_communicator, - load_new_batch, - record_current_batch_training_metrics, - set_fp32_attr_for_model, - set_parallel_attr_for_param_groups, -) - -__all__ = [ - "initialize_llm_profile", - "initialize_model_and_parallel_communicator", - "initialize_parallel_communicator", - "initialize_optimizer", - "load_new_batch", - "record_current_batch_training_metrics", - "get_scheduler_hooks", - "set_parallel_attr_for_param_groups", - "set_fp32_attr_for_model", -] diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py deleted file mode 100644 index 945ee688a..000000000 --- a/internlm/train/pipeline.py +++ /dev/null @@ -1,1210 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import collections -import functools -import itertools -import math -import time -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union - -import torch -from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, - ShardingStrategy, -) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from torch.utils.data import DataLoader - -from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.checkpoint.utils import init_fsdp_v1 -from internlm.core.context import ( - IS_REPLICA_EXPERT_DATA_PARALLEL, - IS_REPLICA_ZERO_PARALLEL, - IS_TENSOR_EXPERT_DATA_PARALLEL, - IS_TENSOR_ZERO_PARALLEL, - IS_WEIGHT_EXPERT_DATA_PARALLEL, - IS_WEIGHT_ZERO_PARALLEL, - ParallelMode, -) -from internlm.core.context import global_context as gpc -from internlm.core.context.random import set_mode -from internlm.core.naive_amp import ( - NaiveAMPModel, - set_fp32_attr_to_module, - unwrap_naive_amp, -) -from internlm.core.parallel.comm.isp import ( - EmbeddingWeightParallelCommunicator, - HeadWeightParallelCommunicator, - ISPCommModelConfig, - ISPCommunicator, - ISPCommunicatorSchedulerHook, - ISPCommunicatorWrapper, -) -from internlm.core.parallel.comm.tensor import ( - EmbeddingSequenceParallelCommunicator, - EmbeddingTensorParallelCommunicator, - HeadSequenceParallelCommunicator, - HeadTensorParallelCommunicator, - LinearRole, - MoESequenceParallelCommunicator, - SequenceParallelCommunicator, - TensorParallelCommunicator, -) -from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler -from internlm.core.trainer import TrainState -from internlm.data.utils import unpack_type_ids -from internlm.model.builder import create_model -from internlm.model.metrics import SchedulerMetricHook -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import ( - ColumnParallelLinear, - GroupedColumnLinear, - GroupedRowLinear, - GroupedWPLinear, - ParallelLinearWithCommExt, - RewardModelLinear, - RowParallelLinear, - ScaleColumnParallelLinear, - new_linear, -) -from internlm.model.modules.norm import new_layer_norm -from internlm.model.moe import Experts, MoE -from internlm.model.moe.moe import Qwen2MoE -from internlm.model.ops.norm import RMSNorm -from internlm.model.registry import register_model_initializer -from internlm.monitor import set_env_var -from internlm.monitor.monitor import monitor_manager as mm -from internlm.solver.optimizer import ( - FSDPadaptOptimizer, - HybridZeroOptimizer, - HybridZeroOptimizer_v2, -) -from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw -from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler -from internlm.solver.schedulers.lr_scheduler import FineTuneCosineAnnealingWarmupLR -from internlm.train.utils import create_param_groups, map_param_block, timeout_input -from internlm.utils.common import DummyProfile, SchedulerHook, get_current_device -from internlm.utils.lazy import LazyObject -from internlm.utils.logger import get_logger -from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import ( - is_replica_expert_data_parallel_parameter, - is_replica_zero_parallel_parameter, - is_tensor_expert_data_parallel_parameter, - is_tensor_zero_parallel_parameter, - is_using_fsdp, - is_using_hf, - is_using_isp, - is_weight_expert_data_parallel_parameter, - is_weight_zero_parallel_parameter, - sync_model_param, - sync_model_replica_param_group, -) -from internlm.utils.timeout import llm_timeout -from internlm.utils.utils import TensorParallelMode - -try: - import torch_npu -except (ImportError, ModuleNotFoundError): - pass - -try: - from torch.distributed._composable.fsdp import fully_shard - - FSDP2_SUPPORTED = True -except (ImportError, ModuleNotFoundError): - FSDP2_SUPPORTED = False - - -try: - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, - ) - - DCP_SUPPORTED = True -except (ImportError, ModuleNotFoundError): - DCP_SUPPORTED = False - - -IS_INJECTED = "is_injected" - -LINEAR2NEWLINEAR_NAME_MAPPING = dict( - q_proj="wq", - k_proj="wk", - v_proj="wv", - o_proj="wo", - gate_proj="w1", - down_proj="w2", - up_proj="w3", - lm_head="head", - W_pack="wqkv", -) - -logger = get_logger(__file__) -internlm_accelerator = get_accelerator() - - -def set_param_unique_tracking_name(model): - for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): - # Important: only works for llama-class models - childrens = chunk.named_children() - for _, children in childrens: - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - for name, child in block.named_modules(): - if isinstance(child, (ParallelLinearWithCommExt)): - full_name = f"{chunk_id}.{idx}.{name}" - setattr( - child.weight, - "tracking_name", - f"{full_name}.weight", - ) - if child.bias is not None: - setattr( - child.bias, - "tracking_name", - f"{full_name}.bias", - ) - else: - if isinstance(children, Embedding1D): - setattr( - children.weight, - "tracking_name", - f"{chunk_id}_embedding.weight", - ) - else: - setattr( - children.weight, - "tracking_name", - f"{chunk_id}_head.weight", - ) - - -def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - for _, module in _chunk.named_modules(): - if isinstance(module, (RMSNorm, nn.LayerNorm)) and gpc.config.get("use_fp32_norm", False): - set_fp32_attr_to_module(module) - - -def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): - def _check_module_pure_dp(name, module): # pylint: disable=W0613 - for param in module.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - - def _check_module(name, module): - # layer_norm - if isinstance(module, (RMSNorm, nn.LayerNorm)): - for param in module.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - - if isinstance(module, (MoE, Qwen2MoE)): - for param in module.moe_layer.gate.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - if hasattr(module, "coefficient"): - for param in module.coefficient.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - - # embedding and head - if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)): - for param in module.parameters(): - if gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): - setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): - setattr(param, IS_TENSOR_ZERO_PARALLEL, True) - - # for moe linear module - if isinstance(module, nn.Linear) and not isinstance(module, ParallelLinearWithCommExt): - for param in module.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - - if isinstance(module, Experts): - for param in module.parameters(): - if ( - gpc.is_initialized(ParallelMode.TENSOR) - and not is_using_isp() - and getattr(gpc.config.parallel.expert, "no_tp", False) - ): - setattr(param, IS_REPLICA_EXPERT_DATA_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): - setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): - setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True) - # for non-moe linear module - elif isinstance(module, ParallelLinearWithCommExt): - for param in module.parameters(): - if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): - setattr(param, IS_TENSOR_ZERO_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): - setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) - - # for vit and vit project - if "vision_tower" in name.lower() or "vision_proj" in name.lower(): - for param in module.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - - for _chunk in unwrap_naive_amp(model): - if not is_using_fsdp(): - # special case for pure dp mode - if ( - isinstance(gpc.config.parallel["tensor"], dict) - and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) - == TensorParallelMode.mtp.name - and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) - ): - _check_module_func = _check_module_pure_dp - else: - _check_module_func = _check_module - # set param parallel attribute - for name, module in _chunk.named_modules(): - _check_module_func(name, module) - - for name, param in _chunk.named_parameters(): - assert ( - is_replica_zero_parallel_parameter(param) - or is_tensor_zero_parallel_parameter(param) - or is_weight_zero_parallel_parameter(param) - or is_tensor_expert_data_parallel_parameter(param) - or is_weight_expert_data_parallel_parameter(param) - or is_replica_expert_data_parallel_parameter(param) - ), f"parameter with name: {name} has no parallel attribution." - - -@llm_timeout(func_name="initialize_model_and_parallel_communicator") -def initialize_model_and_parallel_communicator( - pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None -): - """ - Initialize model with Automatic Mixed Precision. - Returns: - torch.nn.Module: - The neural network model to be trained or evaluated. - An isp communicator for managing comp/comm overlap. - """ - if pre_process_func: - pre_process_output = pre_process_func() - - register_model_initializer() - - model = create_model() - - if post_process_func: - post_process_func(pre_process_output) - - return inject_model(model) - - -def inject_model(model): - """ - Inject model with Automatic Mixed Precision. - - Args: - torch.nn.Module: - The bare neural network model to be trained or evaluated. - - Returns: - torch.nn.Module: - The injected neural network model to be trained or evaluated. - An isp communicator for managing comp/comm overlap. - """ - if hasattr(model, IS_INJECTED) and getattr(model, IS_INJECTED): - return model - - # For non-HF cases, set tracking name for parameters - if not is_using_hf(): - set_param_unique_tracking_name(model) - - # For non-fsdp cases, set model inject helper - if not is_using_fsdp(): - inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None)) - - # should be set before NaiveAMPModel - set_fp32_attr_for_model(model) - - if isinstance(model, nn.ModuleList): - model = nn.ModuleList( - [ - NaiveAMPModel( - model=_m, - output_to_fp32=False, # manually controlled by interleaved pipleline scheduler - dtype=gpc.config.model.get("dtype", torch.half), - sync_buffer=False, - ) - for _m in model - ] - ) - else: - model = NaiveAMPModel( - model=model, - output_to_fp32=gpc.is_no_pp_or_last_stage(), - dtype=gpc.config.model.get("dtype", torch.half), - sync_buffer=False, - ) - - set_parallel_attr_for_param_groups(model) - - # This sync is very important, cause the model weights kept in optimizer are copied - # from the origin parameters in the memory, so we should make sure the dp sync - # does not influence the model weights in optimizer be different with the origin parameters. - if not is_using_fsdp() or gpc.config.parallel.fsdp.get("init_method", "cuda") == "cuda": - sync_model_param(model) - - # This function is needed to make sure parameters that are not splitted by tensor parallelism are - # the same across tensor parallelism. - sync_model_replica_param_group(model) - - # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random - # state in the same dp group are all the same. - random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA - set_mode(random_mode) - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) - - model = wrap_FSDP_model(model) - - # set is_injected flag - setattr(model, "IS_INJECTED", True) - - return model, isp_communicator - - -_T = TypeVar("_T") - - -def _submodule_filter(model: Union[nn.Module, nn.ModuleList], target_cls: Union[_T, Tuple[_T]]) -> Iterable[_T]: - for _chunk in unwrap_naive_amp(model): - for _module in _chunk.modules(): - if not isinstance(_module, target_cls): - continue - - yield _module - - -def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): - """ - Initialize communicator for isp tensor parallel mode. - - Args: - model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated. - - Returns: - An isp communicator for managing comp/comm overlap. - """ - isp_communicator_wrapper = None - _retain_out_sharded = gpc.config.model.get("parallel_output", True) - - if is_using_isp(): - isp_communicator = ISPCommunicator( - model, - ISPCommModelConfig( - gpc.config.model.dtype, - get_current_device(), - gpc.config.model.checkpoint, - ), - gpc.config.parallel.weight.overlap and not is_using_fsdp(), - gpc.get_group(ParallelMode.WEIGHT), - is_moe=False, - selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), - early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release, - ) - # register communicator for isp column parallel linear. - ColumnParallelLinear.register_cls_communicator(isp_communicator) - # row parallel linear will not be used. - RowParallelLinear.register_cls_communicator(None) - _head_communicator = HeadWeightParallelCommunicator( - weight_process_group=gpc.get_group(ParallelMode.WEIGHT), - seq_process_group=gpc.get_group(ParallelMode.TENSOR), - retain_out_sharded=_retain_out_sharded, - ) - _embedding_communicator = EmbeddingWeightParallelCommunicator(ParallelMode.WEIGHT) - - if gpc.config.model.get("num_experts", 1) > 1: - # register communicator for moe isp column parallel linear. - # NOTE: this wil overwrite registed communicator - moe_isp_communicator = ISPCommunicator( - model, - ISPCommModelConfig( - gpc.config.model.dtype, - get_current_device(), - gpc.config.model.checkpoint, - ), - gpc.config.parallel.expert_weight.overlap, - gpc.get_group(ParallelMode.EXPERT_WEIGHT), - is_moe=True, - early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release, - ) - for moe in _submodule_filter(model, Experts): - for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)): - column_linear.register_communicator(moe_isp_communicator) - for row_linear in _submodule_filter(moe, RowParallelLinear): - row_linear.register_communicator(None) - - isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator, moe_isp_communicator]) - else: - isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator]) - - # register communictor for mtp/msp/fsp linear. - - # tensor parallel - if gpc.config.parallel.tensor.mode == TensorParallelMode.mtp.name: - ColumnParallelLinear.register_cls_communicator( - TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) - ) - RowParallelLinear.register_cls_communicator( - TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) - ) - - if gpc.config.model.get("num_experts", 1) > 1: - GroupedColumnLinear.register_cls_communicator( - TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) - ) - GroupedRowLinear.register_cls_communicator( - TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) - ) - GroupedWPLinear.register_cls_communicator(None) - # treat as sequence paralle if no_tp - if gpc.config.parallel.expert.no_tp: - _column_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN - ) - _row_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW - ) - for moe in _submodule_filter(model, MoE): - # 1. the linear in MoE degrades as no tp communication pattern - for column_linear in _submodule_filter(moe, ColumnParallelLinear): - column_linear.register_communicator(_column_communicator) - for row_linear in _submodule_filter(moe, RowParallelLinear): - row_linear.register_communicator(_row_communicator) - # 2. register MoESequenceParallelCommunicator for MoE layer - MoESequenceParallelCommunicator(ParallelMode.TENSOR, reverse=True).register_module_hook(moe) - - _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) - _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR) - # sequence parallel - if gpc.config.parallel.tensor.mode in (TensorParallelMode.msp.name, TensorParallelMode.fsp.name): - save_total_input_as_activation = gpc.config.parallel.tensor.mode == TensorParallelMode.msp.name - - ColumnParallelLinear.register_cls_communicator( - SequenceParallelCommunicator( - process_group=gpc.get_group(ParallelMode.TENSOR), - role=LinearRole.COLUMN, - save_total_input_as_activation=save_total_input_as_activation, - ) - ) - RowParallelLinear.register_cls_communicator( - SequenceParallelCommunicator( - gpc.get_group(ParallelMode.TENSOR), - role=LinearRole.ROW, - save_total_input_as_activation=save_total_input_as_activation, - ) - ) - if gpc.config.model.get("num_experts", 1) > 1: - GroupedColumnLinear.register_cls_communicator( - SequenceParallelCommunicator( - process_group=gpc.get_group(ParallelMode.TENSOR), - role=LinearRole.COLUMN, - save_total_input_as_activation=save_total_input_as_activation, - ) - ) - GroupedRowLinear.register_cls_communicator( - SequenceParallelCommunicator( - gpc.get_group(ParallelMode.TENSOR), - role=LinearRole.ROW, - save_total_input_as_activation=save_total_input_as_activation, - ) - ) - GroupedWPLinear.register_cls_communicator(None) - if gpc.config.parallel.expert.no_tp: - _column_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN - ) - _row_communicator = TensorParallelCommunicator( - process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW - ) - for moe in _submodule_filter(model, MoE): - # 1. the linear in MoE degrades as no tp communication pattern - for column_linear in _submodule_filter(moe, ColumnParallelLinear): - column_linear.register_communicator(_column_communicator) - for row_linear in _submodule_filter(moe, RowParallelLinear): - row_linear.register_communicator(_row_communicator) - - _head_communicator = HeadSequenceParallelCommunicator( - ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation - ) - - _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR) - - # register communitorc for embedding layer. - if not is_using_fsdp(): - for embedding in _submodule_filter(model, Embedding1D): - _embedding_communicator.register_module_hook(embedding) - - # register communictor for head layer. - ScaleColumnParallelLinear.register_cls_communicator(_head_communicator) - RewardModelLinear.register_cls_communicator(_head_communicator) - - return isp_communicator_wrapper - - -@llm_timeout(func_name="initialize_optimizer") -def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicatorWrapper = None): - """ - Initialize optimizer. - - Args: - model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated. - - Returns: - A tuple of (optimizer, beta2_scheduler, lr_scheduler). - """ - - adam_cfg = gpc.config.adam - zero_cfg = gpc.config.hybrid_zero_optimizer - grad_scal_cfg = gpc.config.grad_scaler - use_apex_adam = getattr(gpc.config, "use_apex_adam", False) - - if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim: - map_param_block(model) - - params = create_param_groups(model, adam_cfg.weight_decay) - - naive_optimizer = new_compatible_adamw( - params=params, - lr=adam_cfg.lr, - betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), - eps=adam_cfg.adam_eps, - use_apex_adam=use_apex_adam, - ) - - if ( - zero_cfg.overlap_sync_grad - and gpc.is_using_parallel_mode(ParallelMode.PIPELINE) - and gpc.is_pipeline_first_stage() is False - ): - # When pipeline parallelism is enabled, we prefer to only enable optimizer - # gradient communication overlap in the first stage, to avoid amplifying - # the communication overhead stage by stage in cases where the optimizer - # communication overhead is greater than the compute overhead. - # For pipeline stages except the first, even if overlap is not enabled, - # their gradient synchronization overhead can be well hidden by - # the inherent bubbles of pipeline parallelism. - zero_cfg.overlap_sync_grad = False - - if zero_cfg.overlap_sync_param: - param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator) - else: - param_bcast_sync_handler = None - - if not is_using_fsdp(): - if ( - "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer - or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim - ): - optimizer = HybridZeroOptimizer( - naive_optimizer, - grad_scal_cfg=grad_scal_cfg, - zero_cfg=zero_cfg, - param_bcast_sync_handler=param_bcast_sync_handler, - isp_communicator=isp_communicator, - ) - else: - optimizer = HybridZeroOptimizer_v2( - naive_optimizer, - grad_scal_cfg=grad_scal_cfg, - zero_cfg=zero_cfg, - param_bcast_sync_handler=param_bcast_sync_handler, - isp_communicator=isp_communicator, - ) - else: - optimizer = FSDPadaptOptimizer( - naive_optimizer, - grad_scal_cfg=grad_scal_cfg, - zero_cfg=zero_cfg, - ) - - beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) - - lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler) - - return optimizer, beta2_scheduler, lr_scheduler - - -def get_scheduler_hooks(metric, zero_optim, isp_communicator_wrapper) -> List[SchedulerHook]: - scheduler_hooks: List[SchedulerHook] = [] - - if metric is not None: - scheduler_hooks.append( - SchedulerMetricHook( - metric=metric, - skip=( - gpc.is_using_parallel_mode(ParallelMode.PIPELINE) - and hasattr(gpc.config.model, "num_chunks") - and gpc.config.model.num_chunks > 1 - and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) - ), - ), - ) - - if isp_communicator_wrapper is not None: - for isp_communicator in isp_communicator_wrapper.isp_communicators: - if isp_communicator is not None and isp_communicator.overlap: - scheduler_hooks.append(ISPCommunicatorSchedulerHook(isp_communicator, zero_optim)) - - return scheduler_hooks - - -@llm_timeout(func_name="load_new_batch") -def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): - """ - Load and return the new batch data based on training data loader. - - Args: - train_dl (torch.utils.data.DataLoader): Dataloader for training. - train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). - train_state (TrainState): Current training state. - - Returns: A batch data and the updated train_iter. - """ - - timer("batch-gen").start() - try: - batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor) - if hasattr(train_state, "batch_sampler_iter"): - next(train_state.batch_sampler_iter) - except StopIteration: - train_iter = iter(train_dl) - batch = next(train_iter) - train_state.num_consumed_samples_in_epoch = 0 - if hasattr(train_state, "batch_sampler"): - train_state.batch_sampler.batch_count = 0 - train_state.batch_sampler.num_consumed_samples_in_epoch = 0 - train_state.batch_sampler_iter = iter(train_state.batch_sampler) - next(train_state.batch_sampler_iter) - timer("batch-gen").stop() - - if batch[0].get("type_ids", None) is not None: - # if use_packed_dataset is False, we need to unpack type_ids - if not gpc.config.data.use_packed_dataset: - batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) - - return batch, train_iter - - -def initialize_llm_profile(profiling: bool = False, start_time: str = None): - """Initialize and return the profiler context manager instance.""" - - if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 3} - trace_path = ( - f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_" - f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" - f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" - f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}" - ) - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - ) - llm_profile = torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(**schedule_config), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_path), - record_shapes=True, - profile_memory=True, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) - logger.info(f"Do profiling for NPU on rank {gpc.get_global_rank()}!") - else: - llm_profile = torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=torch.profiler.schedule(**schedule_config), - on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), - with_stack=True, - with_modules=True, - profile_memory=True, - ) - logger.info(f"Do profiling for GPU on rank {gpc.get_global_rank()}!") - else: - llm_profile = DummyProfile() - - return llm_profile - - -@llm_timeout(func_name="record_current_batch_training_metrics") -def record_current_batch_training_metrics( - get_tflops_func, - logger, - writer, - success_update, - batch_count, - batch, - train_state, - optimizer, - beta2_scheduler, - engine, - start_time, - very_begining_time, - loss, - moe_loss, - grad_norm, - metric, -): - """ - Print some training metrics of current batch. - """ - - set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) - - timer.store_last_timers() - if success_update in (0, True): - train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) - if gpc.is_no_pp_or_last_stage(): - acc_perplex = metric.get_metric() - - if success_update and gpc.is_rank_for_log(): - lr = optimizer.param_groups[0]["lr"] - if hasattr(engine.optimizer, "grad_scaler"): - scaler = engine.optimizer.grad_scaler._scale.item() - elif hasattr(engine.optimizer.optim, "grad_scaler"): - scaler = engine.optimizer.optim.grad_scaler._scale.item() - - num_tokens_in_batch = batch[1].nelement() - real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL)) - num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) - max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - time_cost = time.time() - start_time - tk_per_gpu = round( - num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL), - 4, - ) - tgs_statistic = train_state.tgs_statistic - tgs_statistic["sum_step"] += 1 - tgs_statistic["sum_tg"] += tk_per_gpu - tgs_statistic["total_time"] = time.time() - very_begining_time - tgs_statistic["sum_last_tg_10"] += tk_per_gpu - tgs_statistic["sum_last_time_10"] += time_cost - tgs_statistic["sum_last_tg_50"] += tk_per_gpu - tgs_statistic["sum_last_time_50"] += time_cost - tgs_statistic["SMA_tg_50"] += tk_per_gpu - tgs_statistic["SMA_time_50"] += time_cost - tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu) - tgs_statistic["SMA_time_50_list"].append(time_cost) - if tgs_statistic["sum_step"] > 50: - tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0] - tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0] - tgs_statistic["SMA_tg_50_list"].popleft() - tgs_statistic["SMA_time_50_list"].popleft() - - last_tgs_1 = round(tk_per_gpu / time_cost, 2) - tgs_statistic["sum_tgs"] += last_tgs_1 - - if tgs_statistic["sum_step"] % 10 == 0: - tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2) - tgs_statistic["sum_last_tg_10"] = 0 - tgs_statistic["sum_last_time_10"] = 0 - - if tgs_statistic["sum_step"] % 50 == 0: - tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2) - tgs_statistic["sum_last_tg_50"] = 0 - tgs_statistic["sum_last_time_50"] = 0 - - last_tgs_10 = tgs_statistic["last_tgs_10"] - last_tgs_50 = tgs_statistic["last_tgs_50"] - - tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["total_time"], 2) - tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2) - tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2) - - tflops = get_tflops_func(time_cost) - - tgs_origin = round( - num_tokens_in_batch - * gpc.get_world_size(ParallelMode.DATA) - / gpc.get_world_size(ParallelMode.GLOBAL) - / time_cost, - 2, - ) - - real_tgs = round( - real_num_tokens / time_cost, - 2, - ) - - infos = { - "tflops": tflops, - "step": batch_count, - "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), - "real_tgs": real_tgs, - "tgs (tokens/gpu/second)": tgs_origin, - "tgs/last_tgs_1": last_tgs_1, - "tgs/tgs_all": tgs_all, - "tgs/tgs_avg": tgs_avg, - "tgs/tgs_SMA": tgs_SMA, - "tgs/last_tgs_10": last_tgs_10, - "tgs/last_tgs_50": last_tgs_50, - "lr": lr, - "loss_scale": scaler, - "grad_norm": grad_norm, - } - if moe_loss is not None: - infos["moe_loss"] = moe_loss.item() - - infos["micro_num"] = len(batch[1]) - infos["num_consumed_tokens"] = train_state.num_consumed_tokens - infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches - infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples - infos["largest_length"] = max_length_in_batch # the longest input - infos["largest_batch"] = max_samples_in_batch # the batch with the most samples - infos["smallest_batch"] = min_samples_in_batch - infos["adam_beta2"] = beta2_scheduler.get_beta2() - - fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) - infos["fwd_bwd_time"] = fwd_bwd_time - bwd_time = round(timer("bwd").elapsed(), 2) - infos["bwd_time"] = bwd_time - - for key, value in acc_perplex.items(): - infos[key] = value - - line = "" - for key, value in infos.items(): - line += f"{key}={value} " - if isinstance(value, dict): - writer.add_scalars(key=key, value=value, step=train_state.step_count) - else: - writer.add_scalar(key=key, value=value, step=train_state.step_count) - - logger.info(line) - - # if loss spike occurs, send alert info to feishu - mm.monitor_loss_spike( - alert_address=gpc.config.monitor.alert.feishu_alert_address, - step_count=batch_count, - cur_step_loss=loss.item(), - ) - - -def inject_embed(model: nn.Module, inject=False, interactive=False) -> None: - def traverse(module): - for name, child in module.named_children(): - if isinstance(child, nn.Embedding) and not isinstance(child, Embedding1D): - msg = ( - f"To get parallel training enabled, module {name} of type {nn.Embedding.__name__} " - + f"is required to be replaced with {Embedding1D.__name__}." - ) - if inject: - help_msg = f"Do you want to replace {name}? (y/n)" - opt = timeout_input( - f"{msg}\n{help_msg}", - default="y", - timeout=60, - interactive=interactive, - ) - if opt in ["y", "yes"]: - child_new = Embedding1D( - num_embeddings=child.num_embeddings, - embedding_dim=child.embedding_dim, - padding_idx=child.padding_idx, - ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) - else: - if gpc.is_rank_for_log(): - logger.warning(f"Skip replacing {name}") - else: - if gpc.is_rank_for_log(): - logger.warning(msg) - else: - traverse(child) - - traverse(model) - - -def inject_linear(model: nn.Module, inject=False, interactive=False) -> None: - def traverse(module): - for name, child in module.named_children(): - if isinstance(child, nn.Linear) and not isinstance(child, ParallelLinearWithCommExt): - msg = ( - f"To get parallel training enabled, module {name} of type {nn.Linear.__name__} " - + f"is required to be replaced with {new_linear.__name__}." - ) - if inject: - help_msg = f"Do you want to replace {name}? (y/n)" - opt = timeout_input( - f"{msg}\n{help_msg}", - default="y", - timeout=60, - interactive=interactive, - ) - if opt in ["y", "yes"]: - child_new = new_linear( - name=LINEAR2NEWLINEAR_NAME_MAPPING.get(name, name), - in_features=child.in_features, - out_features=child.out_features, - bias=child.bias is not None, - ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) - else: - if gpc.is_rank_for_log(): - logger.warning(f"Skip replacing {name}") - else: - if gpc.is_rank_for_log(): - logger.warning(msg) - else: - traverse(child) - - traverse(model) - - -def inject_norm(model: nn.Module, inject=False, interactive=False) -> None: - def traverse(module): - for name, child in module.named_children(): - cls_name = type(child).__name__ - if "RMSNorm" in cls_name: - msg = ( - f"To re-use unified RMSNorm implementation, {cls_name} " - + f"is suggested to be replaced with {new_layer_norm.__name__}." - ) - if inject: - help_msg = f"Do you want to replace {name}? (y/n)" - opt = timeout_input( - f"{msg}\n{help_msg}", - default="y", - timeout=60, - interactive=interactive, - ) - if opt in ["y", "yes"]: - child_new = new_layer_norm( - norm_type="rmsnorm", - normalized_shape=child.weight.shape, - eps=child.variance_epsilon, - ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) - else: - if gpc.is_rank_for_log(): - logger.warning(f"Skip replacing {name}") - else: - if gpc.is_rank_for_log(): - logger.warning(msg) - else: - traverse(child) - - traverse(model) - - -def inject_config(model: nn.Module) -> None: - # Compatibility for Vision-Language Model - if hasattr(model.config, "text_config"): - llm_cfg = model.config.text_config - else: - llm_cfg = model.config - gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = llm_cfg.vocab_size - gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = llm_cfg.hidden_size - gpc.config.model.num_layers = gpc.config.NUM_LAYER = llm_cfg.num_hidden_layers - # Compatibility for Mamba - if hasattr(llm_cfg, "num_attention_heads"): - gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = llm_cfg.num_attention_heads - gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = llm_cfg.intermediate_size / llm_cfg.hidden_size - # For models that use GQA - if hasattr(llm_cfg, "num_key_value_heads"): - gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads - - -def _get_modules_to_materialize( - root_module: nn.Module, - ignored_modules: Set[nn.Module], -) -> List[nn.Module]: - # Run BFS to collect the modules to materialize via `reset_parameters()`, - # stopping at any module with FSDP already applied or at ignored modules. - modules_to_materialize: List[nn.Module] = [] - queue = collections.deque([root_module]) - visited_modules: Set[nn.Module] = {root_module} - while queue: - module = queue.popleft() - modules_to_materialize.append(module) - for child_module in module.children(): - if child_module not in visited_modules and child_module not in ignored_modules: - visited_modules.add(child_module) - queue.append(child_module) - return modules_to_materialize - - -def _materialize_meta_module( - root_module: nn.Module, - ignored_modules: Set[nn.Module], - device_id: Optional[torch.device], -) -> None: - # Run default meta device initialization - modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) - module = None - try: - # Assume that each module's `reset_parameters()` only initializes its - # own parameters and not those of its children - with torch.no_grad(): - for module in modules_to_materialize: - # As a contract to the user, only call `reset_parameters()` if - # the module has directly managed parameters/buffers - module_state_iter = itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False)) - has_module_states = len(list(module_state_iter)) > 0 - if has_module_states: - module.to_empty(device=device_id, recurse=False) - module.reset_parameters() # type: ignore[operator] - except BaseException as e: - logger.warning( - "Unable to call `reset_parameters()` for module on meta " - f"device with error {str(e)}. Please ensure that your module of" - f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] - ) - raise e - - -def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - if is_using_fsdp(): - assert isinstance(model, nn.Module), "Currently FSDP does not support pipeline parallel." - wrap_cls = tuple( - LazyObject(warp_cls["mod"], warp_cls["mod_cls"]).build() for warp_cls in gpc.config.get("fsdp_wrap_cls", []) - ) - fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1") - fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") - - if fsdp_mode == "v1": - model = FSDP( - module=model, - process_group=gpc.get_group(ParallelMode.GLOBAL), - sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD - auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)), - sync_module_states=fsdp_init_method != "cuda", # sync model paramters - forward_prefetch=True, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, - limit_all_gathers=True, - use_orig_params=True, - device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states - ) - # For FSDP v1, to get ckpt resuming work normally, we do dummy forward. - # This hack is needed due to FSDP v1 lazy initialization in model construction. - # FYI: https://github.com/pytorch/pytorch/issues/113496 - model = init_fsdp_v1(model, get_current_device()) - elif FSDP2_SUPPORTED and fsdp_mode == "v2": - fsdp_kwargs = { - "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True - } - for module in model.modules(): - if isinstance(module, wrap_cls): - fully_shard(module, **fsdp_kwargs) - fully_shard(model, **fsdp_kwargs) - if fsdp_init_method == "meta": - _materialize_meta_module(model, set(), get_current_device()) - elif fsdp_init_method == "cpu": - model.to(get_current_device()) - else: - raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}") - - if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False): - load_ckpt_info = gpc.config.ckpt.load_ckpt_info - load_ckpt_path = load_ckpt_info.get("path", None) - load_ckpt_content = load_ckpt_info.get("content", []) - if load_ckpt_path: - assert load_ckpt_content == ( - "model", - ), "If auto_resume=False and checkpoint path is given, only model can be loaded" - if DCP_SUPPORTED: - hf = gpc.config.hf - mod = LazyObject(hf.mod, hf.mod_cls) - mod = mod.build() - state_dict = mod.from_pretrained( - pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True - ).state_dict() - state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} - set_model_state_dict( - model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) - ) - del state_dict - internlm_accelerator.empty_cache() - else: - raise RuntimeError("DCP is not supported in this version of PyTorch.") - - return model - - -def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None: - """ - Inject model helper functions. - - Args: - model (Union[nn.Module, nn.ModuleList]): - For built-in models, it is nn.Module for no pp and nn.ModuleList for pp. - For injected models, it is nn.Module. - inject_info (Optional[Dict]): configurations for injected_models. - """ - # parse inject_info - if inject_info is not None: - inject = inject_info.get("inject", False) - interactive = inject_info.get("interactive", False) - modules = inject_info.get("modules", []) - reset_params = inject_info.get("reset_params", False) - extra_linear2newlinear = inject_info.get("extra_linear2newlinear", {}) - else: - inject = False - interactive = False - modules = [] - reset_params = False - extra_linear2newlinear = {} - - LINEAR2NEWLINEAR_NAME_MAPPING.update(extra_linear2newlinear) - - inject_funcs = { - "embed": inject_embed, - "linear": inject_linear, - "norm": inject_norm, - } - - # inject config - if inject: - inject_config(model) - - if not isinstance(model, nn.ModuleList): - model = [model] - for _chunk in model: - # Special case for pure dp mode: skip - if ( - isinstance(gpc.config.parallel["tensor"], dict) - and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name - and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) - ): - continue - # In-place replacement or check for modules: "embed", "linear", "norm" - # (1) If inject=True, in-place replacement - # (2) If inject=False, check - for mod in modules: - inject_funcs[mod](_chunk, inject, interactive) - # reset parameters if needed, model should have reset_parameters() method - if reset_params: - _chunk.reset_parameters() - for _chunk in model: - # If model is initialized on cpu, model should be moved to cuda device after injection - if not next(_chunk.parameters()).is_cuda: - _chunk.to(get_current_device()) - - # print injected model - if inject and gpc.is_rank_for_log(): - logger.info( - f"inject is enabled, please check the model carefully, " - f"if there are any problems, please report issue to us. " - f"The injected model is \n {model}" - ) diff --git a/internlm/train/utils.py b/internlm/train/utils.py deleted file mode 100644 index d1bf4fe90..000000000 --- a/internlm/train/utils.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Dict, Tuple - -import torch -from torch import nn - -from internlm.core.context.parallel_context import ParallelMode -from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.naive_amp import unwrap_naive_amp -from internlm.model.modules.utils import is_moe_param -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -def split_params_into_different_groups_for_optimizer( - param_groups: Tuple[Dict], -) -> Tuple[Dict]: - """Split parameters into different groups for optimizer - - Args: - param_groups (Tuple[Dict]): The list of parameter groups to split - Input Example: - >>> ( - >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, - >>> ) - - Returns: - Tuple[Dict]: list of params groups for optimizer - Output Example: - >>> ( - >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, - >>> {'name': 'embed_head', 'params': [tensor], 'weight_decay' :xxx}, - >>> {'name': 'fp32', 'params': [tensor], 'weight_decay' :xxx}, - >>> ) - """ - - if isinstance(param_groups, tuple): - param_groups = list(param_groups) # Tuple cannot be modified - elif isinstance(param_groups, dict): - param_groups = [param_groups] - elif not isinstance(param_groups, list): - raise ValueError(f"Unknown param group type of {type(param_groups)}") - - new_groups = {} - # create new groups for fp32 parameter group - new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1} - - if gpc.config.model.get("num_experts", 1) > 1: - for key in gpc.expert_parallel_group_names: - new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": ParallelMode.EXPERT_DATA} - - for pgroup in param_groups: - # copy attribute from origin group, we assume the input param_groups only - # have one group, so the attribute will not be copyed multiple times. - for ori_key in pgroup.keys(): - if ori_key not in ("name", "params"): - for _, group in new_groups.items(): - group[ori_key] = pgroup[ori_key] - # assign param - origin_params = [] - for param in pgroup["params"]: - # moe param means MoE is enabled - if is_moe_param(param): - new_groups[param.group_name]["params"].append(param) - elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32: - new_groups["fp32"]["params"].append(param) - else: - origin_params.append(param) - - # default param group, which is the first group in the param groups - pgroup["params"] = origin_params - pgroup["optimizer_mode"] = ParallelMode.ZERO1 - - # param groups may contain empty groups, such as fp32 - param_groups.extend(new_groups.values()) - - return tuple(param_groups) - - -def create_param_groups(model, weight_decay): - parameters = { - "params": [param for param in model.parameters() if param.requires_grad], - "name": "default", - "weight_decay": weight_decay, - } - return split_params_into_different_groups_for_optimizer(parameters) - - -def map_param_block(model): - for _chunk in unwrap_naive_amp(model): - for name, children in _chunk.named_children(): - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - block_name = name + f"_{idx}" - for param in block.parameters(): - setattr(param, "block_name", block_name) - else: - for param in children.parameters(): - setattr(param, "block_name", name) - - -def timeout_input(printout, default, timeout=None, interactive=True): - if not interactive: - return default - import select - import sys - - if gpc.is_rank_for_log(): - logger.info(printout) - - i, _, _ = select.select([sys.stdin], [], [], timeout) - if i: - msg = sys.stdin.readline().strip() - return default if len(msg) == 0 else msg - else: - return default diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 56ebcfbe6..c444e456a 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import argparse import bisect import inspect import os @@ -15,7 +16,6 @@ import numpy as np import torch -import internlm from internlm.accelerator import AcceleratorType, get_accelerator from internlm.utils.logger import get_logger @@ -24,8 +24,39 @@ internlm_accelerator = get_accelerator() +def get_default_parser(): + """Reads user command line and uses an argument parser to parse the input arguments. + Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. + + Returns: + Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, help="path to the config file") + parser.add_argument( + "--launcher", + type=str, + default="slurm", + choices=["slurm", "torch"], + help="launcher for launching distributed environment", + ) + parser.add_argument("--host", type=str, help="the master address for distributed training") + parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training") + parser.add_argument("--world_size", type=int, help="world size for distributed training") + parser.add_argument("--rank", type=int, help="rank for the default process group") + parser.add_argument("--local_rank", type=int, help="local rank on the node") + parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") + parser.add_argument("--seed", type=int, default=1024) + parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") + parser.add_argument("--enable_ali_topology", default=False, action="store_true", help="enable ali switch topology.") + parser.add_argument( + "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology." + ) + return parser + + def parse_args(): - parser = internlm.get_default_parser() + parser = get_default_parser() args = parser.parse_args() return args @@ -318,3 +349,7 @@ def __setitem__(self, key, value): mapping[key] = value return self.maps[0][key] = value + + +def set_env_var(key, value): + os.environ[str(key)] = str(value) diff --git a/internlm/utils/config.py b/internlm/utils/config.py new file mode 100644 index 000000000..7d54d0ca8 --- /dev/null +++ b/internlm/utils/config.py @@ -0,0 +1,103 @@ +import inspect +import sys +from importlib.machinery import SourceFileLoader +from pathlib import Path + + +class Config(dict): + """This is a wrapper class for dict objects so that values of which can be + accessed as attributes. + + Args: + config (dict): The dict object to be wrapped. + """ + + def __init__(self, config: dict = None): # pylint: disable=W0231 + if config is not None: + for k, v in config.items(): + self._add_item(k, v) + + def __missing__(self, key): + raise KeyError(key) + + def __getattr__(self, key): + try: + value = super().__getitem__(key) + return value + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + super().__setitem__(key, value) + + def _add_item(self, key, value): + if isinstance(value, dict): + self.__setattr__(key, Config(value)) + else: + self.__setattr__(key, value) + + def update(self, config): + assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." + for k, v in config.items(): + self._add_item(k, v) + return self + + @staticmethod + def from_file(filename: str): + """Reads a python file and constructs a corresponding :class:`Config` object. + + Args: + filename (str): Name of the file to construct the return object. + + Returns: + :class:`Config`: A :class:`Config` object constructed with information in the file. + + Raises: + AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file + """ + + # check config path + if isinstance(filename, str): + filepath = Path(filename).absolute() + elif isinstance(filename, Path): + filepath = filename.absolute() + + assert filepath.exists(), f"{filename} is not found, please check your configuration path" + + # check extension + extension = filepath.suffix + assert extension == ".py", "only .py files are supported" + + # import the config as module + remove_path = False + if filepath.parent not in sys.path: + sys.path.insert(0, (filepath)) + remove_path = True + + module_name = filepath.stem + source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) + module = source_file.load_module() # pylint: disable=W4902,E1120,W1505 + + # load into config + config = Config() + + for k, v in module.__dict__.items(): + if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): + continue + else: + config._add_item(k, v) + + # remove module + del sys.modules[module_name] + if remove_path: + sys.path.pop(0) + + return config + + +def get_config_value(config, key, defalut): + try: + value = config[key] + except KeyError: + value = defalut + return value diff --git a/internlm/utils/lazy.py b/internlm/utils/lazy.py index e67c63aa2..e8dc8d860 100644 --- a/internlm/utils/lazy.py +++ b/internlm/utils/lazy.py @@ -1,3 +1,4 @@ +# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py # Copyright (c) OpenMMLab. All rights reserved. import abc import importlib @@ -43,7 +44,7 @@ class LazyObject: During parsing process, the syntax like: Examples: - >>> import torch.nn as nn + >>> from torch import nn >>> from mmdet.models import RetinaNet >>> import mmcls.models >>> import mmcls.datasets @@ -52,7 +53,7 @@ class LazyObject: Will be parsed as: Examples: - >>> # import torch.nn as nn + >>> # from torch import nn >>> nn = lazyObject('torch.nn') >>> # from mmdet.models import RetinaNet >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py index 5b09f9d5a..d3720ca21 100644 --- a/internlm/utils/timeout.py +++ b/internlm/utils/timeout.py @@ -39,7 +39,7 @@ def __exit__(self, error_type, value, traceback): timeout_threshold_dict = { - "initialize_distributed_env": 240, + "init_distributed": 240, "nopp_forward_backward_step": 360, "initialize_model_and_parallel_communicator": 60, "initialize_optimizer": 60, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index a545f766c..621bc74b3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,20 +1,20 @@ -transformers +transformers<4.47.0 sentencepiece datasets numpy tqdm einops -psutil -packaging -pre-commit -ninja -gputil -pytest boto3 botocore -torch-scatter pyecharts py-libnuma pynvml +psutil +gputil tensorboard --f https://data.pyg.org/whl/torch-2.1.0+cu118.html +ninja +packaging +pre-commit +pylint +pytest +image diff --git a/setup.py b/setup.py index f37599543..2dbe90f91 100644 --- a/setup.py +++ b/setup.py @@ -1,56 +1,50 @@ import os -import re -import sys -import subprocess -from setuptools import setup, find_packages -from setuptools.command.install import install +from typing import List + +from setuptools import find_packages, setup pwd = os.path.dirname(__file__) + def readme(): - with open(os.path.join(pwd, 'README.md')) as f: + with open(os.path.join(pwd, "README.md")) as f: content = f.read() return content + def get_version(): - with open(os.path.join(pwd, 'version.txt'), 'r') as f: + with open(os.path.join(pwd, "version.txt"), encoding="utf-8") as f: content = f.read() return content -def has_nvcc(): - try: - subprocess.run(['nvcc', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - return True - except (subprocess.CalledProcessError, FileNotFoundError): - return False - -def fetch_requirements(path): - with open(path, 'r') as fd: - return [r.strip() for r in fd.readlines() if 'torch-scatter' not in r and not r.startswith('-f ')] - -if has_nvcc(): - install_requires = [ - fetch_requirements('requirements/runtime.txt'), - 'rotary_emb', - 'xentropy', - ] -else: - install_requires = [ - fetch_requirements('requirements/runtime.txt'), - ] + +def get_requires() -> List[str]: + with open(os.path.join("requirements", "runtime.txt"), encoding="utf-8") as f: + file_content = f.read() + lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] + return lines + + +extra_require = { + "torch": ["torch>=2.1.0"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3"], +} setup( - name='InternEvo', + name="InternEvo", version=get_version(), - description='an open-sourced lightweight training framework aims to support model pre-training without the need for extensive dependencies', + description="Lightweight training framework for LLM", + author="InternEvo team", + license="Apache 2.0 License", long_description=readme(), - long_description_content_type='text/markdown', - packages=find_packages(), - install_requires=install_requires, + long_description_content_type="text/markdown", + packages=find_packages(exclude=["tests"]), + install_requires=get_requires(), + extras_require=extra_require, classifiers=[ - 'Programming Language :: Python :: 3.10', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', + "Programming Language :: Python :: 3.10", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", ], ) diff --git a/tests/common_fixture.py b/tests/common_fixture.py index e5a8b9aa1..fbd3763e9 100644 --- a/tests/common_fixture.py +++ b/tests/common_fixture.py @@ -5,12 +5,11 @@ import numpy as np import torch -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import Config from internlm.data.utils import unpack_type_ids -from internlm.initialize.launch import args_sanity_check +from internlm.initialize import initialize_launcher +from internlm.utils.config import Config internlm_accelerator = get_accelerator() @@ -119,9 +118,7 @@ def build_environment(rank, world_size, free_port, config): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(free_port) internlm_accelerator.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) - args_sanity_check() + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl") def seed_all(seed, cuda_deterministic=False): diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index efa5d7b71..3b41374dd 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -5,9 +5,9 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.utils.common import get_current_device +from internlm.utils.config import Config from tests.test_core.utils import ( MlpModel, MyLoss, diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 5ccaccaf3..1436e2988 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -5,7 +5,6 @@ from torch import nn from torch.testing import assert_close -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -17,8 +16,9 @@ NonPipelineScheduler, PipelineScheduler, ) -from internlm.model.metrics import SchedulerMetricHook -from internlm.train import initialize_optimizer +from internlm.initialize import initialize_launcher +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.model.model_ops.metrics import SchedulerMetricHook from internlm.utils.common import get_current_device internlm_accelerator = get_accelerator() @@ -155,8 +155,7 @@ def build_environment(rank, world_size, config): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "33333" internlm_accelerator.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=False, dist_backend="nccl") def loose_close(a, b, dtype: torch.dtype = torch.float32): diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 7600b7637..cf4400c0f 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -6,21 +6,18 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc - -# from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.data import ( build_train_loader_with_data_type, build_valid_loader_with_data_type, ) -from internlm.eval.evaluation import ( - switch_evaluation_mode, - switch_evaluation_pipeline_scheduler, -) -from internlm.train import load_new_batch +from internlm.eval import switch_evaluation_mode, switch_evaluation_pipeline_scheduler +from internlm.core.trainer import load_new_batch + +# from internlm.core.context import ParallelMode +from internlm.utils.config import Config -# from internlm.core.context.parallel_context import global_context as gpc +# from internlm.core.context import global_context as gpc from tests.test_core.utils import build_environment, init_model_and_optim micro_bszs = [1, 2] diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py index 14741b494..1d67b100d 100644 --- a/tests/test_infer/test_generate.py +++ b/tests/test_infer/test_generate.py @@ -5,8 +5,10 @@ from sentencepiece import SentencePieceProcessor from internlm.apis.inference import SequenceGenerator, batch_tokenize -from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.train import initialize_model_and_parallel_communicator +from internlm.initialize import initialize_launcher # noqa: E402 +from internlm.initialize.initialize_model import ( + initialize_model_and_parallel_communicator, +) def set_seed(seed: int = 1024): @@ -36,7 +38,7 @@ def load_and_generate(path, model_type="INTERNLM2", tokenizer_path=""): sequence_parallel=0, ), ) - initialize_distributed_env(evo_cfg, master_port=23574, args_check=False) + initialize_launcher(evo_cfg, distributed_port=23574, args_check=False) tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121 diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py index c3149dda3..4b6e0967d 100644 --- a/tests/test_infer/test_trainer_generate.py +++ b/tests/test_infer/test_trainer_generate.py @@ -3,23 +3,23 @@ import pytest from sentencepiece import SentencePieceProcessor -import internlm # noqa: E402 from internlm.apis.inference import SequenceGenerator, batch_tokenize from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState, Trainer # noqa: E402 +from internlm.core.trainer import Trainer, TrainState # noqa: E402 from internlm.data import build_train_loader_with_data_type # noqa: E402 -from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.model.losses import InternLoss # noqa: E402 -from internlm.train import ( # noqa: E402 - get_scheduler_hooks, +from internlm.initialize import initialize_launcher # noqa: E402 +from internlm.initialize.initialize_model import ( # noqa: E402 initialize_model_and_parallel_communicator, - initialize_optimizer, ) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize import initialize_trainer +from internlm.model.model_ops.losses import InternLoss # noqa: E402 +from internlm.core.trainer import get_scheduler_hooks # noqa: E402 def setup_generator(config, tokenizer): - initialize_distributed_env(config=config) + initialize_launcher(config=config) model, isp_communicator = initialize_model_and_parallel_communicator() @@ -45,7 +45,7 @@ def setup_generator(config, tokenizer): ckpt_manager.try_resume_training(train_state) # initialize trainer - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_model/test_embedding.py b/tests/test_model/test_embedding.py index d8b58f552..b7252e376 100644 --- a/tests/test_model/test_embedding.py +++ b/tests/test_model/test_embedding.py @@ -3,7 +3,7 @@ import pytest import torch -from internlm.model.modules.embedding import Embedding1D +from internlm.model.model_ops.modules.embedding import Embedding1D from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port from tests.test_model.test_model_internlm import build_environment, seed_all diff --git a/tests/test_model/test_feed_forward.py b/tests/test_model/test_feed_forward.py index 311f30d7e..c55e55716 100644 --- a/tests/test_model/test_feed_forward.py +++ b/tests/test_model/test_feed_forward.py @@ -1,7 +1,7 @@ import pytest import torch -from internlm.model.modules.mlp import new_feed_forward, split_fused_mlp_weight +from internlm.model.model_ops.modules.mlp import new_feed_forward, split_fused_mlp_weight from internlm.utils.common import get_current_device SEQ_LEN = 64 diff --git a/tests/test_model/test_fused_precision/test_fused_precision.py b/tests/test_model/test_fused_precision/test_fused_precision.py index d0b79aaef..98d3511c3 100644 --- a/tests/test_model/test_fused_precision/test_fused_precision.py +++ b/tests/test_model/test_fused_precision/test_fused_precision.py @@ -6,9 +6,11 @@ from torch import nn from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module -from internlm.model.modeling_internlm import InternLM1Decoder -from internlm.train.pipeline import initialize_parallel_communicator -from internlm.train.utils import create_param_groups +from internlm.initialize.initialize_communicator import initialize_parallel_communicator +from internlm.model.model_implementations.transformers.modeling_internlm import ( + InternLM1Decoder, +) +from internlm.initialize.initialize_optimizer import create_param_groups from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port from tests.test_model.test_model_internlm import build_environment, seed_all diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index e2655d291..eaeff0ebf 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -6,25 +6,27 @@ import torch from torch import nn -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context.parallel_context import Config -from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.parallel.comm.tensor import ( +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import ( HeadTensorParallelCommunicator, LinearRole, TensorParallelCommunicator, ) from internlm.core.parallel.comm.utils import gather_forward_split_backward -from internlm.model.modeling_internlm import InternLM1Decoder -from internlm.model.modules.linear import ( +from internlm.initialize import initialize_launcher +from internlm.model.model_implementations.transformers.modeling_internlm import ( + InternLM1Decoder, +) +from internlm.model.model_ops.modules.linear import ( ColumnParallelLinear, RowParallelLinear, ScaleColumnParallelLinear, new_linear, ) from internlm.utils.common import get_current_device +from internlm.utils.config import Config from tests.common_fixture import find_free_port internlm_accelerator = get_accelerator() @@ -83,8 +85,7 @@ def build_environment(rank, world_size, free_port): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = free_port internlm_accelerator.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=False, dist_backend="nccl") def seed_all(seed, cuda_deterministic=False): diff --git a/tests/test_model/test_norm.py b/tests/test_model/test_norm.py index 83861b365..2e32d8b1b 100644 --- a/tests/test_model/test_norm.py +++ b/tests/test_model/test_norm.py @@ -3,7 +3,7 @@ import pytest import torch -from internlm.model.modules.norm import new_layer_norm +from internlm.model.model_ops.modules.norm import new_layer_norm from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port from tests.test_model.test_model_internlm import build_environment, seed_all diff --git a/tests/test_model/test_npu_ops/test_flash_attention.py b/tests/test_model/test_npu_ops/test_flash_attention.py index a2a8b91b8..96c11cde3 100644 --- a/tests/test_model/test_npu_ops/test_flash_attention.py +++ b/tests/test_model/test_npu_ops/test_flash_attention.py @@ -12,11 +12,14 @@ from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import Config from internlm.core.context import global_context as gpc -from internlm.model.ops.attention import SelfAttention -from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn +from internlm.model.model_ops.ops.attention import SelfAttention +from internlm.model.model_ops.ops.utils import ( + pack_output_after_attn, + unpack_qkv_before_attn, +) from internlm.utils.common import get_current_device, set_random_seed +from internlm.utils.config import Config HEAD_NUM = 32 HIDDEN_SZIE = 4096 @@ -139,7 +142,7 @@ def npu_transform(B, S, N_KV, dtype): def deeplink_fwd_transform(B, S, N_KV, dtype): from deeplink_ext.internevo_ops import FlashSelfAttention - from internlm.model.modules.multi_head_attention import CrossAttention + from internlm.model.model_ops.modules.multi_head_attention import CrossAttention set_random_seed(1024) softmax_scale = 1 / math.sqrt(HEAD_DIM) diff --git a/tests/test_model/test_npu_ops/test_npu_rmsnorm.py b/tests/test_model/test_npu_ops/test_npu_rmsnorm.py index adeb37c00..74116655f 100644 --- a/tests/test_model/test_npu_ops/test_npu_rmsnorm.py +++ b/tests/test_model/test_npu_ops/test_npu_rmsnorm.py @@ -2,8 +2,8 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.model.ops.norm import _RMSNorm as RMSNormTorch -from internlm.model.ops.norm import _RMSNormNPU as RMSNormNPU +from internlm.model.model_ops.ops.norm import _RMSNorm as RMSNormTorch +from internlm.model.model_ops.ops.norm import _RMSNormNPU as RMSNormNPU from internlm.utils.common import get_current_device internlm_accelerator = get_accelerator() diff --git a/tests/test_model/test_npu_ops/test_rotary_embed.py b/tests/test_model/test_npu_ops/test_rotary_embed.py index 8fca38ce2..71f5d4312 100644 --- a/tests/test_model/test_npu_ops/test_rotary_embed.py +++ b/tests/test_model/test_npu_ops/test_rotary_embed.py @@ -3,7 +3,7 @@ from torch import nn from internlm.accelerator import get_accelerator -from internlm.model.ops.rotary_emb import ( +from internlm.model.model_ops.ops.rotary_emb import ( ApplyRotaryEmb, rotary_emb_in_rotate_half_style, ) diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py index ca470ffc9..617035c76 100644 --- a/tests/test_solver/test_optimizer.py +++ b/tests/test_solver/test_optimizer.py @@ -9,12 +9,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close -import internlm from internlm.accelerator import get_accelerator -from internlm.core.context.parallel_context import Config, ParallelMode -from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler +from internlm.core.context import ParallelMode +from internlm.core.parallel.comm import ParamAsyncBcastHandler +from internlm.initialize import initialize_launcher from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device +from internlm.utils.config import Config internlm_accelerator = get_accelerator() @@ -95,8 +96,7 @@ def build_environment(rank, world_size): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12345" internlm_accelerator.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=False, dist_backend="nccl") def loose_close(a, b, dtype: torch.dtype = torch.float32): diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index ab81dbeed..3795f59bc 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -7,21 +7,21 @@ import pytest import torch -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import Config from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.initialize.launch import args_sanity_check -from internlm.model.losses import InternLoss -from internlm.model.metrics import AccPerplex, SchedulerMetricHook -from internlm.train import ( +from internlm.initialize import initialize_launcher +from internlm.initialize.initialize_model import ( initialize_model_and_parallel_communicator, - initialize_optimizer, ) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize import initialize_trainer +from internlm.model.model_ops.losses import InternLoss +from internlm.model.model_ops.metrics import AccPerplex, SchedulerMetricHook from internlm.utils.common import get_current_device +from internlm.utils.config import Config from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -133,8 +133,7 @@ def build_environment(rank, world_size, free_port, config): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(free_port) internlm_accelerator.empty_cache() - internlm.launch_from_torch(config=config, seed=1024) - args_sanity_check() + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl") def seed_all(seed, cuda_deterministic=False): @@ -198,7 +197,7 @@ def train_check_output(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index f9516c279..6e705ac11 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -1,6 +1,7 @@ import multiprocessing as mp from internlm.accelerator import get_accelerator +from internlm.initialize.initialize_optimizer import initialize_optimizer backup_ForkingPickler = mp.reduction.ForkingPickler backup_dump = mp.reduction.dump @@ -14,7 +15,6 @@ import torch # noqa: E402 #pylint: disable=wrong-import-position import torch.distributed as dist # noqa: E402 #pylint: disable=wrong-import-position -import internlm # noqa: E402 #pylint: disable=wrong-import-position from internlm.checkpoint import ( # noqa: E402 #pylint: disable=wrong-import-position CheckpointManager, ) @@ -24,35 +24,39 @@ from internlm.core.context import ( # noqa: E402 #pylint: disable=wrong-import-position global_context as gpc, ) -from internlm.core.context.parallel_context import ( # noqa: E402 #pylint: disable=wrong-import-position - Config, -) from internlm.core.trainer import ( # noqa: E402 #pylint: disable=wrong-import-position - TrainState, Trainer, + TrainState, ) from internlm.data import ( # noqa: E402 #pylint: disable=wrong-import-position build_train_loader_with_data_type, ) -from internlm.initialize.launch import ( # noqa: E402 #pylint: disable=wrong-import-position - args_sanity_check, +from internlm.initialize import ( # noqa: E402 #pylint: disable=wrong-import-position + initialize_launcher +) +from internlm.initialize.initialize_model import ( # noqa: E402 #pylint: disable=wrong-import-position + initialize_model_and_parallel_communicator, ) -from internlm.model.losses import ( # noqa: E402 #pylint: disable=wrong-import-position +from internlm.initialize import ( # noqa: E402 #pylint: disable=wrong-import-position + initialize_trainer, +) +from internlm.model.model_ops.losses import ( # noqa: E402 #pylint: disable=wrong-import-position InternLoss, ) -from internlm.model.metrics import ( # noqa: E402 #pylint: disable=wrong-import-position +from internlm.model.model_ops.metrics import ( # noqa: E402 #pylint: disable=wrong-import-position AccPerplex, SchedulerMetricHook, ) -from internlm.train import ( # noqa: E402 #pylint: disable=wrong-import-position - initialize_model_and_parallel_communicator, - initialize_optimizer, +from internlm.core.trainer import ( # noqa: E402 #pylint: disable=wrong-import-position load_new_batch, ) from internlm.utils.common import ( # noqa: E402 #pylint: disable=wrong-import-position get_current_device, launch_time, ) +from internlm.utils.config import ( # noqa: E402 #pylint: disable=wrong-import-position + Config, +) from internlm.utils.logger import ( # noqa: E402 #pylint: disable=wrong-import-position get_logger, ) @@ -173,9 +177,7 @@ def build_environment(rank, world_size, free_port, config): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(free_port) internlm_accelerator.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) - args_sanity_check() + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl") def seed_all(seed, cuda_deterministic=False): @@ -265,7 +267,7 @@ def train_model(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index cdee0b18b..fb67421b9 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -5,22 +5,22 @@ import torch import torch.distributed as dist -import internlm from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint import CheckpointManager -from internlm.core.context import Config, ParallelMode +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.trainer import Trainer, TrainState +from internlm.core.trainer import Trainer, TrainState, get_scheduler_hooks from internlm.data import build_train_loader_with_data_type -from internlm.initialize import initialize_distributed_env -from internlm.model.losses import InternLoss -from internlm.train import ( - get_scheduler_hooks, +from internlm.initialize import initialize_launcher +from internlm.initialize.initialize_model import ( initialize_model_and_parallel_communicator, - initialize_optimizer, - load_new_batch, ) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize import initialize_trainer +from internlm.model.model_ops.losses import InternLoss +from internlm.core.trainer import load_new_batch from internlm.utils.common import BatchSkipper, launch_time +from internlm.utils.config import Config from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer @@ -129,7 +129,7 @@ def train( config.model.parallel_output = False config.model.checkpoint = True - initialize_distributed_env(config=config, launcher=launcher) + initialize_launcher(config=config, launcher=launcher) assert hasattr(gpc, "config") and gpc.config is not None # check parallel config @@ -200,7 +200,7 @@ def train( metric = None # initialize trainer - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index 0b0493bb2..e0715d67d 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -2,19 +2,19 @@ import pytest -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.model.losses import InternLoss -from internlm.model.metrics import AccPerplex -from internlm.train import ( - get_scheduler_hooks, +from internlm.initialize.initialize_model import ( initialize_model_and_parallel_communicator, - initialize_optimizer, ) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize import initialize_trainer +from internlm.model.model_ops.losses import InternLoss +from internlm.model.model_ops.metrics import AccPerplex +from internlm.core.trainer import get_scheduler_hooks from internlm.utils.logger import get_logger from tests.common_fixture import ( build_environment, @@ -50,7 +50,7 @@ def train_check(args): # set seed seed_all(1024) - # initialize model and isp communicator + # initialize model and isp communicator model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function @@ -67,7 +67,7 @@ def train_check(args): dataset_types=dataset_types, ) - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 1306da69b..c6c1be04e 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -5,19 +5,19 @@ import pytest import torch -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.model.losses import InternLoss -from internlm.model.metrics import AccPerplex -from internlm.train import ( - get_scheduler_hooks, +from internlm.initialize.initialize_model import ( initialize_model_and_parallel_communicator, - initialize_optimizer, ) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize import initialize_trainer +from internlm.model.model_ops.losses import InternLoss +from internlm.model.model_ops.metrics import AccPerplex +from internlm.core.trainer import get_scheduler_hooks from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from tests.common_fixture import ( @@ -87,7 +87,7 @@ def train_check_norm_weight(args): dataset_types=dataset_types, ) - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 84b79d9f0..fe0ea0d5f 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -9,25 +9,25 @@ import torch.distributed as dist from tqdm import tqdm -import internlm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import Config from internlm.core.trainer import Trainer from internlm.data import ( build_train_loader_with_data_type, build_valid_loader_with_data_type, ) -from internlm.eval.evaluation import switch_evaluation_mode -from internlm.initialize.launch import args_sanity_check -from internlm.model.losses import InternLoss -from internlm.model.metrics import AccPerplex, SchedulerMetricHook -from internlm.train import ( +from internlm.eval import switch_evaluation_mode +from internlm.initialize import initialize_launcher +from internlm.initialize.initialize_model import ( initialize_model_and_parallel_communicator, - initialize_optimizer, ) +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize import initialize_trainer +from internlm.model.model_ops.losses import InternLoss +from internlm.model.model_ops.metrics import AccPerplex, SchedulerMetricHook from internlm.utils.common import get_current_device +from internlm.utils.config import Config from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -136,9 +136,7 @@ def build_environment(rank, world_size, config): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "33333" internlm_accelerator.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) - args_sanity_check() + initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl") def seed_all(seed, cuda_deterministic=False): @@ -302,7 +300,7 @@ def exam_loss(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 623f0ccec..6b161ae30 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -12,11 +12,13 @@ import torch import torch.distributed as dist +from internlm.initialize.initialize_optimizer import initialize_optimizer +from internlm.initialize.initialize_profiler import initialize_llm_profile + script_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(script_dir, "../../")) sys.path.append(project_root) -import internlm # noqa: E402 from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 @@ -25,21 +27,21 @@ build_train_loader_with_data_type, build_valid_loader_with_data_type, ) -from internlm.eval.evaluation import evaluate_on_val_dls # noqa: E402 -from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.model.losses import InternLoss # noqa: E402 -from internlm.model.metrics import AccPerplex, SchedulerMetricHook # noqa: E402 -from internlm.monitor import ( # noqa: E402 - initialize_monitor_manager, - send_alert_message, -) -from internlm.monitor.monitor import monitor_manager as mm # noqa: E402 -from internlm.train import ( # noqa: E402 - initialize_llm_profile, +from internlm.eval import evaluate_on_val_dls # noqa: E402 +from internlm.initialize import initialize_launcher # noqa: E402 +from internlm.initialize.initialize_model import ( # noqa: E402 initialize_model_and_parallel_communicator, - initialize_optimizer, - record_current_batch_training_metrics, ) +from internlm.initialize import initialize_trainer # noqa: E402 +from internlm.model.model_ops.losses import InternLoss # noqa: E402 +from internlm.model.model_ops.metrics import ( # noqa: E402 + AccPerplex, + SchedulerMetricHook, +) +from internlm.monitor import initialize_monitor_manager # noqa: E402 +from internlm.monitor import monitor_manager as mm # noqa: E402 +from internlm.monitor import send_alert_message # noqa: E402 +from internlm.core.trainer import record_current_batch_training_metrics # noqa: E402 from internlm.utils.common import ( # noqa: E402 BatchSkipper, get_current_device, @@ -115,7 +117,7 @@ def main(args): current_time = objs[0] # initialize model - model , _ = initialize_model_and_parallel_communicator() + model, _ = initialize_model_and_parallel_communicator() with open(args.config, "r") as f: config_lines = f.readlines() @@ -180,7 +182,7 @@ def main(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, @@ -361,7 +363,7 @@ def main(args): hostname = socket.gethostname() # initialize distributed environment - initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None # initialize monitor manager context diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index d3405122d..0a3a00d59 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -6,13 +6,13 @@ import torch from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import Config from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.builder import create_model -from internlm.model.registry import register_model_initializer -from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer -from internlm.train.utils import create_param_groups +from internlm.model.model_implementations.builder import create_model +from internlm.model.model_implementations.registry import register_model_initializer +from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.initialize.initialize_optimizer import create_param_groups from internlm.utils.common import SingletonMeta +from internlm.utils.config import Config OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) OSS_IP = os.environ.get("OSS_IP", None) @@ -153,21 +153,21 @@ def reset_singletons(): def reset_seed(): - from internlm.core.context.random import _SEED_MANAGER + from internlm.core.context import _SEED_MANAGER _SEED_MANAGER.reset() @pytest.fixture(scope="module") def init_dist_and_model(rank=0, world_size=1): - from internlm.initialize import initialize_distributed_env + from internlm.initialize import initialize_launcher os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12377" - initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + initialize_launcher(config=init_config, launcher="torch", distributed_port=12377, args_check=False) # setup print("set up", flush=True) diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 5fe8b3c49..83b645dbe 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -10,10 +10,10 @@ import torch.distributed as dist from internlm.checkpoint import CheckpointManager -from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState -from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import SingletonMeta +from internlm.utils.config import Config from internlm.utils.storage_manager import wait_async_upload_finish from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import ASYNC_TMP_FOLDER, @@ -28,38 +28,6 @@ # (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY) step_info_list = [(8, 4, 2), (3, 4, 2), (1, 6, 3)] ckpt_config_list = [ - # Old interface format - dict( - enable_save_ckpt=True, - save_ckpt_folder=BOTO_SAVE_PATH, - load_optimizer=True, - checkpoint_every=0, - async_upload=True, - async_upload_tmp_folder=ASYNC_TMP_FOLDER, - snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None, - oss_snapshot_freq=0, - stop_file_path=None, - load_model_only_folder=None, - load_given_ckpt=False, - load_ckpt_folder=None, - is_old_api=True, - ), - # Old interface format - dict( - enable_save_ckpt=True, - save_ckpt_folder=LOCAL_SAVE_PATH, - load_optimizer=True, - checkpoint_every=0, - async_upload=False, - async_upload_tmp_folder=ASYNC_TMP_FOLDER, - snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]), - oss_snapshot_freq=0, - stop_file_path=None, - load_model_only_folder=None, - load_given_ckpt=False, - load_ckpt_folder=None, - is_old_api=True, - ), # New interface format dict( enable_save_ckpt=True, @@ -201,8 +169,8 @@ def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_fr @pytest.mark.parametrize("step_info", step_info_list) @pytest.mark.parametrize("ckpt_config", ckpt_config_list) def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import - from internlm.core.context import global_context as gpc from internlm.checkpoint.checkpoint_manager import CheckpointLoadMask + from internlm.core.context import global_context as gpc ckpt_config = Config(ckpt_config) total_step, checkpoint_every, oss_snapshot_freq = step_info @@ -297,9 +265,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: def query_quit_file(rank, world_size=2): - from internlm.core.context import global_context as gpc - from internlm.initialize import initialize_distributed_env from internlm.checkpoint.checkpoint_manager import CheckpointSaveType + from internlm.core.context import global_context as gpc + from internlm.initialize import initialize_launcher ckpt_config = Config( dict( @@ -325,7 +293,7 @@ def query_quit_file(rank, world_size=2): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12376" - initialize_distributed_env(config=init_config, launcher="torch", master_port=12376, args_check=False) + initialize_launcher(config=init_config, launcher="torch", distributed_port=12376, args_check=False) train_state = TrainState(init_config, None) ckpt_mm = CheckpointManager(ckpt_config, model=None, optimizer=None) if rank == 0: diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index 9454a8369..57021e5ce 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -3,8 +3,7 @@ import pytest import torch -from internlm.core.context.parallel_context import Config -from internlm.initialize.launch import get_config_value +from internlm.utils.config import Config, get_config_value from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import ALI_SAVE_PATH, BOTO_SAVE_PATH, diff --git a/tests/test_utils/test_timeout.py b/tests/test_utils/test_timeout.py index 49a49d27e..4f9cd47a8 100644 --- a/tests/test_utils/test_timeout.py +++ b/tests/test_utils/test_timeout.py @@ -65,14 +65,14 @@ def local_timeout(rank, _): def gpc_timeout(rank, world_size): - from internlm.initialize import initialize_distributed_env + from internlm.initialize import initialize_launcher os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12377" - initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + initialize_launcher(config=init_config, launcher="torch", distributed_port=12377, args_check=False) try: nccl_timeout_func(rank) diff --git a/tools/load_internlm2_model.py b/tools/load_internlm2_model.py index 4b639003e..aa3dcb636 100644 --- a/tools/load_internlm2_model.py +++ b/tools/load_internlm2_model.py @@ -10,8 +10,10 @@ from internlm.apis.inference import SequenceGenerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.initialize.launch import initialize_distributed_env -from internlm.train import initialize_model_and_parallel_communicator +from internlm.initialize import initialize_launcher +from internlm.initialize.initialize_model import ( + initialize_model_and_parallel_communicator, +) from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -180,7 +182,7 @@ def initialize_internlm_model( if gpc.is_rank_for_log(): logger.info(f"model_config: {model_config}.") - initialize_distributed_env( + initialize_launcher( config=dict( model_type=model_type, model=model_config, @@ -193,7 +195,7 @@ def initialize_internlm_model( ), launcher="torch" if use_torchrun_starter() else "slurm", seed=seed, - master_port=23574, + distributed_port=23574, args_check=False, ) # Directly get the origin model without NativeAMP wrapper. diff --git a/tools/moe_group_ckpt_converter.py b/tools/moe_group_ckpt_converter.py index d3fefb7c7..e07d6a273 100644 --- a/tools/moe_group_ckpt_converter.py +++ b/tools/moe_group_ckpt_converter.py @@ -8,7 +8,6 @@ from tqdm import tqdm sys.path.append(".") -import internlm # noqa: E402,F401 # pylint: disable=W0611,C0413 moe_str_prefix = None weight_key_suffix = ".weight" diff --git a/version.txt b/version.txt index be14282b7..c52db9804 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.5.3 +0.5.3 \ No newline at end of file From 2a1781793d9ae6612718f4d727d98a09ac685a5a Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 24 Feb 2025 16:45:28 +0800 Subject: [PATCH 03/32] fix ci --- internlm/checkpoint/utils.py | 64 ------------------------------------ 1 file changed, 64 deletions(-) diff --git a/internlm/checkpoint/utils.py b/internlm/checkpoint/utils.py index 0b81f2ef2..72036cac6 100644 --- a/internlm/checkpoint/utils.py +++ b/internlm/checkpoint/utils.py @@ -47,67 +47,3 @@ def process_load_info(load_info): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") return load_content_str, load_ckpt_folder, load_content - - -def init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP: - """ - Initialize Fully Sharded Data Parallel (FSDP) for the model. - This function is needed to properly initialize FSDP when resuming from a checkpoint. - It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. - - References: - https://github.com/pytorch/pytorch/issues/113496 - https://github.com/huggingface/transformers/pull/34032 - https://github.com/huggingface/transformers/issues/31892 - - Args: - model: The model to initialize with FSDP. - device: The device to run the model on. - - Returns: - The initialized FSDP model. - """ - model.train() - with torch.no_grad(): - # generate dummy packed sequence - seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz - input_ids = [1] * seq_len - label = input_ids[1:] + [-100] - cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len)) - - input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) - label = torch.tensor(label, device=device).unsqueeze(0) - indexes = torch.tensor( - list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])), - device=device, - ).unsqueeze(0) - cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0) - - data = { - "input_ids": input_ids, - "cu_seqlens": cu_seqlens, - "indexes": indexes, - "max_seqlen": seq_len, - } - - data_fns = [] - - # default data process function - if gpc.config.data.use_packed_dataset: - data_fns.append(packed_data_normalizer) - else: - data_fns.append(unpack_data) - - # support sequence parallel for isp - if is_using_isp(): - data_fns.append(split_data_for_sequence_parallel) - - # generate dummy_input - _data, _label = data, label - for fn in data_fns: - _data, _label = fn(_data, _label) - dummy_input = _data - - # run a forward pass with dummy_input to initialize FSDP - _ = model(**dummy_input) - return model From 9ef40eecd566530c0d5cb4f1a0ebed154b4513a5 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 24 Feb 2025 16:46:50 +0800 Subject: [PATCH 04/32] fix ci --- internlm/checkpoint/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/internlm/checkpoint/utils.py b/internlm/checkpoint/utils.py index 72036cac6..401bd54ec 100644 --- a/internlm/checkpoint/utils.py +++ b/internlm/checkpoint/utils.py @@ -2,10 +2,7 @@ # -*- encoding: utf-8 -*- from internlm.core.context import global_context as gpc -from internlm.core.parallel.shard import split_data_for_sequence_parallel -from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp logger = get_logger(__file__) From 99b3555336f0b38cbc05fdcbc778d6bfcfd67dde Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 24 Feb 2025 17:18:11 +0800 Subject: [PATCH 05/32] ljx adapt npu --- internlm/core/context/parallel_context.py | 7 +++++-- internlm/core/trainer_builder.py | 5 ++++- setup.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index fbe2b6247..c68fd9e99 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -11,7 +11,7 @@ import torch import torch.distributed as dist -from internlm.accelerator import get_accelerator +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.utils.common import SingletonMeta from internlm.utils.config import Config from internlm.utils.logger import get_logger @@ -309,7 +309,10 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, use_cpu (bool): whether to set up cpu process group. """ # initialize the default process group - init_method = f"tcp://[{host}]:{port}" + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + init_method = f"tcp://[{host}]:{port}" + else: + init_method = f"tcp://{host}:{port}" dist.init_process_group( rank=rank, world_size=world_size, diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 532da9494..cec7137e0 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -8,6 +8,7 @@ import torch.distributed as dist from torch.utils.data import DataLoader +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -48,6 +49,7 @@ # global llm logger logger = logging.getLogger(__file__) +internlm_accelerator = get_accelerator() class TrainerBuilder(Trainer): @@ -114,7 +116,8 @@ def __init__( criterion = self._initialize_criterion() # initialize cpu offload manager for selective checkpoint - initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) # initialize train state train_state = get_train_state(train_dl) diff --git a/setup.py b/setup.py index 2dbe90f91..5ba3feb38 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def get_requires() -> List[str]: extra_require = { "torch": ["torch>=2.1.0"], - "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "numpy==1.26.4", "scipy", "decorator"], } setup( From 55e605570216a52c1b957d5620d7931398fe24c0 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 24 Feb 2025 18:55:44 +0800 Subject: [PATCH 06/32] update npu adapt --- internlm/core/context/parallel_context.py | 7 ++----- internlm/core/parallel/comm/attn_offload.py | 5 ++++- internlm/core/trainer_builder.py | 5 +---- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index c68fd9e99..7e83129c8 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -11,7 +11,7 @@ import torch import torch.distributed as dist -from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.accelerator import get_accelerator from internlm.utils.common import SingletonMeta from internlm.utils.config import Config from internlm.utils.logger import get_logger @@ -309,10 +309,7 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, use_cpu (bool): whether to set up cpu process group. """ # initialize the default process group - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - init_method = f"tcp://[{host}]:{port}" - else: - init_method = f"tcp://{host}:{port}" + init_method = f"tcp://{host}:{port}" dist.init_process_group( rank=rank, world_size=world_size, diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py index da23f3ae8..02f1cd15d 100644 --- a/internlm/core/parallel/comm/attn_offload.py +++ b/internlm/core/parallel/comm/attn_offload.py @@ -1,8 +1,10 @@ import torch +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.utils.common import get_current_device global_attn_offload = None +internlm_accelerator = get_accelerator() class AttnOffloadManager: @@ -117,7 +119,8 @@ def preload_fa_output_with_layer(self, layer_idx): def initialize_offload_manager(enable_cpu_offload: bool = False): global global_attn_offload if global_attn_offload is None: - global_attn_offload = AttnOffloadManager(enable_cpu_offload) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + global_attn_offload = AttnOffloadManager(enable_cpu_offload) return global_attn_offload diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index cec7137e0..532da9494 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -8,7 +8,6 @@ import torch.distributed as dist from torch.utils.data import DataLoader -from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -49,7 +48,6 @@ # global llm logger logger = logging.getLogger(__file__) -internlm_accelerator = get_accelerator() class TrainerBuilder(Trainer): @@ -116,8 +114,7 @@ def __init__( criterion = self._initialize_criterion() # initialize cpu offload manager for selective checkpoint - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) # initialize train state train_state = get_train_state(train_dl) From 016712979d3ca20b44ebd24634edd6aa1899caad Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 25 Feb 2025 11:01:59 +0800 Subject: [PATCH 07/32] fix pylint --- .../model/model_implementations/transformers/modeling_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/internlm/model/model_implementations/transformers/modeling_moe.py b/internlm/model/model_implementations/transformers/modeling_moe.py index d58aa9fa3..54fc1cb5c 100644 --- a/internlm/model/model_implementations/transformers/modeling_moe.py +++ b/internlm/model/model_implementations/transformers/modeling_moe.py @@ -9,7 +9,6 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context from internlm.model.model_implementations.transformers.base_model import ( BaseTransformerModel, From 73c3ce1a3c863d27e5a04c6ae51fe2b2f7fe16dd Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 25 Feb 2025 11:12:39 +0800 Subject: [PATCH 08/32] update setup --- requirements/runtime.txt | 2 ++ requirements/torch.txt | 4 ---- setup.py | 8 +++++++- 3 files changed, 9 insertions(+), 5 deletions(-) delete mode 100644 requirements/torch.txt diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 621bc74b3..419fa22cb 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,6 +2,8 @@ transformers<4.47.0 sentencepiece datasets numpy +scipy +decorator tqdm einops boto3 diff --git a/requirements/torch.txt b/requirements/torch.txt deleted file mode 100644 index c9a04b3d8..000000000 --- a/requirements/torch.txt +++ /dev/null @@ -1,4 +0,0 @@ ---extra-index-url https://download.pytorch.org/whl/cu118 -torch==2.1.0+cu118 -torchvision==0.16.0+cu118 -torchaudio==2.1.0+cu118 diff --git a/setup.py b/setup.py index 5ba3feb38..673b86577 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import sys from typing import List from setuptools import find_packages, setup @@ -27,9 +28,14 @@ def get_requires() -> List[str]: extra_require = { "torch": ["torch>=2.1.0"], - "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "numpy==1.26.4", "scipy", "decorator"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "numpy==1.26.4"], } + +if sys.platform.startswith("linux"): + extra_require["torch"].append("flash-attn>=2.6.3") + + setup( name="InternEvo", version=get_version(), From 70865cee747d835909e118fb26ff03f686c83780 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 3 Mar 2025 11:26:16 +0800 Subject: [PATCH 09/32] remove unused settings and temporarily remove other model_implementations --- ci_scripts/train/ci_7B_sft.py | 2 - configs/1.8B_MoE16_sft.py | 2 - configs/57B_qwen2_MoE.py | 226 ------ configs/7B_MoE4_sft.py | 2 - configs/7B_baichuan2.py | 223 ------ configs/7B_gemma.py | 230 ------ configs/7B_internlm2.py | 2 - configs/7B_isp_sft.py | 2 - configs/7B_llama2.py | 4 +- configs/7B_qwen2.py | 230 ------ configs/8B_internlm3.py | 2 - configs/8x22B_mixtral.py | 227 ------ configs/8x7B_mixtral.py | 227 ------ configs/_base_/models/internlm2_1B.py | 2 - configs/_base_/models/internlm2_20B.py | 2 - configs/_base_/models/internlm2_7B.py | 2 - configs/_base_/models/internlm_20B.py | 2 - configs/_base_/models/internlm_7B.py | 2 - configs/demo_llava.py | 191 ----- doc/code-docs/source/example/20B_demo.rst | 2 - doc/code-docs/source/example/7B_demo.rst | 2 - doc/code-docs/source/initialize.rst | 2 - doc/code-docs/source/mixed_precision.rst | 2 - doc/en/usage.md | 6 - doc/usage.md | 4 - internlm/checkpoint/load_funcs.py | 4 - internlm/initialize/initialize_model.py | 8 +- .../model/model_implementations/builder.py | 2 - .../model/model_implementations/registry.py | 18 - .../transformers/modeling_baichuan2.py | 639 --------------- .../transformers/modeling_gemma.py | 752 ------------------ .../transformers/modeling_internlm.py | 118 --- .../transformers/modeling_llava.py | 248 ------ .../transformers/modeling_mixtral.py | 434 ---------- .../transformers/modeling_qwen2.py | 752 ------------------ .../transformers/modeling_qwen2_moe.py | 561 ------------- internlm/model/model_ops/llava/__init__.py | 0 .../model/model_ops/llava/clip_builder.py | 13 - .../model/model_ops/llava/clip_encoder.py | 82 -- .../model_ops/llava/projector_builder.py | 48 -- tests/common_fixture.py | 2 - tests/test_infer/test_generate.py | 1 - tests/test_model/test_model_internlm.py | 2 - tests/test_training/7B_check_acc.py | 2 - tests/test_training/7B_check_init.py | 2 - .../test_forward_output_no_fa.py | 2 - tests/test_training/test_load_ckpt_loss.py | 2 - .../test_swap_nb_loss_and_gradnorm.py | 2 - tests/test_utils/common_fixture.py | 2 - tools/README.md | 2 - tools/load_internlm2_model.py | 2 - web_demo_internlm.py | 4 - 52 files changed, 2 insertions(+), 5300 deletions(-) delete mode 100644 configs/57B_qwen2_MoE.py delete mode 100644 configs/7B_baichuan2.py delete mode 100644 configs/7B_gemma.py delete mode 100644 configs/7B_qwen2.py delete mode 100644 configs/8x22B_mixtral.py delete mode 100644 configs/8x7B_mixtral.py delete mode 100644 configs/demo_llava.py delete mode 100644 internlm/model/model_implementations/transformers/modeling_baichuan2.py delete mode 100644 internlm/model/model_implementations/transformers/modeling_gemma.py delete mode 100644 internlm/model/model_implementations/transformers/modeling_llava.py delete mode 100644 internlm/model/model_implementations/transformers/modeling_mixtral.py delete mode 100644 internlm/model/model_implementations/transformers/modeling_qwen2.py delete mode 100644 internlm/model/model_implementations/transformers/modeling_qwen2_moe.py delete mode 100644 internlm/model/model_ops/llava/__init__.py delete mode 100644 internlm/model/model_ops/llava/clip_builder.py delete mode 100644 internlm/model/model_ops/llava/clip_encoder.py delete mode 100644 internlm/model/model_ops/llava/projector_builder.py diff --git a/ci_scripts/train/ci_7B_sft.py b/ci_scripts/train/ci_7B_sft.py index fea45e124..591faf36c 100644 --- a/ci_scripts/train/ci_7B_sft.py +++ b/ci_scripts/train/ci_7B_sft.py @@ -101,14 +101,12 @@ model = dict( checkpoint=False, num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/configs/1.8B_MoE16_sft.py b/configs/1.8B_MoE16_sft.py index eca10b045..a8a58dc6f 100644 --- a/configs/1.8B_MoE16_sft.py +++ b/configs/1.8B_MoE16_sft.py @@ -136,14 +136,12 @@ model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=False, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/configs/57B_qwen2_MoE.py b/configs/57B_qwen2_MoE.py deleted file mode 100644 index 27f63cc1d..000000000 --- a/configs/57B_qwen2_MoE.py +++ /dev/null @@ -1,226 +0,0 @@ -JOB_NAME = "57b_qwen2_moe" -model_type = "QWEN2MOE" -DO_ALERT = False - -SEQ_LEN = 4096 -HIDDEN_SIZE = 3584 -NUM_ATTENTION_HEAD = 28 -NUM_KV_ATTENTION_HEAD = 4 -MLP_RATIO = 5 / 7 -NUM_LAYER = 28 -VOCAB_SIZE = 151936 - -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), - load_ckpt_folder="local:llm_ckpts/", - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=True, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None # "/path/to/dataset" -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - # packed_length = micro_bsz * SEQ_LEN - micro_bsz=2, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=50, - pack_sample_into_one=False, - total_steps=50000, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=False, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, - moe_loss_coeff=0.001, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] - num_attention_heads=NUM_ATTENTION_HEAD, - num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - max_position_embeddings=131072, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - norm_type="rmsnorm", - layer_norm_epsilon=1e-6, - use_flash_attn=True, - # Whether the odd and even columns of the query and key in the model are normally interleaved. - # If it's True, the model's odd and even columns are normally ordered; if it's False, - # it means that the model has prematurely concatenated all odd columns and even columns in front - # and back, in order to improve the RoPE's computational efficiency. - # Example: - # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] - # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] - qk_interleaved=False, - use_sliding_window=False, - rope_base=1000000, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. - moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless" - num_experts=64, - num_shared_experts=8, - top_k=8, -) -""" -zero1 parallel (dict): - 1. size: int - * if size <= 0, the size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. - * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -tensor parallel (dict): - 1. size: int, the size of tensor parallel. - 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], - defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. - msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. - fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. - isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, - defaults to False. -weight parallel (dict): - 1. size: int, the size of weight parallel. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -expert parallel (dict): - 1. size: int - * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size - to be the number of experts to make sure each device has one expert. - * if size == 1, all experts are placed in each device, running as dp-only. - * if size > 1, all experts are placed in k devices and each device has n/k experts, where n is the total - number of experts and k = size. -expert weight parallel (dict): - 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -""" -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), - expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) \ No newline at end of file diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 74ebbcbb6..4b494d9f5 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -149,14 +149,12 @@ model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/configs/7B_baichuan2.py b/configs/7B_baichuan2.py deleted file mode 100644 index 9957d6819..000000000 --- a/configs/7B_baichuan2.py +++ /dev/null @@ -1,223 +0,0 @@ -JOB_NAME = "7b_baichuan2_train" -model_type = "BAICHUAN2" -DO_ALERT = False - -VOCAB_SIZE = 125696 -SEQ_LEN = 2048 -HIDDEN_SIZE = 4096 -NUM_ATTENTION_HEAD = 32 -MLP_RATIO = 8 / 3 -NUM_LAYER = 32 - - -MODEL_ONLY_FOLDER = "local:llm_ckpts_baichuan2/xxxx" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts_baichuan2" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=False, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - # packed_length = micro_bsz * SEQ_LEN - micro_bsz=1, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=0, - pack_sample_into_one=False, - total_steps=20, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=True, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, - num_chunks=1, - num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - no_bias=True, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", - norm_type="rmsnorm", - layer_norm_epsilon=1e-6, - use_flash_attn=True, - # Whether the odd and even columns of the query and key in the model are normally interleaved. - # If it's True, the model's odd and even columns are normally ordered; if it's False, - # it means that the model has prematurely concatenated all odd columns and even columns in front - # and back, in order to improve the RoPE's computational efficiency. - # Example: - # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] - # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] - qk_interleaved=False, -) - -""" -zero1 parallel (dict): - 1. size: int - * if size <= 0, the size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. - * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -tensor parallel (dict): - 1. size: int, the size of tensor parallel. - 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], - defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. - msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. - fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. - isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, - defaults to False. -weight parallel (dict): - 1. size: int, the size of weight parallel. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -""" -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) - -# metric_dtype can be "fp32" or other string -# only when set to "fp32" will use fp32 to calc in metrics -# metric_dtype = "fp32" - -generation = dict( - ckpt_folder="/path/to/saved/ckpt", - output_folder="/path/to/save/generation", - batch_size=1, - eos_id=[2, 0], - bos_id=1, - max_length=100, - do_sample=True, - temperature=1.0, - top_k=50, - top_p=1.0, - repetition_penalty=1, - length_penalty=1.0, -) diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py deleted file mode 100644 index 643bcbdbf..000000000 --- a/configs/7B_gemma.py +++ /dev/null @@ -1,230 +0,0 @@ -JOB_NAME = "7b_gemma_train" -model_type = "GEMMA" -DO_ALERT = False - -VOCAB_SIZE = 256000 -SEQ_LEN = 2048 -HIDDEN_SIZE = 3072 -NUM_ATTENTION_HEAD = 16 -NUM_KV_ATTENTION_HEAD = 16 -HEAD_DIM = 256 -MLP_RATIO = 8 -NUM_LAYER = 28 - - -MODEL_ONLY_FOLDER = "local:llm_ckpts_gemma/xxxx" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts_gemma" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=False, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - # packed_length = micro_bsz * SEQ_LEN - micro_bsz=1, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=0, - pack_sample_into_one=False, - total_steps=20, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=True, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, - num_chunks=1, - num_attention_heads=NUM_ATTENTION_HEAD, - num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - max_position_embeddings=8192, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - no_bias=True, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", - add_unit_offset=True, - norm_type="rmsnorm", - layer_norm_epsilon=1e-6, - head_dim=HEAD_DIM, - use_flash_attn=True, - # Whether the odd and even columns of the query and key in the model are normally interleaved. - # If it's True, the model's odd and even columns are normally ordered; if it's False, - # it means that the model has prematurely concatenated all odd columns and even columns in front - # and back, in order to improve the RoPE's computational efficiency. - # Example: - # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] - # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] - qk_interleaved=False, - use_swiglu=False, -) - -""" -zero1 parallel (dict): - 1. size: int - * if size <= 0, the size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. - * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -tensor parallel (dict): - 1. size: int, the size of tensor parallel. - 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], - defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. - msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. - fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. - isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, - defaults to False. -weight parallel (dict): - 1. size: int, the size of weight parallel. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -""" -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) - -# metric_dtype can be "fp32" or other string -# only when set to "fp32" will use fp32 to calc in metrics -# metric_dtype = "fp32" - -generation = dict( - ckpt_folder="/path/to/saved/ckpt", - output_folder="/path/to/save/generation", - batch_size=1, - eos_id=[2, 0], - bos_id=1, - max_length=100, - do_sample=True, - temperature=1.0, - top_k=50, - top_p=1.0, - repetition_penalty=1, - length_penalty=1.0, -) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 3c7bb9f4f..2126d7470 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -142,7 +142,6 @@ checkpoint=False, num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, @@ -150,7 +149,6 @@ num_layers=NUM_LAYER, no_bias=True, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 9b53d50e4..158e6868e 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -163,7 +163,6 @@ checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, @@ -171,7 +170,6 @@ num_layers=NUM_LAYER, no_bias=True, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py index 7783abaf7..0161d78e0 100644 --- a/configs/7B_llama2.py +++ b/configs/7B_llama2.py @@ -130,7 +130,6 @@ checkpoint=False, num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, @@ -138,7 +137,6 @@ num_layers=NUM_LAYER, no_bias=True, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, @@ -152,7 +150,7 @@ # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] qk_interleaved=False, - mlp_layer_fusion=True, + mlp_layer_fusion=False, enable_qkv_fusion=True, ) diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py deleted file mode 100644 index 3622e12f1..000000000 --- a/configs/7B_qwen2.py +++ /dev/null @@ -1,230 +0,0 @@ -JOB_NAME = "7b_qwen2_train" -model_type = "QWEN2" -DO_ALERT = False - -VOCAB_SIZE = 152064 -SEQ_LEN = 2048 -HIDDEN_SIZE = 3584 -NUM_ATTENTION_HEAD = 28 -NUM_KV_ATTENTION_HEAD = 4 -MLP_RATIO = 5.25 -NUM_LAYER = 28 - - -MODEL_ONLY_FOLDER = "local:llm_ckpts_qwen2/xxxx/" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts_qwen2" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=False, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - # packed_length = micro_bsz * SEQ_LEN - micro_bsz=1, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=0, - pack_sample_into_one=False, - total_steps=20, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=True, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, - num_chunks=1, - num_attention_heads=NUM_ATTENTION_HEAD, - num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - qkv_bias=True, - o_bias=False, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", - norm_type="rmsnorm", - layer_norm_epsilon=1e-6, - use_flash_attn=True, - # Whether the odd and even columns of the query and key in the model are normally interleaved. - # If it's True, the model's odd and even columns are normally ordered; if it's False, - # it means that the model has prematurely concatenated all odd columns and even columns in front - # and back, in order to improve the RoPE's computational efficiency. - # Example: - # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] - # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] - qk_interleaved=False, - rope_base=1000000, - use_sliding_window=False, - sliding_window=32768, - max_window_layers=28, -) - -""" -zero1 parallel (dict): - 1. size: int - * if size <= 0, the size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. - * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -tensor parallel (dict): - 1. size: int, the size of tensor parallel. - 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], - defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. - msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. - fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. - isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, - defaults to False. -weight parallel (dict): - 1. size: int, the size of weight parallel. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -""" -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) - -# metric_dtype can be "fp32" or other string -# only when set to "fp32" will use fp32 to calc in metrics -# metric_dtype = "fp32" - -generation = dict( - ckpt_folder="/path/to/saved/ckpt", - output_folder="/path/to/save/generation", - batch_size=1, - eos_id=[2, 0], - bos_id=1, - max_length=100, - do_sample=True, - temperature=1.0, - top_k=50, - top_p=1.0, - repetition_penalty=1, - length_penalty=1.0, -) diff --git a/configs/8B_internlm3.py b/configs/8B_internlm3.py index acb04d446..9f5840c05 100644 --- a/configs/8B_internlm3.py +++ b/configs/8B_internlm3.py @@ -153,7 +153,6 @@ num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, @@ -161,7 +160,6 @@ num_layers=NUM_LAYER, no_bias=True, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/configs/8x22B_mixtral.py b/configs/8x22B_mixtral.py deleted file mode 100644 index f1f1b6e60..000000000 --- a/configs/8x22B_mixtral.py +++ /dev/null @@ -1,227 +0,0 @@ -JOB_NAME = "22b_moe_mixtral" -model_type = "MIXTRALMOE" -DO_ALERT = False - -SEQ_LEN = 4096 -HIDDEN_SIZE = 6144 -NUM_ATTENTION_HEAD = 48 -NUM_KV_ATTENTION_HEAD = 8 -MLP_RATIO = 8 / 3 -NUM_LAYER = 56 -VOCAB_SIZE = 32000 - -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), - load_ckpt_folder="local:llm_ckpts/", - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=True, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None # "/path/to/dataset" -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - # packed_length = micro_bsz * SEQ_LEN - micro_bsz=2, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=50, - pack_sample_into_one=False, - total_steps=50000, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=False, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, - moe_loss_coeff=0.001, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] - num_attention_heads=NUM_ATTENTION_HEAD, - num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - max_position_embeddings=65536, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - qkv_bias=False, - o_bias=False, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - norm_type="rmsnorm", - layer_norm_epsilon=1e-5, - use_flash_attn=True, - # Whether the odd and even columns of the query and key in the model are normally interleaved. - # If it's True, the model's odd and even columns are normally ordered; if it's False, - # it means that the model has prematurely concatenated all odd columns and even columns in front - # and back, in order to improve the RoPE's computational efficiency. - # Example: - # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] - # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] - qk_interleaved=False, - use_sliding_window=False, - rope_base=1000000, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. - moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless" - num_experts=8, - top_k=2, -) -""" -zero1 parallel (dict): - 1. size: int - * if size <= 0, the size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. - * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -tensor parallel (dict): - 1. size: int, the size of tensor parallel. - 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], - defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. - msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. - fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. - isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, - defaults to False. -weight parallel (dict): - 1. size: int, the size of weight parallel. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -expert parallel (dict): - 1. size: int - * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size - to be the number of experts to make sure each device has one expert. - * if size == 1, all experts are placed in each device, running as dp-only. - * if size > 1, all experts are placed in k devices and each device has n/k experts, where n is the total - number of experts and k = size. -expert weight parallel (dict): - 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -""" -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), - expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) diff --git a/configs/8x7B_mixtral.py b/configs/8x7B_mixtral.py deleted file mode 100644 index 6db43f9c6..000000000 --- a/configs/8x7B_mixtral.py +++ /dev/null @@ -1,227 +0,0 @@ -JOB_NAME = "7b_moe_mixtral" -model_type = "MIXTRALMOE" -DO_ALERT = False - -SEQ_LEN = 4096 -HIDDEN_SIZE = 4096 -NUM_ATTENTION_HEAD = 32 -NUM_KV_ATTENTION_HEAD = 8 -MLP_RATIO = 3.5 -NUM_LAYER = 32 -VOCAB_SIZE = 32000 - -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), - load_ckpt_folder="local:llm_ckpts/", - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=True, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None # "/path/to/dataset" -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - # packed_length = micro_bsz * SEQ_LEN - micro_bsz=2, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=50, - pack_sample_into_one=False, - total_steps=50000, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=False, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, - moe_loss_coeff=0.02, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] - num_attention_heads=NUM_ATTENTION_HEAD, - num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - max_position_embeddings=32768, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - qkv_bias=False, - o_bias=False, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - norm_type="rmsnorm", - layer_norm_epsilon=1e-5, - use_flash_attn=True, - # Whether the odd and even columns of the query and key in the model are normally interleaved. - # If it's True, the model's odd and even columns are normally ordered; if it's False, - # it means that the model has prematurely concatenated all odd columns and even columns in front - # and back, in order to improve the RoPE's computational efficiency. - # Example: - # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] - # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] - qk_interleaved=False, - use_sliding_window=False, - rope_base=1000000, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. - moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless" - num_experts=8, - top_k=2, -) -""" -zero1 parallel (dict): - 1. size: int - * if size <= 0, the size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. - * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -tensor parallel (dict): - 1. size: int, the size of tensor parallel. - 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], - defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. - msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. - fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. - isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, - defaults to False. -weight parallel (dict): - 1. size: int, the size of weight parallel. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -expert parallel (dict): - 1. size: int - * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size - to be the number of experts to make sure each device has one expert. - * if size == 1, all experts are placed in each device, running as dp-only. - * if size > 1, all experts are placed in k devices and each device has n/k experts, where n is the total - number of experts and k = size. -expert weight parallel (dict): - 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. - 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. -""" -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), - expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index cc3f186ad..f4cfef8aa 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -14,7 +14,6 @@ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. checkpoint=0.2, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - embed_split_hidden=True, num_layers=NUM_LAYER, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE, @@ -26,7 +25,6 @@ multiple_of=MULTIPLE_OF, norm_type="rmsnorm", qk_interleaved=False, - apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, rope_base=1000000, diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py index dc461c0da..f0fea954e 100644 --- a/configs/_base_/models/internlm2_20B.py +++ b/configs/_base_/models/internlm2_20B.py @@ -13,7 +13,6 @@ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. checkpoint=1.0, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - embed_split_hidden=True, num_layers=NUM_LAYER, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE, @@ -24,7 +23,6 @@ mlp_ratio=MLP_RATIO, norm_type="rmsnorm", qk_interleaved=False, - apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, rope_base=1000000, diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py index cbdb03cb1..06b27693b 100644 --- a/configs/_base_/models/internlm2_7B.py +++ b/configs/_base_/models/internlm2_7B.py @@ -13,7 +13,6 @@ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. checkpoint=0.2, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - embed_split_hidden=True, num_layers=NUM_LAYER, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE, @@ -24,7 +23,6 @@ mlp_ratio=MLP_RATIO, norm_type="rmsnorm", qk_interleaved=True, - apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, rope_base=1000000, diff --git a/configs/_base_/models/internlm_20B.py b/configs/_base_/models/internlm_20B.py index 26f4ff7f8..2f7ff0c8c 100644 --- a/configs/_base_/models/internlm_20B.py +++ b/configs/_base_/models/internlm_20B.py @@ -12,7 +12,6 @@ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - embed_split_hidden=True, num_layers=NUM_LAYER, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE, @@ -21,7 +20,6 @@ num_attention_heads=NUM_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - apply_post_layer_norm=False, layer_norm_epsilon=1e-5, ) diff --git a/configs/_base_/models/internlm_7B.py b/configs/_base_/models/internlm_7B.py index 8dde6e4e4..4b63c7ded 100644 --- a/configs/_base_/models/internlm_7B.py +++ b/configs/_base_/models/internlm_7B.py @@ -12,7 +12,6 @@ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" - embed_split_hidden=True, num_layers=NUM_LAYER, hidden_size=HIDDEN_SIZE, vocab_size=VOCAB_SIZE, @@ -21,7 +20,6 @@ num_attention_heads=NUM_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - apply_post_layer_norm=False, layer_norm_epsilon=1e-5, ) diff --git a/configs/demo_llava.py b/configs/demo_llava.py deleted file mode 100644 index e138e886a..000000000 --- a/configs/demo_llava.py +++ /dev/null @@ -1,191 +0,0 @@ -JOB_NAME = "llava_train" -model_type = "LLAVA" -DO_ALERT = False - -VOCAB_SIZE = 32000 -SEQ_LEN = 2048 -HIDDEN_SIZE = 4096 -NUM_ATTENTION_HEAD = 32 -NUM_KV_ATTENTION_HEAD = 8 -MLP_RATIO = 3.5 -NUM_LAYER = 32 - - -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" -# Ckpt folder format: -# fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" - -# boto3 Ckpt folder format: -# import os -# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint -# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" -CHECKPOINT_EVERY = 50 -ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. - save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering - # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) - # with an automatic restart mechanism upon training reboot. - # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint - # path specified in `load_ckpt_info` by default. - # If you want to initialize your model weights from another model, you must set `auto_resume` to False. - # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=False, - checkpoint_every=CHECKPOINT_EVERY, - async_upload=True, # async ckpt upload. (only work for boto3 ckpt) - async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. - oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. -) - -TRAIN_FOLDER = None -VALID_FOLDER = None # "/path/to/dataset" -data = dict( - is_multimodal=True, - seq_len=SEQ_LEN, - # micro_num means the number of micro_batch contained in one gradient update - micro_num=4, - packed_length=SEQ_LEN, - micro_bsz=1, - # defaults to the value of micro_num - valid_micro_num=4, - # defaults to 0, means disable evaluate - valid_every=0, - pack_sample_into_one=False, - total_steps=200, - skip_batches="", - # rampup_batch_size (str): A string with three space-separated integers representing the - # starting batch size, the increment, and the number of steps between - # each increment. For example, "192 24 8" means that the batch size (micro_num) - # starts at 192 and increases by 24 every 8 steps. Defaults to None. - # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size="", - # Datasets with less than 50 rows will be discarded - min_length=50, - train_folder=TRAIN_FOLDER, - valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=200, - diag_outlier_ratio=1.1, - image_size=336, - patch_size=14, -) - -grad_scaler = dict( - fp16=dict( - # the initial loss scale, defaults to 2**16 - initial_scale=2**16, - # the minimum loss scale, defaults to None - min_scale=1, - # the number of steps to increase loss scale when no overflow occurs - growth_interval=1000, - ), - # the multiplication factor for increasing loss scale, defaults to 2 - growth_factor=2, - # the multiplication factor for decreasing loss scale, defaults to 0.5 - backoff_factor=0.5, - # the maximum loss scale, defaults to None - max_scale=2**24, - # the number of overflows before decreasing loss scale, defaults to 2 - hysteresis=2, -) - -hybrid_zero_optimizer = dict( - # Enable low_level_optimzer overlap_communication - overlap_sync_grad=True, - overlap_sync_param=False, - # bucket size for nccl communication params - reduce_bucket_size=512 * 1024 * 1024, - # grad clipping - clip_grad_norm=1.0, -) - -loss = dict( - label_smoothing=0, -) - -adam = dict( - lr=1e-4, - adam_beta1=0.9, - adam_beta2=0.95, - adam_beta2_c=0, - adam_eps=1e-8, - weight_decay=0.01, -) - -lr_scheduler = dict( - total_steps=data["total_steps"], - init_steps=0, # optimizer_warmup_step - warmup_ratio=0.01, - eta_min=1e-5, - last_epoch=-1, -) - -beta2_scheduler = dict( - init_beta2=adam["adam_beta2"], - c=adam["adam_beta2_c"], - cur_iter=-1, -) - -use_fp32_norm = False -model = dict( - checkpoint=False, - num_chunks=1, - num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, - vocab_size=VOCAB_SIZE, - embed_grad_scale=1, - parallel_output=True, - hidden_size=HIDDEN_SIZE, - num_layers=NUM_LAYER, - no_bias=True, - mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, - dtype="torch.bfloat16", - norm_type="rmsnorm", - layer_norm_epsilon=1e-5, - num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, - use_flash_attn=True, - image_token_id=200000, - vit_cfg=dict( - mm_projector_type="mlp2x_gelu", - mm_use_im_patch_token=True, - mm_use_im_start_end=True, - mm_vision_select_feature="patch", - mm_vision_select_layer=-2, - mm_vision_tower="openai/clip-vit-large-patch14-336", - ), - vision_proj_cfg=dict( - mm_projector_type="mlp2x_gelu", - mm_hidden_size=1024, # vit hidden_size - hidden_size=HIDDEN_SIZE, # llm hidden_size - ), -) - -parallel = dict( - zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), -) - -cudnn_deterministic = False -cudnn_benchmark = False - -monitor = dict( - # feishu alert configs - alert=dict( - enable_feishu_alert=DO_ALERT, - feishu_alert_address=None, # feishu webhook to send alert message - light_monitor_address=None, # light_monitor address to send heartbeat - alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", - ), - tensorboard=dict( - queue_max_length=10, - ), -) - -# metric_dtype can be "fp32" or other string -# only when set to "fp32" will use fp32 to calc in metrics -# metric_dtype = "fp32" diff --git a/doc/code-docs/source/example/20B_demo.rst b/doc/code-docs/source/example/20B_demo.rst index 0fd0d0221..232d810b2 100644 --- a/doc/code-docs/source/example/20B_demo.rst +++ b/doc/code-docs/source/example/20B_demo.rst @@ -123,14 +123,12 @@ model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/doc/code-docs/source/example/7B_demo.rst b/doc/code-docs/source/example/7B_demo.rst index 67df4261e..92f9b0307 100644 --- a/doc/code-docs/source/example/7B_demo.rst +++ b/doc/code-docs/source/example/7B_demo.rst @@ -123,14 +123,12 @@ model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst index 9b7ee3b3c..4c938c30b 100644 --- a/doc/code-docs/source/initialize.rst +++ b/doc/code-docs/source/initialize.rst @@ -58,14 +58,12 @@ InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制 model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/doc/code-docs/source/mixed_precision.rst b/doc/code-docs/source/mixed_precision.rst index bbada7f77..774c620f7 100644 --- a/doc/code-docs/source/mixed_precision.rst +++ b/doc/code-docs/source/mixed_precision.rst @@ -63,14 +63,12 @@ InternEvo支持使用TF32训练模型,允许用户在config文件中将 ``dtyp model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.tf32", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/doc/en/usage.md b/doc/en/usage.md index f8ae268a1..17cf88e98 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -229,14 +229,12 @@ beta2_scheduler = dict( model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, @@ -359,14 +357,12 @@ MLP_RATIO = 8 / 3 model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, @@ -455,14 +451,12 @@ MLP_RATIO = 8 / 3 model = dict( checkpoint=False, # 进行重计算的模型层数比例,可选值为 True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/doc/usage.md b/doc/usage.md index cba2b4be2..b28144cca 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -238,14 +238,12 @@ use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, @@ -398,14 +396,12 @@ MLP_RATIO = 8 / 3 model = dict( checkpoint=False, # 进行重计算的模型层数比例,可选值为 True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index 5b9ad74de..28f4e06c8 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -1,8 +1,5 @@ # Copyright (c) InternLM. All rights reserved. -from internlm.model.model_implementations.transformers.modeling_internlm import ( - InternLM1, -) from internlm.model.model_implementations.transformers.modeling_internlm2 import ( InternLM2, ) @@ -13,6 +10,5 @@ LOAD_FUNC_DICT = { "llama": Llama2.load_llama_pretrained_weights, - "internlm_test": InternLM1.load_internlm_with_dynamic_parallel_size, "internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size, } diff --git a/internlm/initialize/initialize_model.py b/internlm/initialize/initialize_model.py index 9e8c46342..352666dcc 100644 --- a/internlm/initialize/initialize_model.py +++ b/internlm/initialize/initialize_model.py @@ -29,7 +29,6 @@ ScaleColumnParallelLinear, ) from internlm.model.model_ops.moe import Experts, MoE -from internlm.model.model_ops.moe.moe import Qwen2MoE from internlm.model.model_ops.ops.norm import RMSNorm from internlm.utils.parallel import ( is_replica_expert_data_parallel_parameter, @@ -100,7 +99,7 @@ def _check_module(name, module): for param in module.parameters(): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - if isinstance(module, (MoE, Qwen2MoE)): + if isinstance(module, MoE): for param in module.moe_layer.gate.parameters(): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) if hasattr(module, "coefficient"): @@ -140,11 +139,6 @@ def _check_module(name, module): elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) - # for vit and vit project - if "vision_tower" in name.lower() or "vision_proj" in name.lower(): - for param in module.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - for _chunk in unwrap_naive_amp(model): if not is_using_fsdp(): # set param parallel attribute diff --git a/internlm/model/model_implementations/builder.py b/internlm/model/model_implementations/builder.py index 168a00df5..8b3113faa 100644 --- a/internlm/model/model_implementations/builder.py +++ b/internlm/model/model_implementations/builder.py @@ -39,8 +39,6 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: # TODO: fix use_flash_attn parameter config kwargs.pop("use_flash_attn", False) - kwargs.pop("apply_post_layer_norm") - kwargs.pop("embed_split_hidden", True) kwargs["checkpoint"] = float(kwargs.get("checkpoint", False)) kwargs["device"] = get_current_device() diff --git a/internlm/model/model_implementations/registry.py b/internlm/model/model_implementations/registry.py index a7857e7f8..6a21a79ff 100644 --- a/internlm/model/model_implementations/registry.py +++ b/internlm/model/model_implementations/registry.py @@ -4,10 +4,6 @@ from typing import Callable -from internlm.model.model_implementations.transformers.modeling_baichuan2 import ( - Baichuan2, -) -from internlm.model.model_implementations.transformers.modeling_gemma import Gemma from internlm.model.model_implementations.transformers.modeling_internlm import ( InternLM1, ) @@ -15,15 +11,7 @@ InternLM2, ) from internlm.model.model_implementations.transformers.modeling_llama import Llama2 -from internlm.model.model_implementations.transformers.modeling_llava import Llava -from internlm.model.model_implementations.transformers.modeling_mixtral import ( - MixtralMoE, -) from internlm.model.model_implementations.transformers.modeling_moe import Internlm1MoE -from internlm.model.model_implementations.transformers.modeling_qwen2 import Qwen2 -from internlm.model.model_implementations.transformers.modeling_qwen2_moe import ( - Qwen2Moe, -) from internlm.utils.common import SingletonMeta from internlm.utils.utils import ModelType @@ -99,9 +87,3 @@ def register_model_initializer() -> None: model_initializer.register_module(ModelType.INTERNLM3.name, InternLM2) model_initializer.register_module(ModelType.LLAMA2.name, Llama2) model_initializer.register_module(ModelType.INTERNLM_MoE.name, Internlm1MoE) - model_initializer.register_module(ModelType.LLAVA.name, Llava) - model_initializer.register_module(ModelType.QWEN2.name, Qwen2) - model_initializer.register_module(ModelType.BAICHUAN2.name, Baichuan2) - model_initializer.register_module(ModelType.GEMMA.name, Gemma) - model_initializer.register_module(ModelType.QWEN2MOE.name, Qwen2Moe) - model_initializer.register_module(ModelType.MIXTRALMOE.name, MixtralMoE) diff --git a/internlm/model/model_implementations/transformers/modeling_baichuan2.py b/internlm/model/model_implementations/transformers/modeling_baichuan2.py deleted file mode 100644 index 09bde6c3e..000000000 --- a/internlm/model/model_implementations/transformers/modeling_baichuan2.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright (c) InternLM. All rights reserved. -import math -import os -from typing import Optional - -import torch -from einops import rearrange -from torch import nn -from tqdm import tqdm - -from internlm.accelerator import get_accelerator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.model.model_implementations.transformers.base_model import ( - BaseTransformerModel, -) -from internlm.model.model_implementations.transformers.utils import ( - normal_, - scaled_init_method_normal, - scaled_init_method_uniform, - uniform_, -) -from internlm.model.model_ops.modules.embedding import Embedding1D -from internlm.model.model_ops.modules.linear import new_linear -from internlm.model.model_ops.modules.mha import MHA -from internlm.model.model_ops.modules.mlp import new_feed_forward -from internlm.model.model_ops.modules.norm import new_layer_norm -from internlm.model.model_ops.utils import ( - convert_attn_args_to_kwargs, - convert_attn_kwargs_to_args, -) -from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.utils.logger import get_logger -from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) - -internlm_accelerator = get_accelerator() -logger = get_logger(__file__) - - -class Baichuan2Decoder(nn.Module): - """ - 1D Packed Flash Llama Layer. - - Args: - hidden_size (int): The hidden size of model. 768 by default. - num_attention_heads (int): The number of attention heads. 12 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. - dtype (torch.dtype): Type of data. torch.float by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_idx (int): The index of current layer. 0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.006 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - hidden_size: int = 768, - num_attention_heads: int = 12, - mlp_ratio: int = 4, - attn_drop_rate: float = 0, - drop_rate: float = 0.0, - dtype: torch.dtype = torch.float, - layer_norm_epsilon: float = 1e-6, - checkpoint: bool = False, - layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - residual_in_fp32: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm: bool = False, - fused_dropout_add_ln: bool = True, - no_bias: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - attn_wqkv_init_std: float = 0.006, - attn_other_init_std: float = 0.0015, - ffn_uplayer_init_std: float = 0.006, - ffn_other_init_std: float = 0.0015, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - max_position_embeddings: int = 2048, - ): - super().__init__() - self.checkpoint = checkpoint - # dropout selective checkpoint can only be enabled when checkpoint is disabled. - self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False - self.layer_idx = layer_idx - self.prenorm = not apply_post_layer_norm - assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" - self.fused_dropout_add_ln = fused_dropout_add_ln - self.attn_wqkv_init_std = attn_wqkv_init_std - self.attn_other_init_std = attn_other_init_std - self.ffn_uplayer_init_std = ffn_uplayer_init_std - self.ffn_other_init_std = ffn_other_init_std - - head_dim = hidden_size // num_attention_heads - - self.attention = MHA( - embed_dim=hidden_size, - num_heads=num_attention_heads, - max_position_embeddings=max_position_embeddings, - bias=not no_bias, - dropout=attn_drop_rate, - softmax_scale=1 / math.sqrt(head_dim), - causal=True, - layer_idx=layer_idx, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - rope_base=rope_base, - rotary_emb_dim=head_dim, - rotary_emb_scale_base=0, - device=device, - dtype=dtype, - qk_interleaved=qk_interleaved, - enable_qkv_fusion=True, - out_bias=False, - ) - - self.dropout1 = nn.Dropout(drop_rate) - self.dropout2 = nn.Dropout(drop_rate) - self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.feed_forward = new_feed_forward( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - # TODO: to support more activation functions - activation_type="swiglu" if use_swiglu else "gelu", - ) - - self.use_swiglu = use_swiglu - self.use_scaled_init = use_scaled_init - self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm - self.return_residual = False - - if init_type == "normal": - self.init_func = normal_ - self.scaled_init_func = scaled_init_method_normal - else: - self.init_func = uniform_ - self.scaled_init_func = scaled_init_method_uniform - - self.reset_parameters() - - def reset_parameters(self): - with torch.no_grad(): - for name, param in self.attention.named_parameters(): - if param.ndim == 1: - param.data.zero_() - elif "wq" in name or "wk" in name or "wv" in name: - self.init_func(std=self.attn_wqkv_init_std)(param.data) - elif self.use_scaled_init: # wo - self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.attn_other_init_std)(param.data) - - for name, param in self.feed_forward.named_parameters(): - if self.use_swiglu: - if self.use_scaled_init and "w2" in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - # candidate: w1, w3, fused_w1_w3 - self.init_func( - std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std - )(param.data) - else: - if self.use_scaled_init and "fc1" not in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( - param.data - ) - - def forward(self, hidden_states, residual=None, **kwargs): - if self.checkpoint and self.training: - args = convert_attn_kwargs_to_args(kwargs) - return activation_checkpoint(self._forward, False, hidden_states, residual, *args) - else: - return self._forward(hidden_states, residual, **kwargs) - - def _forward(self, hidden_states, residual, *args, **kwargs): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Attn/MLP(LN(residual)) - cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 - indexes: the length of index is same as hidden states, which stand for the current position - """ - if self.prenorm: - - def _dropout_and_norm_attn(_residual, _hidden_states): - _dropped = self.dropout1(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) - else: - residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - hidden_states = self.attention(hidden_states, **mixer_kwargs) - - if not isinstance(self.feed_forward, nn.Identity): - if not self.fused_dropout_add_ln: - - def _dropout_and_norm_ffn(_residual, _hidden_states): - _dropped = self.dropout2(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint( - _dropout_and_norm_ffn, False, residual, hidden_states - ) - else: - residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - hidden_states = self.feed_forward(hidden_states) - - return hidden_states + residual - else: - assert residual is None - - mixer_out = self.attention(hidden_states, **kwargs) - if self.return_residual: # mixer out is actually a pair here - mixer_out, hidden_states = mixer_out - hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( - dtype=self.attention_norm.weight.dtype - ) - if not isinstance(self.feed_forward, nn.Identity): - mlp_out = self.feed_forward(hidden_states) - if self.return_residual: # mlp out is actually a pair here - mlp_out, hidden_states = mlp_out - hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( - dtype=self.ffn_norm.weight.dtype - ) - return hidden_states - - -class Baichuan2(BaseTransformerModel): - """ - 1D Packed Flash Llama. - - Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. - drop_rate (float): The dropout rate of input hidden state. 0.0 by default. - dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. - first (bool): Whether input embedding layer or not. False by default. - last (bool): Whether output embedding layer or not. False by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - start_layer_idx (int): The index of start layer in the pipeline. 0 by default. - device (Optional[Union[str, torch.device]]): The device will be used. None by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. - embedding_init_std (float): std used to init embedding weight. 0.0052 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.006 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.0052 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - mlp_ratio: int = 4, - attn_drop_rate: float = 0.0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - checkpoint: float = 1.0, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False, - embed_grad_scale: float = 0.1, - parallel_output: bool = True, - start_layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm=False, - no_bias=False, - residual_in_fp32: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - is_reward: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - embedding_init_std: float = 0.0052, - attn_wqkv_init_std: float = 0.006, - attn_other_init_std: float = 0.0015, - ffn_uplayer_init_std: float = 0.006, - ffn_other_init_std: float = 0.0015, - out_head_init_std: float = 0.0052, - init_type: str = "normal", - norm_head: bool = False, - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - ): - super().__init__() - - checkpoint_layer_num = int(num_layers * checkpoint) - self.embed_grad_scale = embed_grad_scale - self.parallel_output = parallel_output - - if first: - self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - - for _, param in self.tok_embeddings.named_parameters(): - if init_type == "normal": - normal_(std=embedding_init_std)(param) - else: - uniform_(std=embedding_init_std)(param) - - self.layers = nn.ModuleList( - [ - Baichuan2Decoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - mlp_ratio=mlp_ratio, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, - layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - residual_in_fp32=residual_in_fp32, - device=device, - apply_post_layer_norm=apply_post_layer_norm, - fused_dropout_add_ln=False, - no_bias=no_bias, - norm_type=norm_type, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - qk_interleaved=qk_interleaved, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - for lid in range(num_layers) - ] - ) - - if last: - if not apply_post_layer_norm: - self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.output = new_linear( - name="output", - in_features=hidden_size, - out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - norm_head=norm_head, - ) - - for _, param in self.output.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - def forward(self, hidden_states=None, input_ids=None, **kwargs): - # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "tok_embeddings") and input_ids is not None: - hidden_states = self.tok_embeddings(input_ids) - if self.embed_grad_scale != 1: - hidden_states = ( - self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() - ) - - for _, block in enumerate(self.layers): - hidden_states = block(hidden_states, residual=None, **kwargs) - - if hasattr(self, "norm"): - hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) - if hasattr(self, "output"): - hidden_states = self.output(hidden_states) - - return hidden_states - - @staticmethod - def load_hf_weights(folder: str, model: nn.Module) -> None: - assert folder is not None, "Please specify the folder of the pretrained model" - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - - fns = get_fns(folder) - model_fns = [ - os.path.join(folder, fn) - for fn in fns - if (fn.endswith(".bin") and fn.startswith("pytorch_model")) - or (fn.endswith(".safetensors") and fn.startswith("model")) - ] - model_fns.sort() - - state_dict = {} - for model_fn in model_fns: - state_dict.update(llm_load(model_fn, map_location="cpu")) - - tp_size = gpc.get_world_size(ParallelMode.TENSOR) - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - wp_size = gpc.get_world_size(ParallelMode.WEIGHT) - wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) - tp_mode = gpc.config.parallel.tensor["mode"] - split_size = wp_size if tp_mode == "isp" else tp_size - local_rank = wp_rank if tp_mode == "isp" else tp_rank - row_dim = 0 if tp_mode == "isp" else 1 - if gpc.config.model.get("embed_split_hidden", True): - embed_concat_dim = 1 - else: - embed_concat_dim = 0 - - new_state_dict = {} - - # embedding - if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): - new_state_dict["tok_embeddings.weight"] = torch.chunk( - state_dict.pop("model.embed_tokens.weight"), - split_size, - dim=embed_concat_dim, - )[local_rank] - - for idx, i in enumerate(range(model.first_layer, model.last_layer)): - layer_ids = i - - # attn - state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.W_pack.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.out_proj.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), - split_size, - dim=row_dim, - )[local_rank] - - # ffn - state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), - split_size, - dim=row_dim, - )[local_rank] - - # attn norm - state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( - f"model.layers.{layer_ids}.input_layernorm.weight" - ) - # ffn norm - state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( - f"model.layers.{layer_ids}.post_attention_layernorm.weight" - ) - - # replace value within decoder layer - for name in list(state_dict.keys()): - if name.startswith(f"layers.{i}"): - new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) - - # output - if gpc.is_last_rank(ParallelMode.PIPELINE): - new_state_dict["output.weight"] = torch.chunk( - state_dict.pop("lm_head.weight"), - split_size, - dim=0, - )[local_rank] - new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") - - missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) - - if gpc.get_local_rank(ParallelMode.DATA) == 0: - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) - logger.info( - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" - ) - - internlm_accelerator.empty_cache() - - @staticmethod - def convert_internevo2hf_weights(src: str, tgt: str) -> None: - def permute(qkv, num_heads, num_kv_heads, head_dim, qk_interleaved=False): - if not qk_interleaved: - return qkv - q_per_kv = num_heads // num_kv_heads - qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim) - q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :] - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - qkv = torch.cat((q, k, v), dim=2) - qkv = rearrange(qkv, "o g n i -> o (g n i)").T - return qkv - - model_config = gpc.config.model - tp_mode = gpc.config.parallel.tensor["mode"] - row_dim = 0 if tp_mode == "isp" else 1 - if model_config["embed_split_hidden"]: - embed_concat_dim = 1 - else: - embed_concat_dim = 0 - - # load states - states, num_shards = Baichuan2.load_sharded_states(src) - - # convert state_dict - state_dict = {} - embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] - for layer_i in tqdm(range(model_config["num_layers"])): - # attn norm, ffn norm - state_dict.update( - { - f"model.layers.{layer_i}.input_layernorm.weight": states[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - ) - # attn - state_dict[f"model.layers.{layer_i}.self_attn.W_pack.weight"] = permute( - torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0), - num_heads=model_config["num_attention_heads"], - # num_kv_attention_heads equals to num_attention_heads in MHA - num_kv_heads=model_config["num_attention_heads"], - head_dim=model_config["hidden_size"] // model_config["num_attention_heads"], - qk_interleaved=model_config.get("qk_interleaved", False), - ) - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.out_proj.weight"] for i in range(num_shards)], dim=row_dim - ) - # ffn - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 - ) - # embedding, output - for embedding_key in embedding_key_list: - if embedding_key in states[0]: - break - if embedding_key is None: - raise KeyError("Cannot find embedding key!") - state_dict.update( - { - "model.norm.weight": states[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat( - [states[i][embedding_key] for i in range(num_shards)], dim=embed_concat_dim - ), - "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), - }, - ) - - # save state_dict to hf format - shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) - for shard_file, shard in shards.items(): - llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) - if index is not None: - llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/model_implementations/transformers/modeling_gemma.py b/internlm/model/model_implementations/transformers/modeling_gemma.py deleted file mode 100644 index 5e8bd0a6d..000000000 --- a/internlm/model/model_implementations/transformers/modeling_gemma.py +++ /dev/null @@ -1,752 +0,0 @@ -# Copyright (c) InternLM. All rights reserved. -import math -import os -from typing import Optional - -import torch -from torch import nn -from tqdm import tqdm - -from internlm.accelerator import get_accelerator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.model.model_implementations.transformers.base_model import ( - BaseTransformerModel, -) -from internlm.model.model_implementations.transformers.utils import ( - normal_, - scaled_init_method_normal, - scaled_init_method_uniform, - uniform_, -) -from internlm.model.model_ops.modules.embedding import Embedding1D -from internlm.model.model_ops.modules.linear import new_linear -from internlm.model.model_ops.modules.mha import GQA -from internlm.model.model_ops.modules.mlp import new_feed_forward -from internlm.model.model_ops.modules.norm import new_layer_norm -from internlm.model.model_ops.utils import ( - convert_attn_args_to_kwargs, - convert_attn_kwargs_to_args, -) -from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.utils.logger import get_logger -from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) - -try: - from flash_attn.modules.mlp import ParallelFusedMLP -except ImportError: - pass - -internlm_accelerator = get_accelerator() -logger = get_logger(__file__) - - -class GemmaDecoder(nn.Module): - """ - 1D Packed Flash Llama Layer. - - Args: - hidden_size (int): The hidden size of model. 768 by default. - num_attention_heads (int): The number of attention heads. 12 by default. - head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. - dtype (torch.dtype): Type of data. torch.float by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_idx (int): The index of current layer. 0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default. - use_glu (bool): Whether to use glu. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. - """ - - def __init__( - self, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - head_dim: int = None, - mlp_ratio: int = 4, - attn_drop_rate: float = 0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - layer_norm_epsilon: float = 1e-6, - checkpoint: bool = False, - layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - residual_in_fp32: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm: bool = False, - fused_dropout_add_ln: bool = True, - no_bias: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - add_unit_offset: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_glu: bool = True, - use_swiglu: bool = True, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - tp_mode: str = "mtp", - ): - super().__init__() - self.checkpoint = checkpoint - # dropout selective checkpoint can only be enabled when checkpoint is disabled. - self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False - self.layer_idx = layer_idx - self.prenorm = not apply_post_layer_norm - assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" - self.fused_dropout_add_ln = fused_dropout_add_ln - self.attn_wqkv_init_std = attn_wqkv_init_std - self.attn_other_init_std = attn_other_init_std - self.ffn_uplayer_init_std = ffn_uplayer_init_std - self.ffn_other_init_std = ffn_other_init_std - - if not head_dim: - head_dim = hidden_size // num_attention_heads - - self.attention = GQA( - embed_dim=hidden_size, - num_heads=num_attention_heads, - num_kv_heads=num_kv_attention_heads, - head_dim=head_dim, - dropout=attn_drop_rate, - max_position_embeddings=max_position_embeddings, - softmax_scale=1 / math.sqrt(head_dim), - causal=True, - layer_idx=layer_idx, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - rotary_emb_dim=head_dim, - rotary_emb_scale_base=0, - device=device, - dtype=dtype, - qk_interleaved=qk_interleaved, - bias=not no_bias, - rope_base=rope_base, - enable_qkv_fusion=False, - ) - - self.dropout1 = nn.Dropout(drop_rate) - self.dropout2 = nn.Dropout(drop_rate) - self.attention_norm = new_layer_norm( - norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset - ) - self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) - - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - parallel_mode = ParallelMode.WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR - - if use_glu: - self.feed_forward = new_feed_forward( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - activation_type="swiglu" if use_swiglu else "gelu", - ) - else: - self.feed_forward = ParallelFusedMLP( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - activation="gelu_approx", - process_group=gpc.get_group(parallel_mode), - bias1=False, - bias2=False, - sequence_parallel=sequence_parallel, - checkpoint_lvl=0, - heuristic="auto", - device=device, - dtype=dtype, - ) - - self.use_glu = use_glu - self.use_swiglu = use_swiglu - self.use_scaled_init = use_scaled_init - self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm - self.return_residual = False - - if init_type == "normal": - self.init_func = normal_ - self.scaled_init_func = scaled_init_method_normal - else: - self.init_func = uniform_ - self.scaled_init_func = scaled_init_method_uniform - - self.reset_parameters() - - def reset_parameters(self): - with torch.no_grad(): - for name, param in self.attention.named_parameters(): - if param.ndim == 1: - param.data.zero_() - elif "wq" in name or "wk" in name or "wv" in name: - self.init_func(std=self.attn_wqkv_init_std)(param.data) - elif self.use_scaled_init: # wo - self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.attn_other_init_std)(param.data) - - for name, param in self.feed_forward.named_parameters(): - if self.use_glu: - if self.use_scaled_init and "w2" in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func( - std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std - )(param.data) - else: - if self.use_scaled_init and "fc1" not in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( - param.data - ) - - def forward(self, hidden_states, residual=None, **kwargs): - if self.checkpoint and self.training: - args = convert_attn_kwargs_to_args(kwargs) - return activation_checkpoint(self._forward, False, hidden_states, residual, *args) - else: - return self._forward(hidden_states, residual, **kwargs) - - def _forward(self, hidden_states, residual, *args, **kwargs): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Attn/MLP(LN(residual)) - cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 - indexes: the length of index is same as hidden states, which stand for the current position - """ - if self.prenorm: - - def _dropout_and_norm_attn(_residual, _hidden_states): - _dropped = self.dropout1(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) - else: - residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - hidden_states = self.attention(hidden_states, **mixer_kwargs) - - if not isinstance(self.feed_forward, nn.Identity): - if not self.fused_dropout_add_ln: - - def _dropout_and_norm_ffn(_residual, _hidden_states): - _dropped = self.dropout2(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint( - _dropout_and_norm_ffn, False, residual, hidden_states - ) - else: - residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - hidden_states = self.feed_forward(hidden_states) - - return hidden_states + residual - else: - assert residual is None - - mixer_out = self.attention(hidden_states, **kwargs) - if self.return_residual: # mixer out is actually a pair here - mixer_out, hidden_states = mixer_out - hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( - dtype=self.attention_norm.weight.dtype - ) - if not isinstance(self.feed_forward, nn.Identity): - mlp_out = self.feed_forward(hidden_states) - if self.return_residual: # mlp out is actually a pair here - mlp_out, hidden_states = mlp_out - hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( - dtype=self.ffn_norm.weight.dtype - ) - return hidden_states - - -class Gemma(BaseTransformerModel): - """ - 1D Packed Flash Llama. - - Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. - head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default. - vocab_size (int): The size of vocabulary. 50304 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. - drop_rate (float): The dropout rate of input hidden state. 0.0 by default. - dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. - first (bool): Whether input embedding layer or not. False by default. - last (bool): Whether output embedding layer or not. False by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - start_layer_idx (int): The index of start layer in the pipeline. 0 by default. - device (Optional[Union[str, torch.device]]): The device will be used. None by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default. - use_glu (bool): Whether to use glu. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - head_dim: int = None, - vocab_size: int = 50304, - mlp_ratio: int = 4, - attn_drop_rate: float = 0.0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - checkpoint: float = 1.0, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False, - embed_grad_scale: float = 0.1, - parallel_output: bool = True, - start_layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm=False, - no_bias=False, - residual_in_fp32: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - add_unit_offset: bool = False, - is_reward: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_glu: bool = True, - use_swiglu: bool = False, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - extra_pred_tokens: int = 0, - rope_base: int = 10000, - norm_head: bool = False, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - ): - super().__init__() - - checkpoint_layer_num = int(num_layers * checkpoint) - self.hidden_size = hidden_size - self.embed_grad_scale = embed_grad_scale - self.parallel_output = parallel_output - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if first: - self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - for _, param in self.embed_tokens.named_parameters(): - if init_type == "normal": - normal_(std=embedding_init_std)(param) - else: - uniform_(std=embedding_init_std)(param) - - self.layers = nn.ModuleList( - [ - GemmaDecoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads, - head_dim=head_dim, - mlp_ratio=mlp_ratio, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, - layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - residual_in_fp32=residual_in_fp32, - device=device, - apply_post_layer_norm=apply_post_layer_norm, - fused_dropout_add_ln=False, - no_bias=no_bias, - norm_type=norm_type, - add_unit_offset=add_unit_offset, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_glu=use_glu, - use_swiglu=use_swiglu, - qk_interleaved=qk_interleaved, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - tp_mode=self.tp_mode, - ) - for lid in range(num_layers) - ] - ) - - if last: - if not apply_post_layer_norm: - self.norm = new_layer_norm( - norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset - ) - - self.output = new_linear( - name="output", - in_features=hidden_size, - out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - bias=False, - device=device, - is_reward=is_reward, - dtype=dtype, - weight_scale=embed_grad_scale, - norm_head=norm_head, - ) - for _, param in self.output.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - if extra_pred_tokens > 0: - self.extra_pred_tokens = extra_pred_tokens - assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" - self.extra_outputs = nn.ModuleList( - [ - new_linear( - name="output", - in_features=hidden_size, - out_features=vocab_size, - bias=False, - device=device, - is_reward=is_reward, - dtype=dtype, - weight_scale=embed_grad_scale, - norm_head=norm_head, - ) - for _ in range(self.extra_pred_tokens) - ] - ) - for _, param in self.extra_outputs.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - def forward(self, hidden_states=None, input_ids=None, **kwargs): - # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "embed_tokens"): - hidden_states = self.embed_tokens(input_ids) - if self.embed_grad_scale != 1: - hidden_states = ( - self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() - ) - hidden_states = hidden_states * (self.hidden_size**0.5) - - for _, block in enumerate(self.layers): - hidden_states = block(hidden_states, residual=None, **kwargs) - - if hasattr(self, "norm"): - hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) - if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: - extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] - else: - extra_hidden_states_list = None - if hasattr(self, "output"): - hidden_states = self.output(hidden_states) - - if extra_hidden_states_list is not None: - return (hidden_states, extra_hidden_states_list) - - return hidden_states - - @staticmethod - def load_hf_weights(folder: str, model: nn.Module) -> None: - assert folder is not None, "Please specify the folder of the pretrained model" - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - - fns = get_fns(folder) - model_fns = [ - os.path.join(folder, fn) - for fn in fns - if (fn.endswith(".bin") and fn.startswith("pytorch_model")) - or (fn.endswith(".safetensors") and fn.startswith("model")) - ] - model_fns.sort() - - state_dict = {} - for model_fn in model_fns: - state_dict.update(llm_load(model_fn, map_location="cpu")) - - tp_size = gpc.get_world_size(ParallelMode.TENSOR) - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - wp_size = gpc.get_world_size(ParallelMode.WEIGHT) - wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) - tp_mode = gpc.config.parallel.tensor["mode"] - split_size = wp_size if tp_mode == "isp" else tp_size - local_rank = wp_rank if tp_mode == "isp" else tp_rank - row_dim = 0 if tp_mode == "isp" else 1 - if gpc.config.model.get("embed_split_hidden", True): - embed_concat_dim = 1 - else: - embed_concat_dim = 0 - - new_state_dict = {} - - # embedding - if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): - new_state_dict["embed_tokens.weight"] = torch.chunk( - state_dict.get("model.embed_tokens.weight"), - split_size, - dim=embed_concat_dim, - )[local_rank] - - for idx, i in enumerate(range(model.first_layer, model.last_layer)): - layer_ids = i - - # attn - state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), - split_size, - dim=row_dim, - )[local_rank] - - # ffn - state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), - split_size, - dim=row_dim, - )[local_rank] - - # attn norm - state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( - f"model.layers.{layer_ids}.input_layernorm.weight" - ) - # ffn norm - state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( - f"model.layers.{layer_ids}.post_attention_layernorm.weight" - ) - - # replace value within decoder layer - for name in list(state_dict.keys()): - if name.startswith(f"layers.{i}"): - new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) - - # output - if gpc.is_last_rank(ParallelMode.PIPELINE): - if "lm_head.weight" in state_dict: - new_state_dict["output.weight"] = torch.chunk( - state_dict.pop("lm_head.weight"), # we do not tie lm head with embedding - split_size, - dim=0, - )[local_rank] - state_dict.pop("model.embed_tokens.weight") - else: - new_state_dict["output.weight"] = torch.chunk( - # gemma model ties lm head with embedding in transformers implementation - state_dict.pop("model.embed_tokens.weight"), - split_size, - dim=0, - )[local_rank] - new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") - - missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) - - if gpc.get_local_rank(ParallelMode.DATA) == 0: - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) - logger.info( - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" - ) - - internlm_accelerator.empty_cache() - - @staticmethod - def convert_internevo2hf_weights(src: str, tgt: str) -> None: - model_config = gpc.config.model - tp_mode = gpc.config.parallel.tensor["mode"] - row_dim = 0 if tp_mode == "isp" else 1 - - # load states - states, num_shards = Gemma.load_sharded_states(src) - - # convert state_dict - state_dict = {} - embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] - for layer_i in tqdm(range(model_config["num_layers"])): - # attn norm, mlp norm - state_dict.update( - { - f"model.layers.{layer_i}.input_layernorm.weight": states[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - ) - # attn wqkv weight and bias - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], - dim=0, - ) - # attn wo weight - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim - ) - - # mlp - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 - ) - - # embedding, head - for embedding_key in embedding_key_list: - if embedding_key in states[0]: - break - if embedding_key is None: - raise KeyError("Cannot find embedding key!") - if model_config["embed_split_hidden"]: - embed_concat_dim = 1 - tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] - else: - embed_concat_dim = 0 - _, size_1 = states[0][embedding_key].shape - embdim_pertp = size_1 // num_shards - tok_emb_list = [ - torch.concat( - [ - states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] - for tp in range(num_shards) - ], - dim=0, - ) - for local_rank in range(num_shards) - ] - state_dict.update( - { - "model.norm.weight": states[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), - "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), - }, - ) - - # save state_dict to hf format - shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) - for shard_file, shard in shards.items(): - llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) - if index is not None: - # Save the index as well - llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/model_implementations/transformers/modeling_internlm.py b/internlm/model/model_implementations/transformers/modeling_internlm.py index 6a29124d7..dfcd2eec8 100644 --- a/internlm/model/model_implementations/transformers/modeling_internlm.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm.py @@ -522,124 +522,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: internlm_accelerator.empty_cache() - @staticmethod - def load_internlm_with_dynamic_parallel_size(folder: str, model: nn.Module): - - assert folder is not None, "Please specify the folder of the pretrained model" - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - - fns = get_fns(folder) - model_fns = [] - for fn in fns: - # filter with `_t` is for avoiding conflict with model_config.py - if fn.startswith("model_t") and not fn.endswith("md5"): - model_fns.append(fn) - - old_tp, old_pp = -1, -1 - for fn in model_fns: - _, tp, pp = os.path.splitext(fn)[0].split("_") - old_tp = max(old_tp, int(tp[2:]) + 1) - old_pp = max(old_pp, int(pp[2:]) + 1) - - assert old_tp > 0 and old_pp > 0, f"ckpt with tp:{old_tp} and pp:{old_pp} is illegal" - - tp = gpc.get_world_size(ParallelMode.TENSOR) - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - assert old_tp % tp == 0 or tp % old_tp == 0, ( - f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in " - f"checkpoint and {tp} in current config" - ) - - correspond_tps = [] - - if old_tp <= tp: - correspond_tps.append(tp_rank // (tp // old_tp)) - ratio = tp // old_tp - rank = tp_rank % ratio - else: - for i in range(old_tp // tp): - correspond_tps.append(tp_rank * (old_tp // tp) + i) - rank = 0 - ratio = 1 - - current_states = {} - - pp = gpc.get_world_size(ParallelMode.PIPELINE) - - assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary" - - old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1) - - for idx, parts in enumerate(old_pp_partition): - start, end = parts[0] - if model.last_layer <= start or model.first_layer >= end: - continue - - tmp_states = {} - - for correspond_tp in correspond_tps: - model_name = f"model_tp{correspond_tp}_pp{idx}.pt" - states = llm_load(os.path.join(folder, model_name), map_location="cpu") - for i in range(start, end): - if i >= model.last_layer: - break - if i < model.first_layer: - continue - for name in list(states.keys()): - if f".{i-start}." in name: - to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.") - if "norm" in name: - tmp_states[to_name] = [states.pop(name)] - elif any(x in name for x in ("out_proj", "w2")): - if "bias" not in name: - tmp_states[to_name] = tmp_states.get(to_name, []) - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=-1)[rank]) - else: - tmp_states[to_name] = [states.pop(name)] - elif any(x in name for x in ("w1", "w3")): - tmp_states[to_name] = tmp_states.get(to_name, []) - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) - elif any(x in name for x in ("Wqkv",)): - tmp_states[to_name] = tmp_states.get(to_name, []) - _wqkv = states.pop(name).chunk(3, dim=0) - _wq_splits = _wqkv[0].chunk(ratio, dim=0) - _wk_splits = _wqkv[1].chunk(ratio, dim=0) - _wv_splits = _wqkv[2].chunk(ratio, dim=0) - new_wqkv = torch.concat([_wq_splits[rank], _wk_splits[rank], _wv_splits[rank]], dim=0) - tmp_states[to_name].append(new_wqkv) - else: - raise KeyError(f"Unknown key {name}.") - - if "embedding.weight" in states and model.first_layer == 0: - tmp_states["embedding.weight"] = tmp_states.get("embedding.weight", []) - tmp_states["embedding.weight"].append(states["embedding.weight"].chunk(ratio, dim=1)[rank]) - if "head.weight" in states and model.last_layer == gpc.config.model.num_layers: - tmp_states["norm.weight"] = [states["norm.weight"]] - tmp_states["head.weight"] = tmp_states.get("head.weight", []) - tmp_states["head.weight"].append(states["head.weight"].chunk(ratio, dim=0)[rank]) - - states = {} - - for name in list(tmp_states.keys()): - data = tmp_states.pop(name) - if len(data) == 1: - current_states[name] = data[0] - else: - current_states[name] = torch.concat( - data, dim=1 if name == "embedding.weight" or any(x in name for x in ("out_proj", "w2")) else 0 - ) - - missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) - - if gpc.get_local_rank(ParallelMode.DATA) == 0: - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) - logger.info( - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" - ) - - internlm_accelerator.empty_cache() @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: diff --git a/internlm/model/model_implementations/transformers/modeling_llava.py b/internlm/model/model_implementations/transformers/modeling_llava.py deleted file mode 100644 index 614578d43..000000000 --- a/internlm/model/model_implementations/transformers/modeling_llava.py +++ /dev/null @@ -1,248 +0,0 @@ -from typing import Optional - -import torch -from torch import nn - -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import set_output_attr_to_module -from internlm.model.model_implementations.transformers.base_model import ( - BaseTransformerModel, -) -from internlm.model.model_implementations.transformers.modeling_llama import ( - Llama2Decoder, -) -from internlm.model.model_implementations.transformers.utils import normal_, uniform_ -from internlm.model.model_ops.llava.clip_builder import build_vision_tower -from internlm.model.model_ops.llava.projector_builder import build_vision_projector -from internlm.model.model_ops.modules.embedding import Embedding1D -from internlm.model.model_ops.modules.linear import new_linear -from internlm.model.model_ops.modules.norm import new_layer_norm -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -class Llava(BaseTransformerModel): - """ - 1D Packed Flash Llava. - - Args: - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - num_attention_heads (int): The number of attention head. 32 by default. - num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 32. - vocab_size (int): The size of vocabulary. 50304 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. - drop_rate (float): The dropout rate of input hidden state. 0.0 by default. - dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. - first (bool): Whether input embedding layer or not. False by default. - last (bool): Whether output embedding layer or not. False by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - start_layer_idx (int): The index of start layer in the pipeline. 0 by default. - device (Optional[Union[str, torch.device]]): The device will be used. None by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - image_token_id (int): image token id. 200000 by default. - vit_cfg (dict): The config of vision tower. None by default. - vision_proj_cfg (dict): The config of vision projector. None by default. - """ - - def __init__( - self, - num_layers: int = 48, - hidden_size: int = 2048, - num_attention_heads: int = 32, - num_kv_attention_heads: int = 32, - vocab_size: int = 50304, - mlp_ratio: int = 4, - attn_drop_rate: float = 0.0, - drop_rate: float = 0.0, - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False, - embed_grad_scale: float = 0.1, - parallel_output: bool = True, - start_layer_idx: int = 0, - device: Optional[torch.device] = None, - apply_post_layer_norm=False, - no_bias=False, - residual_in_fp32: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - is_reward: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - image_token_id: int = 200000, - vit_cfg=None, - vision_proj_cfg=None, - ): - super().__init__() - - checkpoint_layer_num = num_layers * checkpoint - - self.dtype = dtype - self.image_token_id = image_token_id - self.embed_grad_scale = embed_grad_scale - self.parallel_output = parallel_output - - if first: - self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - - for _, param in self.tok_embeddings.named_parameters(): - if init_type == "normal": - normal_(std=embedding_init_std)(param) - else: - uniform_(std=embedding_init_std)(param) - - self.layers = nn.ModuleList( - [ - Llama2Decoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads, - mlp_ratio=mlp_ratio, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - dtype=dtype, - layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, - layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation - residual_in_fp32=residual_in_fp32, - device=device, - apply_post_layer_norm=apply_post_layer_norm, - fused_dropout_add_ln=False, - no_bias=no_bias, - norm_type=norm_type, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - qk_interleaved=qk_interleaved, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - for lid in range(num_layers) - ] - ) - - if last: - if not apply_post_layer_norm: - self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.output = new_linear( - name="output", - in_features=hidden_size, - out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - ) - set_output_attr_to_module(self.output) - for _, param in self.output.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - if first: - assert vit_cfg is not None - self.vit = build_vision_tower(vit_cfg) - self.vit.requires_grad_(False) - - assert vision_proj_cfg is not None - self.vision_proj = build_vision_projector(vision_proj_cfg) - # self.vision_proj.requires_grad_(False) - - def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs): - xs = [] - pure_text = False - images = [] if images is None else images - - if hasattr(self, "vit") and hasattr(self, "vision_proj") and hasattr(self, "tok_embeddings"): - # vit - if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update - images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)] - pure_text = True - - for image in images: - assert len(image) > 0 - if len(image) == 0: - x = [] - else: - assert not isinstance(image, list), image - x = image.to(torch.cuda.current_device()).to(self.dtype) - x = self.vit(x) - x = self.vision_proj(x) - xs.append(x) - - # tok embeddings - org_ids = input_ids.clone() - input_ids[input_ids == self.image_token_id] = 0 - hidden_states = self.tok_embeddings(input_ids).clone() - - if pure_text and len(xs) > 0: - hidden_states = hidden_states + 0 * xs[0].sum() - else: - for i in range(len(xs)): - hidden_states[i, org_ids[i] == self.image_token_id] = (xs[i].reshape((-1, xs[i].shape[-1]))).to( - hidden_states.dtype - ) - - if self.embed_grad_scale != 1: - hidden_states = ( - self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() - ) - - for _, block in enumerate(self.layers): - hidden_states = block(hidden_states, residual=None, **kwargs) - - if hasattr(self, "norm"): - hidden_states = self.norm(hidden_states.float()) - - if hasattr(self, "output"): - hidden_states = self.output(hidden_states) - - return hidden_states - - @staticmethod - def load_hf_weights(folder: str, model: nn.Module) -> None: - raise NotImplementedError - - @staticmethod - def convert_internevo2hf_weights(src: str, tgt: str) -> None: - raise NotImplementedError diff --git a/internlm/model/model_implementations/transformers/modeling_mixtral.py b/internlm/model/model_implementations/transformers/modeling_mixtral.py deleted file mode 100644 index 340da871e..000000000 --- a/internlm/model/model_implementations/transformers/modeling_mixtral.py +++ /dev/null @@ -1,434 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from typing import Optional - -import torch -from torch import nn - -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.model.model_implementations.transformers.base_model import ( - BaseTransformerModel, -) -from internlm.model.model_implementations.transformers.utils import ( - normal_, - scaled_init_method_normal, -) -from internlm.model.model_ops.modules.embedding import Embedding1D -from internlm.model.model_ops.modules.linear import new_linear -from internlm.model.model_ops.modules.mha import SWA -from internlm.model.model_ops.modules.mlp import new_feed_forward -from internlm.model.model_ops.modules.norm import new_layer_norm -from internlm.model.model_ops.moe.moe import MoE -from internlm.model.model_ops.utils import ( - convert_attn_args_to_kwargs, - convert_attn_kwargs_to_args, -) -from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -class MixtralMoEDecoder(nn.Module): - """ - InternLM1 MoE Decoder Layer. - - Args: - hidden_size (int): The hidden size of model. 768 by default. - num_attention_heads (int): The number of attention heads. 12 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. - max_position_embeddings (int): The maximum position embeddings. 2048 by default. - dtype (torch.dtype): Type of data. torch.float by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_idx (int): The index of current layer. 0 by default. - use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. - dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. - use_scaled_init (bool): Whether to use scaled initialization for weights. - use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. - multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. - """ - - def __init__( - self, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - mlp_ratio: int = 4, - attn_drop_rate: float = 0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - layer_norm_epsilon: float = 1e-6, - checkpoint: bool = False, - layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - residual_in_fp32: bool = False, - device: Optional[torch.device] = None, - qkv_bias=True, - o_bias=False, - norm_type: str = "rmsnorm", - rope_base: int = 10000, - rope_scaling_factor: float = 1.0, - use_sliding_window: bool = False, - sliding_window: int = None, - qk_interleaved: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - num_experts: int = 1, - top_k: int = 1, - num_shared_experts: int = 0, - moe_layer_kwargs: dict = None, - ): - super().__init__() - self.checkpoint = checkpoint - # dropout selective checkpoint can only be enabled when checkpoint is disabled. - self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False - self.layer_idx = layer_idx - - head_dim = hidden_size // num_attention_heads - softmax_scale = 1 / math.sqrt(head_dim) - - self.mixer = SWA( - embed_dim=hidden_size, - num_heads=num_attention_heads, - num_kv_heads=num_kv_attention_heads, - dropout=attn_drop_rate, - max_position_embeddings=max_position_embeddings, - softmax_scale=softmax_scale, - causal=True, - layer_idx=layer_idx, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - rotary_emb_dim=head_dim, - rotary_emb_scale_base=0, - device=device, - dtype=dtype, - qk_interleaved=qk_interleaved, - qkv_bias=qkv_bias, - o_bias=o_bias, - rope_base=rope_base, - rope_scaling_factor=rope_scaling_factor, - use_sliding_window=use_sliding_window, - sliding_window=sliding_window, - ) - - self.dropout1 = nn.Dropout(drop_rate) - self.dropout2 = nn.Dropout(drop_rate) - self.norm1 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.norm2 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.num_experts = num_experts - if num_experts <= 1: # dense, not MoE - self.mlp = new_feed_forward( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - # TODO: to support more activation functions - activation_type="swiglu" if use_swiglu else "gelu", - ) - else: - # replace mlp by MoE module. The expert in MoE is a FeedForward module. - # mlp_cls = get_mlp_cls(self.tp_mode) - self.mlp = MoE( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - num_experts=num_experts, - top_k=top_k, - num_shared_experts=num_shared_experts, - moe_layer_kwargs=moe_layer_kwargs, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - # TODO: to support more activation functions - activation_type="swiglu" if use_swiglu else "gelu", - ) - - self.use_swiglu = use_swiglu - self.use_scaled_init = use_scaled_init - self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm - self.return_residual = False - self.reset_parameters() # TODO: check this should be changed when moe is added - - def reset_parameters(self): - with torch.no_grad(): - for name, param in self.mixer.named_parameters(): - if param.ndim == 1: - param.data.zero_() - elif "wqkv" in name: - normal_(std=0.006)(param.data) - elif self.use_scaled_init: - scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) - else: - normal_(std=0.0015)(param.data) - - for name, param in self.mlp.named_parameters(): - if param.ndim == 1 and "bias" in name: - param.data.zero_() - elif self.use_swiglu: - if self.use_scaled_init and "w2" in name: - scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) - else: - # candidate: w1, w3, fused_w1_w3 - normal_(std=0.006 if "w1" in name or "w3" in name else 0.0015)(param.data) - else: - if self.use_scaled_init and "fc1" not in name: - scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) - else: - normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) - - def forward(self, hidden_states, **kwargs): - if self.checkpoint and self.training: - # TODO: check whether this will be affected by moe - # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True. - args = convert_attn_kwargs_to_args(kwargs) - return activation_checkpoint(self._forward, False, hidden_states, *args) - else: - return self._forward(hidden_states, **kwargs) - - def _forward(self, hidden_states, *args, **kwargs): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Attn/MLP(LN(residual)) - cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 - indexes: the length of index is same as hidden states, which stand for the current position - """ - - def _dropout_and_norm_attn(_hidden_states): - _dropped = self.dropout1(_hidden_states) - _residual = _dropped - _hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype)) - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states) - else: - residual, hidden_states = _dropout_and_norm_attn(hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - hidden_states = self.mixer(hidden_states, **mixer_kwargs) - - def _dropout_and_norm_ffn(_residual, _hidden_states): - _dropped = self.dropout2(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype)) - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states) - else: - residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - # MLP. - if self.num_experts <= 1: # dense mlp output - hidden_states = self.mlp(hidden_states) - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) - else: # MoE output - hidden_states, moe_loss, _ = self.mlp(hidden_states) - - return hidden_states + residual, moe_loss - - -class MixtralMoE(BaseTransformerModel): - """ - InternLM1 MoE. - - Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. - drop_rate (float): The dropout rate of input hidden state. 0.0 by default. - max_position_embeddings (int): The maximum position embeddings. 2048 by default. - dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 0.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - first (bool): Whether input embedding layer or not. False by default. - last (bool): Whether output embedding layer or not. False by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - start_layer_idx (int): The index of start layer in the pipeline. 0 by default. - use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. - device (Optional[Union[str, torch.device]]): The device will be used. None by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. - dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. - use_scaled_init (bool): Whether to use scaled initialization for weights. - use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. - moe_type (str): determine which moe impl will be used, default is GShardMoE - mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. - multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. - """ - - def __init__( - self, - num_layers: int = 48, - hidden_size: int = 2048, - num_attention_heads: int = 32, - num_kv_attention_heads: int = 12, - vocab_size: int = 50304, - mlp_ratio: float = 4.0, - attn_drop_rate: float = 0.0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - checkpoint: float = 0.0, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False, - embed_grad_scale: float = 0.1, - parallel_output: bool = True, - start_layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - device: Optional[torch.device] = None, - qkv_bias=True, - o_bias=False, - residual_in_fp32: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - is_reward: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - rope_base: int = 10000, - rope_scaling_factor: float = 1.0, - use_sliding_window: bool = False, - max_window_layers: int = 0, - sliding_window: int = None, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - moe_type: str = None, # pylint: disable=W0613 - num_experts: bool = 1, - top_k: int = 1, - num_shared_experts: int = 0, - moe_layer_kwargs: dict = None, - ): - super().__init__() - - checkpoint_layer_num = int(num_layers * checkpoint) - - if first: - self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - - for _, param in self.embedding.named_parameters(): - normal_(std=0.0052)(param) - self.embed_grad_scale = embed_grad_scale - self.blocks = nn.ModuleList( - [ - MixtralMoEDecoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads, - mlp_ratio=mlp_ratio, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - max_position_embeddings=max_position_embeddings, - dtype=dtype, - layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, - layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - residual_in_fp32=residual_in_fp32, - device=device, - qkv_bias=qkv_bias, - o_bias=o_bias, - norm_type=norm_type, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - qk_interleaved=qk_interleaved, - rope_base=rope_base, - rope_scaling_factor=rope_scaling_factor, - use_sliding_window=use_sliding_window and lid >= max_window_layers, - sliding_window=sliding_window, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - num_experts=num_experts, - top_k=top_k, - num_shared_experts=num_shared_experts, - moe_layer_kwargs=moe_layer_kwargs, - ) - for lid in range(num_layers) - ] - ) - if last: - self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.head = new_linear( - name="head", - in_features=hidden_size, - out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - ) - for _, param in self.head.named_parameters(): - normal_(std=0.0052)(param) - - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, input_ids=None, **kwargs): - # attention_mask: compute attention on the places where the value is 1 - # old condition may fail when use shared embedding - if gpc.is_pipeline_first_stage() and input_ids is not None: - hidden_states = self.embedding(input_ids) - if self.embed_grad_scale != 1: - hidden_states = ( - self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() - ) - - moe_losses = [] - for _, block in enumerate(self.blocks): - hidden_states, mos_loss = block(hidden_states, **kwargs) - moe_losses.append(mos_loss) - - if hasattr(self, "norm"): - hidden_states = self.norm(hidden_states.float()) - if hasattr(self, "head"): - hidden_states = self.head(hidden_states) - - return hidden_states, moe_losses - - @staticmethod - def load_hf_weights(folder: str, model: nn.Module) -> None: - raise NotImplementedError - - @staticmethod - def convert_internevo2hf_weights(src: str, tgt: str) -> None: - raise NotImplementedError diff --git a/internlm/model/model_implementations/transformers/modeling_qwen2.py b/internlm/model/model_implementations/transformers/modeling_qwen2.py deleted file mode 100644 index b1d18e634..000000000 --- a/internlm/model/model_implementations/transformers/modeling_qwen2.py +++ /dev/null @@ -1,752 +0,0 @@ -# Copyright (c) InternLM. All rights reserved. -import math -import os -from typing import Optional - -import torch -from torch import nn -from tqdm import tqdm - -from internlm.accelerator import get_accelerator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.model.model_implementations.transformers.base_model import ( - BaseTransformerModel, -) -from internlm.model.model_implementations.transformers.utils import ( - normal_, - scaled_init_method_normal, - scaled_init_method_uniform, - uniform_, -) -from internlm.model.model_ops.modules.embedding import Embedding1D -from internlm.model.model_ops.modules.linear import new_linear -from internlm.model.model_ops.modules.mha import SWA -from internlm.model.model_ops.modules.mlp import new_feed_forward -from internlm.model.model_ops.modules.norm import new_layer_norm -from internlm.model.model_ops.utils import ( - convert_attn_args_to_kwargs, - convert_attn_kwargs_to_args, -) -from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.utils.logger import get_logger -from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) - -internlm_accelerator = get_accelerator() -logger = get_logger(__file__) - - -class Qwen2Decoder(nn.Module): - """ - 1D Packed Flash Qwen Layer. - - Args: - hidden_size (int): The hidden size of model. 768 by default. - num_attention_heads (int): The number of attention heads. 12 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. - dtype (torch.dtype): Type of data. torch.float by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_idx (int): The index of current layer. 0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - mlp_ratio: int = 4, - attn_drop_rate: float = 0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - layer_norm_epsilon: float = 1e-6, - checkpoint: bool = False, - layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - residual_in_fp32: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm: bool = False, - fused_dropout_add_ln: bool = True, - qkv_bias=True, - o_bias=False, - mlp_bias=False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - init_type: str = "normal", - rope_type: str = "normal", - rope_base: int = 10000, - rope_scaling_factor: float = 1.0, - use_sliding_window: bool = False, - sliding_window: int = None, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - scale_attn_weights: bool = False, # Qwen1 - use_logn_attn: bool = False, # Qwen1 - ): - super().__init__() - self.checkpoint = checkpoint - # dropout selective checkpoint can only be enabled when checkpoint is disabled. - self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False - self.layer_idx = layer_idx - self.prenorm = not apply_post_layer_norm - assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" - self.fused_dropout_add_ln = fused_dropout_add_ln - self.attn_wqkv_init_std = attn_wqkv_init_std - self.attn_other_init_std = attn_other_init_std - self.ffn_uplayer_init_std = ffn_uplayer_init_std - self.ffn_other_init_std = ffn_other_init_std - - head_dim = hidden_size // num_attention_heads - - if scale_attn_weights: - softmax_scale = None - else: - softmax_scale = 1 / math.sqrt(head_dim) - self.attention = SWA( - embed_dim=hidden_size, - num_heads=num_attention_heads, - num_kv_heads=num_kv_attention_heads, - dropout=attn_drop_rate, - max_position_embeddings=max_position_embeddings, - softmax_scale=softmax_scale, - causal=True, - layer_idx=layer_idx, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - rotary_emb_dim=head_dim, - rotary_emb_scale_base=0, - device=device, - dtype=dtype, - qk_interleaved=qk_interleaved, - qkv_bias=qkv_bias, - o_bias=o_bias, - rope_type=rope_type, - rope_base=rope_base, - rope_scaling_factor=rope_scaling_factor, - use_sliding_window=use_sliding_window, - sliding_window=sliding_window, - use_logn_attn=use_logn_attn, - ) - - self.dropout1 = nn.Dropout(drop_rate) - self.dropout2 = nn.Dropout(drop_rate) - self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.feed_forward = new_feed_forward( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - bias=mlp_bias, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - activation_type="swiglu" if use_swiglu else "gelu", - ) - - self.use_swiglu = use_swiglu - self.use_scaled_init = use_scaled_init - self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm - self.return_residual = False - - if init_type == "normal": - self.init_func = normal_ - self.scaled_init_func = scaled_init_method_normal - else: - self.init_func = uniform_ - self.scaled_init_func = scaled_init_method_uniform - - self.reset_parameters() - - def reset_parameters(self): - with torch.no_grad(): - for name, param in self.attention.named_parameters(): - if param.ndim == 1: - param.data.zero_() - elif "wq" in name or "wk" in name or "wv" in name: - self.init_func(std=self.attn_wqkv_init_std)(param.data) - elif self.use_scaled_init: # wo - self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.attn_other_init_std)(param.data) - - for name, param in self.feed_forward.named_parameters(): - if self.use_swiglu: - if self.use_scaled_init and "w2" in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - # candidate: w1, w3, fused_w1_w3 - self.init_func( - std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std - )(param.data) - else: - if self.use_scaled_init and "fc1" not in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( - param.data - ) - - def forward(self, hidden_states, residual=None, **kwargs): - if self.checkpoint and self.training: - args = convert_attn_kwargs_to_args(kwargs) - return activation_checkpoint(self._forward, False, hidden_states, residual, *args) - else: - return self._forward(hidden_states, residual, **kwargs) - - def _forward(self, hidden_states, residual, *args, **kwargs): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Attn/MLP(LN(residual)) - cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 - indexes: the length of index is same as hidden states, which stand for the current position - """ - if self.prenorm: - - def _dropout_and_norm_attn(_residual, _hidden_states): - _dropped = self.dropout1(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) - else: - residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - hidden_states = self.attention(hidden_states, **mixer_kwargs) - - if not isinstance(self.feed_forward, nn.Identity): - if not self.fused_dropout_add_ln: - - def _dropout_and_norm_ffn(_residual, _hidden_states): - _dropped = self.dropout2(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint( - _dropout_and_norm_ffn, False, residual, hidden_states - ) - else: - residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - hidden_states = self.feed_forward(hidden_states) - - return hidden_states + residual - else: - assert residual is None - - mixer_out = self.attention(hidden_states, **kwargs) - if self.return_residual: # mixer out is actually a pair here - mixer_out, hidden_states = mixer_out - hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( - dtype=self.attention_norm.weight.dtype - ) - if not isinstance(self.feed_forward, nn.Identity): - mlp_out = self.feed_forward(hidden_states) - if self.return_residual: # mlp out is actually a pair here - mlp_out, hidden_states = mlp_out - hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( - dtype=self.ffn_norm.weight.dtype - ) - return hidden_states - - -class Qwen2(BaseTransformerModel): - """ - 1D Packed Flash Qwen. - - Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. - drop_rate (float): The dropout rate of input hidden state. 0.0 by default. - dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. - first (bool): Whether input embedding layer or not. False by default. - last (bool): Whether output embedding layer or not. False by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - start_layer_idx (int): The index of start layer in the pipeline. 0 by default. - device (Optional[Union[str, torch.device]]): The device will be used. None by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - vocab_size: int = 50304, - mlp_ratio: int = 4, - attn_drop_rate: float = 0.0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - checkpoint: float = 1.0, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False, - embed_grad_scale: float = 0.1, - parallel_output: bool = True, - start_layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm=False, - qkv_bias=True, - o_bias=False, - mlp_bias=False, - residual_in_fp32: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - is_reward: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - extra_pred_tokens: int = 0, - rope_type: str = "normal", - rope_base: int = 10000, - rope_scaling_factor: float = 1.0, - use_sliding_window: bool = False, - max_window_layers: int = 0, - sliding_window: int = None, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - scale_attn_weights: bool = False, # Qwen1 - use_logn_attn: bool = False, # Qwen1 - ): - super().__init__() - - self.embed_grad_scale = embed_grad_scale - - checkpoint_layer_num = int(num_layers * checkpoint) - - if first: - self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - for _, param in self.embed_tokens.named_parameters(): - if init_type == "normal": - normal_(std=embedding_init_std)(param) - else: - uniform_(std=embedding_init_std)(param) - - self.layers = nn.ModuleList( - [ - Qwen2Decoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads, - mlp_ratio=mlp_ratio, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - dtype=dtype, - layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, - layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - residual_in_fp32=residual_in_fp32, - device=device, - apply_post_layer_norm=apply_post_layer_norm, - fused_dropout_add_ln=False, - qkv_bias=qkv_bias, - o_bias=o_bias, - mlp_bias=mlp_bias, - norm_type=norm_type, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - qk_interleaved=qk_interleaved, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - init_type=init_type, - rope_type=rope_type, - rope_base=rope_base, - rope_scaling_factor=rope_scaling_factor, - use_sliding_window=use_sliding_window and lid >= max_window_layers, - sliding_window=sliding_window, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - max_position_embeddings=max_position_embeddings, - scale_attn_weights=scale_attn_weights, - use_logn_attn=use_logn_attn, - ) - for lid in range(num_layers) - ] - ) - - if last: - if not apply_post_layer_norm: - self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.output = new_linear( - name="output", - in_features=hidden_size, - out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - ) - - for _, param in self.output.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - if extra_pred_tokens > 0: - self.extra_pred_tokens = extra_pred_tokens - assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" - self.extra_outputs = nn.ModuleList( - [ - new_linear( - name="output", - in_features=hidden_size, - out_features=vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - ) - for _ in range(self.extra_pred_tokens) - ] - ) - for _, param in self.extra_outputs.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, input_ids=None, **kwargs): - # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "embed_tokens"): - hidden_states = self.embed_tokens(input_ids) - if self.embed_grad_scale != 1: - hidden_states = ( - self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() - ) - - for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - **kwargs, - ) - - if hasattr(self, "norm"): - hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) - if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: - extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] - else: - extra_hidden_states_list = None - if hasattr(self, "output"): - hidden_states = self.output(hidden_states) - - if extra_hidden_states_list is not None: - return (hidden_states, extra_hidden_states_list) - - return hidden_states - - @staticmethod - def load_hf_weights(folder: str, model: nn.Module) -> None: - assert folder is not None, "Please specify the folder of the pretrained model" - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - - fns = get_fns(folder) - model_fns = [ - os.path.join(folder, fn) - for fn in fns - if (fn.endswith(".bin") and fn.startswith("pytorch_model")) - or (fn.endswith(".safetensors") and fn.startswith("model")) - ] - model_fns.sort() - - state_dict = {} - for model_fn in model_fns: - state_dict.update(llm_load(model_fn, map_location="cpu")) - - tp_size = gpc.get_world_size(ParallelMode.TENSOR) - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - wp_size = gpc.get_world_size(ParallelMode.WEIGHT) - wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) - tp_mode = gpc.config.parallel.tensor["mode"] - split_size = wp_size if tp_mode == "isp" else tp_size - local_rank = wp_rank if tp_mode == "isp" else tp_rank - row_dim = 0 if tp_mode == "isp" else 1 - if gpc.config.model.get("embed_split_hidden", True): - embed_concat_dim = 1 - else: - embed_concat_dim = 0 - - new_state_dict = {} - - # embedding - if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): - new_state_dict["embed_tokens.weight"] = torch.chunk( - state_dict.pop("model.embed_tokens.weight"), - split_size, - dim=embed_concat_dim, - )[local_rank] - - for idx, i in enumerate(range(model.first_layer, model.last_layer)): - layer_ids = i - - # attn - state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wq.bias"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.bias"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wk.bias"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.bias"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wv.bias"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.bias"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), - split_size, - dim=row_dim, - )[local_rank] - - # ffn - state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), - split_size, - dim=0, - )[local_rank] - state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), - split_size, - dim=row_dim, - )[local_rank] - - # attn norm - state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( - f"model.layers.{layer_ids}.input_layernorm.weight" - ) - # ffn norm - state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( - f"model.layers.{layer_ids}.post_attention_layernorm.weight" - ) - - # replace value within decoder layer - for name in list(state_dict.keys()): - if name.startswith(f"layers.{i}"): - new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) - - # output - if gpc.is_last_rank(ParallelMode.PIPELINE): - new_state_dict["output.weight"] = torch.chunk( - state_dict.pop("lm_head.weight"), - split_size, - dim=0, - )[local_rank] - new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") - - missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) - - if gpc.get_local_rank(ParallelMode.DATA) == 0: - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) - logger.info( - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" - ) - - internlm_accelerator.empty_cache() - - @staticmethod - def convert_internevo2hf_weights(src: str, tgt: str) -> None: - model_config = gpc.config.model - tp_mode = gpc.config.parallel.tensor["mode"] - row_dim = 0 if tp_mode == "isp" else 1 - - # load states - states, num_shards = Qwen2.load_sharded_states(src) - - # convert state_dict - state_dict = {} - embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] - for layer_i in tqdm(range(model_config["num_layers"])): - # attn norm, mlp norm - state_dict.update( - { - f"model.layers.{layer_i}.input_layernorm.weight": states[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - ) - # attn wqkv weight and bias - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wq.bias"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wk.bias"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], - dim=0, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wv.bias"] for i in range(num_shards)], - dim=0, - ) - # attn wo weight - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim - ) - - # mlp - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 - ) - - # embedding, head - for embedding_key in embedding_key_list: - if embedding_key in states[0]: - break - if embedding_key is None: - raise KeyError("Cannot find embedding key!") - if model_config["embed_split_hidden"]: - embed_concat_dim = 1 - tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] - else: - embed_concat_dim = 0 - _, size_1 = states[0][embedding_key].shape - embdim_pertp = size_1 // num_shards - tok_emb_list = [ - torch.concat( - [ - states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] - for tp in range(num_shards) - ], - dim=0, - ) - for local_rank in range(num_shards) - ] - state_dict.update( - { - "model.norm.weight": states[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), - "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), - }, - ) - - # save state_dict to hf format - shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) - for shard_file, shard in shards.items(): - llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) - if index is not None: - # Save the index as well - llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/model_implementations/transformers/modeling_qwen2_moe.py b/internlm/model/model_implementations/transformers/modeling_qwen2_moe.py deleted file mode 100644 index ec3978bb1..000000000 --- a/internlm/model/model_implementations/transformers/modeling_qwen2_moe.py +++ /dev/null @@ -1,561 +0,0 @@ -# Copyright (c) InternLM. All rights reserved. -import math -from typing import Optional - -import torch -from torch import nn - -from internlm.accelerator import get_accelerator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.model.model_implementations.transformers.base_model import ( - BaseTransformerModel, -) -from internlm.model.model_implementations.transformers.utils import ( - normal_, - scaled_init_method_normal, - scaled_init_method_uniform, - uniform_, -) -from internlm.model.model_ops.modules.embedding import Embedding1D -from internlm.model.model_ops.modules.linear import new_linear -from internlm.model.model_ops.modules.mha import SWA -from internlm.model.model_ops.modules.mlp import new_feed_forward -from internlm.model.model_ops.modules.norm import new_layer_norm -from internlm.model.model_ops.moe.moe import Qwen2MoE -from internlm.model.model_ops.utils import ( - convert_attn_args_to_kwargs, - convert_attn_kwargs_to_args, -) -from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.utils.logger import get_logger - -internlm_accelerator = get_accelerator() -logger = get_logger(__file__) - - -class Qwen2MoeDecoder(nn.Module): - """ - 1D Packed Flash Qwen Layer. - - Args: - hidden_size (int): The hidden size of model. 768 by default. - num_attention_heads (int): The number of attention heads. 12 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. - dtype (torch.dtype): Type of data. torch.float by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_idx (int): The index of current layer. 0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - mlp_ratio: int = 4, - attn_drop_rate: float = 0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - layer_norm_epsilon: float = 1e-6, - checkpoint: bool = False, - layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - residual_in_fp32: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm: bool = False, - fused_dropout_add_ln: bool = True, - qkv_bias=True, - o_bias=False, - mlp_bias=False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - init_type: str = "normal", - rope_type: str = "normal", - rope_base: int = 10000, - rope_scaling_factor: float = 1.0, - use_sliding_window: bool = False, - sliding_window: int = None, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - scale_attn_weights: bool = False, # Qwen1 - use_logn_attn: bool = False, # Qwen1 - num_experts: int = 1, - top_k: int = 1, - num_shared_experts: int = 0, - moe_layer_kwargs: dict = None, - ): - super().__init__() - self.checkpoint = checkpoint - # dropout selective checkpoint can only be enabled when checkpoint is disabled. - self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False - self.layer_idx = layer_idx - self.prenorm = not apply_post_layer_norm - assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" - self.fused_dropout_add_ln = fused_dropout_add_ln - self.attn_wqkv_init_std = attn_wqkv_init_std - self.attn_other_init_std = attn_other_init_std - self.ffn_uplayer_init_std = ffn_uplayer_init_std - self.ffn_other_init_std = ffn_other_init_std - - head_dim = hidden_size // num_attention_heads - - if scale_attn_weights: - softmax_scale = None - else: - softmax_scale = 1 / math.sqrt(head_dim) - self.attention = SWA( - embed_dim=hidden_size, - num_heads=num_attention_heads, - num_kv_heads=num_kv_attention_heads, - dropout=attn_drop_rate, - max_position_embeddings=max_position_embeddings, - softmax_scale=softmax_scale, - causal=True, - layer_idx=layer_idx, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - rotary_emb_dim=head_dim, - rotary_emb_scale_base=0, - device=device, - dtype=dtype, - qk_interleaved=qk_interleaved, - qkv_bias=qkv_bias, - o_bias=o_bias, - rope_type=rope_type, - rope_base=rope_base, - rope_scaling_factor=rope_scaling_factor, - use_sliding_window=use_sliding_window, - sliding_window=sliding_window, - use_logn_attn=use_logn_attn, - ) - - self.dropout1 = nn.Dropout(drop_rate) - self.dropout2 = nn.Dropout(drop_rate) - self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.num_experts = num_experts - if num_experts <= 1: # dense, not MoE - self.feed_forward = new_feed_forward( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - bias=mlp_bias, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - activation_type="swiglu" if use_swiglu else "gelu", - ) - else: - # replace mlp by MoE module. The expert in MoE is a FeedForward module. - # mlp_cls = get_mlp_cls(self.tp_mode) - self.feed_forward = Qwen2MoE( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - num_experts=num_experts, - top_k=top_k, - num_shared_experts=num_shared_experts, - moe_layer_kwargs=moe_layer_kwargs, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - activation_type="swiglu" if use_swiglu else "gelu", - ) - - self.use_swiglu = use_swiglu - self.use_scaled_init = use_scaled_init - self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm - self.return_residual = False - - if init_type == "normal": - self.init_func = normal_ - self.scaled_init_func = scaled_init_method_normal - else: - self.init_func = uniform_ - self.scaled_init_func = scaled_init_method_uniform - - self.reset_parameters() - - def reset_parameters(self): - with torch.no_grad(): - for name, param in self.attention.named_parameters(): - if param.ndim == 1: - param.data.zero_() - elif "wq" in name or "wk" in name or "wv" in name: - self.init_func(std=self.attn_wqkv_init_std)(param.data) - elif self.use_scaled_init: # wo - self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.attn_other_init_std)(param.data) - - for name, param in self.feed_forward.named_parameters(): - if self.use_swiglu: - if self.use_scaled_init and "w2" in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - # candidate: w1, w3, fused_w1_w3 - self.init_func( - std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std - )(param.data) - else: - if self.use_scaled_init and "fc1" not in name: - self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) - else: - self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( - param.data - ) - - def forward(self, hidden_states, residual=None, **kwargs): - if self.checkpoint and self.training: - args = convert_attn_kwargs_to_args(kwargs) - return activation_checkpoint(self._forward, False, hidden_states, residual, *args) - else: - return self._forward(hidden_states, residual, **kwargs) - - def _forward(self, hidden_states, residual, *args, **kwargs): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Attn/MLP(LN(residual)) - cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 - indexes: the length of index is same as hidden states, which stand for the current position - """ - if self.prenorm: - - def _dropout_and_norm_attn(_residual, _hidden_states): - _dropped = self.dropout1(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) - else: - residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - hidden_states = self.attention(hidden_states, **mixer_kwargs) - - if not isinstance(self.feed_forward, nn.Identity): - if not self.fused_dropout_add_ln: - - def _dropout_and_norm_ffn(_residual, _hidden_states): - _dropped = self.dropout2(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) - - return _residual, _hidden_states - - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint( - _dropout_and_norm_ffn, False, residual, hidden_states - ) - else: - residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) - - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - if self.num_experts <= 1: # dense mlp output - hidden_states = self.feed_forward(hidden_states) - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) - else: # MoE output - hidden_states, moe_loss, _ = self.feed_forward(hidden_states) - - return hidden_states + residual, moe_loss - else: - assert residual is None - - mixer_out = self.attention(hidden_states, **kwargs) - if self.return_residual: # mixer out is actually a pair here - mixer_out, hidden_states = mixer_out - hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( - dtype=self.attention_norm.weight.dtype - ) - if not isinstance(self.feed_forward, nn.Identity): - if self.num_experts <= 1: # dense mlp output - mlp_out = self.feed_forward(hidden_states) - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) - else: # MoE output - mlp_out, moe_loss, _ = self.feed_forward(hidden_states) - - if self.return_residual: # mlp out is actually a pair here - mlp_out, hidden_states = mlp_out - hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( - dtype=self.ffn_norm.weight.dtype - ) - return hidden_states, moe_loss - - -class Qwen2Moe(BaseTransformerModel): - """ - 1D Packed Flash Qwen. - - Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - mlp_ratio (int): The ratio of MLP layers. 4 by default. - attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. - drop_rate (float): The dropout rate of input hidden state. 0.0 by default. - dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. - first (bool): Whether input embedding layer or not. False by default. - last (bool): Whether output embedding layer or not. False by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - start_layer_idx (int): The index of start layer in the pipeline. 0 by default. - device (Optional[Union[str, torch.device]]): The device will be used. None by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. - """ - - def __init__( - self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, - vocab_size: int = 50304, - mlp_ratio: int = 4, - attn_drop_rate: float = 0.0, - drop_rate: float = 0.0, - max_position_embeddings: int = 2048, - dtype: torch.dtype = torch.float, - checkpoint: float = 1.0, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False, - embed_grad_scale: float = 0.1, - parallel_output: bool = True, - start_layer_idx: int = 0, - use_dynamic_ntk_rope: bool = False, - device: Optional[torch.device] = None, - apply_post_layer_norm=False, - qkv_bias=True, - o_bias=False, - mlp_bias=False, - residual_in_fp32: bool = False, - norm_type: str = "rmsnorm", - qk_interleaved: bool = False, - is_reward: bool = False, - dropout_selective_checkpoint: bool = True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - extra_pred_tokens: int = 0, - rope_type: str = "normal", - rope_base: int = 10000, - rope_scaling_factor: float = 1.0, - use_sliding_window: bool = False, - max_window_layers: int = 0, - sliding_window: int = None, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - scale_attn_weights: bool = False, # Qwen1 - use_logn_attn: bool = False, # Qwen1 - moe_type: str = None, # pylint: disable=W0613 - num_experts: bool = 1, - top_k: int = 1, - num_shared_experts: int = 0, - moe_layer_kwargs: dict = None, - ): - super().__init__() - - self.embed_grad_scale = embed_grad_scale - - checkpoint_layer_num = int(num_layers * checkpoint) - - if first: - self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - for _, param in self.embed_tokens.named_parameters(): - if init_type == "normal": - normal_(std=embedding_init_std)(param) - else: - uniform_(std=embedding_init_std)(param) - - self.layers = nn.ModuleList( - [ - Qwen2MoeDecoder( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads, - mlp_ratio=mlp_ratio, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - dtype=dtype, - layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, - layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - residual_in_fp32=residual_in_fp32, - device=device, - apply_post_layer_norm=apply_post_layer_norm, - fused_dropout_add_ln=False, - qkv_bias=qkv_bias, - o_bias=o_bias, - mlp_bias=mlp_bias, - norm_type=norm_type, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - qk_interleaved=qk_interleaved, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - init_type=init_type, - rope_type=rope_type, - rope_base=rope_base, - rope_scaling_factor=rope_scaling_factor, - use_sliding_window=use_sliding_window and lid >= max_window_layers, - sliding_window=sliding_window, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - max_position_embeddings=max_position_embeddings, - scale_attn_weights=scale_attn_weights, - use_logn_attn=use_logn_attn, - num_experts=num_experts, - top_k=top_k, - num_shared_experts=num_shared_experts, - moe_layer_kwargs=moe_layer_kwargs, - ) - for lid in range(num_layers) - ] - ) - - if last: - if not apply_post_layer_norm: - self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - - self.output = new_linear( - name="output", - in_features=hidden_size, - out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - ) - - for _, param in self.output.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - if extra_pred_tokens > 0: - self.extra_pred_tokens = extra_pred_tokens - assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" - self.extra_outputs = nn.ModuleList( - [ - new_linear( - name="output", - in_features=hidden_size, - out_features=vocab_size, - bias=False, - device=device, - dtype=dtype, - is_reward=is_reward, - weight_scale=embed_grad_scale, - ) - for _ in range(self.extra_pred_tokens) - ] - ) - for _, param in self.extra_outputs.named_parameters(): - if init_type == "normal": - normal_(std=out_head_init_std)(param) - else: - uniform_(std=out_head_init_std)(param) - - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, input_ids=None, **kwargs): - # attention_mask: compute attention on the places where the value is 1 - # old condition may fail when use shared embedding - if gpc.is_pipeline_first_stage() and input_ids is not None: - hidden_states = self.embed_tokens(input_ids) - if self.embed_grad_scale != 1: - hidden_states = ( - self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() - ) - - moe_losses = [] - for _, block in enumerate(self.layers): - hidden_states, moe_loss = block( - hidden_states, - residual=None, - **kwargs, - ) - moe_losses.append(moe_loss) - - if hasattr(self, "norm"): - hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) - if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: - extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] - else: - extra_hidden_states_list = None - if hasattr(self, "output"): - hidden_states = self.output(hidden_states) - - if extra_hidden_states_list is not None: - return (hidden_states, extra_hidden_states_list), moe_losses - - return hidden_states, moe_losses - - @staticmethod - def load_hf_weights(folder: str, model: nn.Module) -> None: - raise NotImplementedError - - @staticmethod - def convert_internevo2hf_weights(src: str, tgt: str) -> None: - raise NotImplementedError diff --git a/internlm/model/model_ops/llava/__init__.py b/internlm/model/model_ops/llava/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/internlm/model/model_ops/llava/clip_builder.py b/internlm/model/model_ops/llava/clip_builder.py deleted file mode 100644 index 78cc3fa0e..000000000 --- a/internlm/model/model_ops/llava/clip_builder.py +++ /dev/null @@ -1,13 +0,0 @@ -import os - -from .clip_encoder import CLIPVisionTower - - -def build_vision_tower(vision_tower_cfg, **kwargs): - vision_tower = vision_tower_cfg.get("mm_vision_tower", None) - is_absolute_path_exists = os.path.exists(vision_tower) - if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): - model = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) - return model - - raise ValueError(f"Unknown vision tower: {vision_tower}") diff --git a/internlm/model/model_ops/llava/clip_encoder.py b/internlm/model/model_ops/llava/clip_encoder.py deleted file mode 100644 index e1d982f72..000000000 --- a/internlm/model/model_ops/llava/clip_encoder.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from torch import nn - -from transformers import CLIPVisionConfig, CLIPVisionModel - - -class CLIPVisionTower(nn.Module): # pylint: disable=C0115 - def __init__(self, vision_tower, args, delay_load=False): - super().__init__() - - self.is_loaded = False - - self.vision_tower_name = vision_tower - self.select_layer = args.get("mm_vision_select_layer", -2) - self.select_feature = args.get("mm_vision_select_feature", "patch") - - if not delay_load: - self.load_model() - self.image_size = self.config.image_size - else: - self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) - - def load_model(self): - self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) - self.vision_tower.requires_grad_(False) - - self.is_loaded = True - - def feature_select(self, image_forward_outs): - image_features = image_forward_outs.hidden_states[self.select_layer] - if self.select_feature == "patch": - image_features = image_features[:, 1:] - elif self.select_feature == "cls_patch": - pass - else: - raise ValueError(f"Unexpected select feature: {self.select_feature}") - return image_features - - @torch.no_grad() - def forward(self, images): - if isinstance(images, list): - image_features = [] - for image in images: - image_forward_out = self.vision_tower( - image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True - ) - image_feature = self.feature_select(image_forward_out).to(image.dtype) - image_features.append(image_feature) - else: - image_forward_outs = self.vision_tower( - images.to(device=self.device, dtype=self.dtype), output_hidden_states=True - ) - image_features = self.feature_select(image_forward_outs).to(images.dtype) - - return image_features - - @property - def dummy_feature(self): - return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) - - @property - def dtype(self): - return self.vision_tower.dtype - - @property - def device(self): - return self.vision_tower.device - - @property - def config(self): - if self.is_loaded: - return self.vision_tower.config - else: - return self.cfg_only - - @property - def hidden_size(self): - return self.config.hidden_size - - @property - def num_patches(self): - return (self.config.image_size // self.config.patch_size) ** 2 diff --git a/internlm/model/model_ops/llava/projector_builder.py b/internlm/model/model_ops/llava/projector_builder.py deleted file mode 100644 index 2b1a701e3..000000000 --- a/internlm/model/model_ops/llava/projector_builder.py +++ /dev/null @@ -1,48 +0,0 @@ -import re - -from torch import nn - - -class IdentityMap(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x - - @property - def config(self): - return {"mm_projector_type": "identity"} - - -class SimpleResBlock(nn.Module): - def __init__(self, channels): - super().__init__() - self.pre_norm = nn.LayerNorm(channels) - - self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) - - def forward(self, x): - x = self.pre_norm(x) - return x + self.proj(x) - - -def build_vision_projector(config): - projector_type = config.get("mm_projector_type", "linear") - - if projector_type == "linear": - return nn.Linear(config.get("mm_hidden_size", 1024), config.get("hidden_size", 4096)) - - mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) - if mlp_gelu_match: - mlp_depth = int(mlp_gelu_match.group(1)) - modules = [nn.Linear(config.get("mm_hidden_size", 1024), config.get("hidden_size", 4096))] - for _ in range(1, mlp_depth): - modules.append(nn.GELU()) - modules.append(nn.Linear(config.get("hidden_size", 4096), config.get("hidden_size", 4096))) - return nn.Sequential(*modules) - - if projector_type == "identity": - return IdentityMap() - - raise ValueError(f"Unknown projector type: {projector_type}") diff --git a/tests/common_fixture.py b/tests/common_fixture.py index fbd3763e9..0d7fd95dc 100644 --- a/tests/common_fixture.py +++ b/tests/common_fixture.py @@ -40,14 +40,12 @@ model=dict( checkpoint=False, num_attention_heads=32, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=True, hidden_size=4096, num_layers=32, mlp_ratio=8 / 3, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py index 1d67b100d..ad8f36a3a 100644 --- a/tests/test_infer/test_generate.py +++ b/tests/test_infer/test_generate.py @@ -25,7 +25,6 @@ def load_and_generate(path, model_type="INTERNLM2", tokenizer_path=""): model_cfg = os.path.join(path, "model_config.pt") model_wt = os.path.join(path, "model_tp0_pp0.pt") model_config = torch.load(model_cfg) - model_config["apply_post_layer_norm"] = False if model_config.get("adapt_hf") is not None: model_config.pop("adapt_hf") evo_cfg = dict( diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index eaeff0ebf..084702d59 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -54,14 +54,12 @@ model=dict( checkpoint=False, num_attention_heads=2, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=True, hidden_size=1024, num_layers=2, mlp_ratio=1, - apply_post_layer_norm=False, dtype=torch.bfloat16, norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_training/7B_check_acc.py b/tests/test_training/7B_check_acc.py index 70b612c1d..eb8d32705 100644 --- a/tests/test_training/7B_check_acc.py +++ b/tests/test_training/7B_check_acc.py @@ -128,7 +128,6 @@ checkpoint=False, num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, @@ -136,7 +135,6 @@ num_layers=NUM_LAYER, no_bias=True, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_training/7B_check_init.py b/tests/test_training/7B_check_init.py index 27794dd02..9097a47c6 100644 --- a/tests/test_training/7B_check_init.py +++ b/tests/test_training/7B_check_init.py @@ -133,7 +133,6 @@ checkpoint=False, num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, @@ -141,7 +140,6 @@ num_layers=NUM_LAYER, no_bias=True, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index 3795f59bc..f76cdcd52 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -54,14 +54,12 @@ model=dict( checkpoint=True, num_attention_heads=32, - embed_split_hidden=True, vocab_size=92544, embed_grad_scale=1, parallel_output=False, hidden_size=4096, num_layers=32, mlp_ratio=8 / 3, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index 6e705ac11..7a554622e 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -97,14 +97,12 @@ model=dict( checkpoint=False, num_attention_heads=16, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=True, hidden_size=4096, num_layers=16, mlp_ratio=8 / 3, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index fe0ea0d5f..5780cf609 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -63,14 +63,12 @@ model=dict( checkpoint=False, num_attention_heads=16, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=True, hidden_size=4096, num_layers=16, mlp_ratio=8 / 3, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 0a3a00d59..d0fe61b13 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -67,14 +67,12 @@ model=dict( checkpoint=False, num_attention_heads=2, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=True, hidden_size=1024, num_layers=2, mlp_ratio=1, - apply_post_layer_norm=False, dtype=torch.bfloat16, norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/tools/README.md b/tools/README.md index a24040cae..9861ce39f 100644 --- a/tools/README.md +++ b/tools/README.md @@ -160,7 +160,6 @@ LLaMA 7B推理的例子: num_chunks=1, checkpoint=0.2, dtype="torch.bfloat16", - embed_split_hidden=True, num_layers=32, hidden_size=4096, vocab_size=32000, @@ -171,7 +170,6 @@ LLaMA 7B推理的例子: mlp_ratio=2.675, use_flash_attn=True, norm_type="rmsnorm", - apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, ), diff --git a/tools/load_internlm2_model.py b/tools/load_internlm2_model.py index aa3dcb636..121c0842e 100644 --- a/tools/load_internlm2_model.py +++ b/tools/load_internlm2_model.py @@ -286,7 +286,6 @@ def get_default_parser(): num_chunks=1, checkpoint=0.2, dtype="torch.bfloat16", - embed_split_hidden=True, num_layers=32, hidden_size=4096, vocab_size=92544, @@ -298,7 +297,6 @@ def get_default_parser(): use_flash_attn=True, norm_type="rmsnorm", qk_interleaved=True, - apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, rope_base=1000000, diff --git a/web_demo_internlm.py b/web_demo_internlm.py index abe0568e7..b89e2ae24 100644 --- a/web_demo_internlm.py +++ b/web_demo_internlm.py @@ -21,14 +21,12 @@ "internlm-chat-7b": dict( checkpoint=False, num_attention_heads=32, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=False, hidden_size=4096, num_layers=32, mlp_ratio=8 / 3, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, @@ -39,14 +37,12 @@ "internlm-chat-7b-v1.1": dict( checkpoint=False, num_attention_heads=32, - embed_split_hidden=True, vocab_size=103168, embed_grad_scale=1, parallel_output=False, hidden_size=4096, num_layers=32, mlp_ratio=8 / 3, - apply_post_layer_norm=False, dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, From 6498236aee3a0dd327550e34220259a24416b297 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Mon, 3 Mar 2025 12:12:32 +0800 Subject: [PATCH 10/32] fix pylint --- internlm/checkpoint/load_funcs.py | 9 +- internlm/core/parallel/comm/zero.py | 4 +- internlm/initialize/initialize_model.py | 6 +- .../transformers/modeling_internlm.py | 2 - .../transformers/modeling_internlm2.py | 193 ----------------- .../transformers/modeling_llama.py | 57 ----- .../solver/optimizer/hybrid_zero_optim.py | 2 +- .../solver/optimizer/hybrid_zero_optim_v2.py | 6 +- tests/test_training/test_loss.py | 195 +++++++++++++++++- 9 files changed, 206 insertions(+), 268 deletions(-) diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index 28f4e06c8..13342afcb 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -1,14 +1,7 @@ # Copyright (c) InternLM. All rights reserved. -from internlm.model.model_implementations.transformers.modeling_internlm2 import ( - InternLM2, -) -from internlm.model.model_implementations.transformers.modeling_llama import Llama2 from internlm.utils.logger import get_logger logger = get_logger(__file__) -LOAD_FUNC_DICT = { - "llama": Llama2.load_llama_pretrained_weights, - "internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size, -} +LOAD_FUNC_DICT = {} diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py index f33056778..3aa8f8a35 100644 --- a/internlm/core/parallel/comm/zero.py +++ b/internlm/core/parallel/comm/zero.py @@ -159,7 +159,9 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06 for working_param, all_splited_param in zip( self._block_working_params[block_name], all_splited_param_list ): - working_param.data.copy_(_flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param)) + working_param.data.copy_( + _flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param) + ) self._block_allgather_handles[block_name] = None self._block_gathered_params[block_name] = [] diff --git a/internlm/initialize/initialize_model.py b/internlm/initialize/initialize_model.py index 352666dcc..51272afc0 100644 --- a/internlm/initialize/initialize_model.py +++ b/internlm/initialize/initialize_model.py @@ -93,7 +93,7 @@ def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): - def _check_module(name, module): + def _check_module(module): # layer_norm if isinstance(module, (RMSNorm, nn.LayerNorm)): for param in module.parameters(): @@ -142,8 +142,8 @@ def _check_module(name, module): for _chunk in unwrap_naive_amp(model): if not is_using_fsdp(): # set param parallel attribute - for name, module in _chunk.named_modules(): - _check_module(name, module) + for _, module in _chunk.named_modules(): + _check_module(module) for name, param in _chunk.named_parameters(): assert ( diff --git a/internlm/model/model_implementations/transformers/modeling_internlm.py b/internlm/model/model_implementations/transformers/modeling_internlm.py index dfcd2eec8..22de76f1a 100644 --- a/internlm/model/model_implementations/transformers/modeling_internlm.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm.py @@ -13,7 +13,6 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module -from internlm.core.parallel.shard import partition_uniform from internlm.model.model_implementations.transformers.base_model import ( BaseTransformerModel, ) @@ -522,7 +521,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: internlm_accelerator.empty_cache() - @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: model_config = gpc.config.model diff --git a/internlm/model/model_implementations/transformers/modeling_internlm2.py b/internlm/model/model_implementations/transformers/modeling_internlm2.py index 26854527a..875ecc9ba 100644 --- a/internlm/model/model_implementations/transformers/modeling_internlm2.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm2.py @@ -2,7 +2,6 @@ import math import os from contextlib import nullcontext -from functools import reduce from typing import Optional import torch @@ -13,7 +12,6 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context -from internlm.core.parallel.shard import partition_uniform from internlm.model.model_implementations.transformers.base_model import ( BaseTransformerModel, ) @@ -31,7 +29,6 @@ from internlm.model.model_ops.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, - get_parallel_size_from_file, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger @@ -636,196 +633,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: internlm_accelerator.empty_cache() - @staticmethod - def load_internlm2_with_dynamic_parallel_size(folder, model): - """Load InternLM2 with dynamic parallel size.""" - assert folder is not None, "Please specify the folder of the pretrained model" - assert gpc.config.model_type in ["INTERNLM2"], "dynamic_parallel is only for INTERNLM2" - - fns = get_fns(folder) - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612 - - tp = gpc.get_world_size(ParallelMode.TENSOR) - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - assert old_tp % tp == 0 or tp % old_tp == 0, ( - f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in " - f"checkpoint and {tp} in current config" - ) - - correspond_tps = [] - - if old_tp <= tp: - correspond_tps.append(tp_rank // (tp // old_tp)) - ratio = tp // old_tp - rank = tp_rank % ratio - else: - for i in range(old_tp // tp): - correspond_tps.append(tp_rank * (old_tp // tp) + i) - rank = 0 - ratio = 1 - - current_states = {} - - pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612 - - assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary" - - old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1) - - for idx, parts in enumerate(old_pp_partition): - start, end = parts[0] - if model.last_layer <= start or model.first_layer >= end: - continue - tmp_states = {} - - for correspond_tp in correspond_tps: - model_name = f"model_tp{correspond_tp}_pp{idx}.pt" - states = llm_load(os.path.join(folder, model_name), map_location="cpu") - states = {k.replace("model.", ""): v for k, v in states.items()} - for i in range(start, end): - if i >= model.last_layer: - break - if i < model.first_layer: - continue - - for name in list(states.keys()): - if f".{i-start}." in name: - to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.") - - if gpc.config.model_type == "INTERNLM2": - if "norm" in name: - tmp_states[to_name] = [states.pop(name)] - elif any(x in name for x in ("wo", "w2")): - tmp_states[to_name] = tmp_states.get(to_name, []) - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank]) - elif any(x in name for x in ("w1", "w3")): - tmp_states[to_name] = tmp_states.get(to_name, []) - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) - elif any(x in name for x in ("wqkv",)): - tmp_states[to_name] = tmp_states.get(to_name, []) - if tp > gpc.config.model.num_kv_attention_heads: - assert old_tp <= gpc.config.model.num_kv_attention_heads, ( - f"`old_tp ({old_tp}) => tp ({tp})` is not supported. " - "At least one of `tp` and `old_tp` should be less than or " - "equal to `num_kv_attention_heads`" - ) - # Suitable for cases where the num_kv_attention_head is small, - # but you want to have a large TP Size - q_per_kv = ( - gpc.config.model.num_attention_heads - // gpc.config.model.num_kv_attention_heads - ) - head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads - index = torch.concat( - ( - torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio], - torch.tensor([q_per_kv, q_per_kv + 1]), - ) - ) - index = index + (q_per_kv + 2) * (tp_rank // ratio) - index = index % ( - (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp) - ) - index = index * head_dim - index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( - index.shape[0] - ) - tmp_states[to_name].append( - torch.index_select(states.pop(name), 0, index.to(torch.int32)) - ) - else: - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) - else: - raise KeyError(f"Unknown key {name}.") - - else: - assert False, "unsupported model type" - - if "tok_embeddings.weight" in states and model.first_layer == 0: - tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", []) - tmp_states["tok_embeddings.weight"].append( - states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank] - ) - if "output.weight" in states and model.last_layer == gpc.config.model.num_layers: - tmp_states["norm.weight"] = [states["norm.weight"]] - tmp_states["output.weight"] = tmp_states.get("output.weight", []) - tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank]) - - states = {} - - for name in list(tmp_states.keys()): - data = tmp_states.pop(name) - if len(data) == 1: - current_states[name] = data[0] - else: - current_states[name] = torch.concat( - data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0 - ) - # Merge copied kv heads - if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads: - assert ( - tp <= gpc.config.model.num_kv_attention_heads - ), "new_tp should be less than or equal to num_kv_attention_heads" - head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads - q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads - copied_times = old_tp // gpc.config.model.num_kv_attention_heads - cur_q_per_kv = q_per_kv // copied_times - - # pylint: disable=all - def duplicate_kv_index(i): - if i % (cur_q_per_kv + 2) >= cur_q_per_kv: - return i - else: - return -100 - - def unique_kv_index(i): - if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv: - return i - else: - return -100 - - # pylint: enable=all - - # Verify - duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] - duplicate_index = [i for i in duplicate_index if i != -100] - duplicate_index = _duplicate_index = torch.tensor(duplicate_index) - for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): - duplicate_index = torch.concat( - (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0 - ) - duplicate_kv = [] - for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1): - index = index.reshape(-1) * head_dim - index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0]) - duplicate_kv.append(torch.index_select(current_states[name], 0, index)) - assert reduce( - lambda x, y: x and y, - [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]], - ), "Copied kv heads are not equal after training!" - - # Merge - unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] - unique_index = [i for i in unique_index if i != -100] - unique_index = _unique_index = torch.tensor(unique_index) - for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): - unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0) - unique_index = unique_index * head_dim - unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( - unique_index.shape[0] - ) - current_states[name] = torch.index_select(current_states[name], 0, unique_index) - missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) - - if gpc.get_local_rank(ParallelMode.DATA) == 0: - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) - logger.info( - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" - ) - @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: model_config = gpc.config.model diff --git a/internlm/model/model_implementations/transformers/modeling_llama.py b/internlm/model/model_implementations/transformers/modeling_llama.py index 03dec8b1b..69c71ac80 100644 --- a/internlm/model/model_implementations/transformers/modeling_llama.py +++ b/internlm/model/model_implementations/transformers/modeling_llama.py @@ -586,63 +586,6 @@ def load_hf_weights(folder: str, model: nn.Module): internlm_accelerator.empty_cache() - @staticmethod - def load_llama_pretrained_weights(folder: str, model: nn.Module) -> None: - """NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config.""" - """NOTE: specified for meta-llama/Llama-2-7b""" - assert folder is not None, "Please specify the folder of the pretrained model" - if gpc.is_rank_for_log(): - logger.info(f"Loading pretrained model from {folder}") - - fns = get_fns(folder) - model_fns = [] - for fn in fns: - if fn.startswith("model_t") and not fn.endswith("md5"): - model_fns.append(os.path.join(folder, fn)) - - if len(model_fns) == 0: - model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")] - - if len(model_fns) == 0: - raise FileNotFoundError(f"No checkpoint file found in {folder}") - - model_fns.sort() - - old_tp = len(model_fns) - cur_tp = gpc.get_world_size(ParallelMode.TENSOR) - # If the two tp are inconsistent, you need to consider the merge before splitting - if old_tp != cur_tp: - raise RuntimeError( - f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first" - ) - - states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu") - - current_states = {} - for idx, i in enumerate(range(model.first_layer, model.last_layer)): - for name in list(states.keys()): - if f".{i}." in name: - current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) - - model_state_keys = set(list(model.state_dict().keys())) - - if "tok_embeddings.weight" in model_state_keys: - current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"] - assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}" - if "output.weight" in model_state_keys: - current_states["norm.weight"] = states["norm.weight"] - current_states["output.weight"] = states["output.weight"] - missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) - - if gpc.get_local_rank(ParallelMode.DATA) == 0: - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) - logger.info( - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" - ) - - internlm_accelerator.empty_cache() - @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: model_config = gpc.config.model diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 8edbcf849..313319a48 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -8,8 +8,8 @@ import torch import torch.distributed as dist -from torch.optim import Optimizer from torch._utils import _flatten_dense_tensors +from torch.optim import Optimizer from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ( diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 38a1c8b8a..c167b53eb 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -5,8 +5,8 @@ import torch import torch.distributed as dist -from torch.optim import Optimizer from torch._utils import _flatten_dense_tensors +from torch.optim import Optimizer from internlm.core.context import ( IS_REPLICA_ZERO_PARALLEL, @@ -670,7 +670,9 @@ def step(self, closure=None): # Update working parameters for working_param, all_splited_param in zip(working_params_list[gather_idx], all_splited_param_list): - working_param.data.copy_(_flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param)) + working_param.data.copy_( + _flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param) + ) for group_id in range(self.num_param_groups): self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index d13225ac3..03b998914 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -1,5 +1,6 @@ import math import os +from functools import reduce import pytest import torch @@ -7,8 +8,10 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint import CheckpointManager +from internlm.checkpoint.load_funcs import LOAD_FUNC_DICT from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import partition_uniform from internlm.core.trainer import ( Trainer, TrainState, @@ -22,10 +25,12 @@ ) from internlm.initialize.initialize_optimizer import initialize_optimizer from internlm.model.model_ops.losses import InternLoss +from internlm.model.model_ops.utils import get_parallel_size_from_file from internlm.utils.common import BatchSkipper, launch_time from internlm.utils.config import Config from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.storage_manager import get_fns, llm_load CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_internlm2.py") INTERNLM2_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss_pri/model_ckpt") @@ -46,11 +51,199 @@ 4.799427032470703, ] - cur_loss_list = [] internlm_accelerator = get_accelerator() +def load_internlm2_with_dynamic_parallel_size(folder, model): + """Load InternLM2 with dynamic parallel size.""" + assert folder is not None, "Please specify the folder of the pretrained model" + assert gpc.config.model_type in ["INTERNLM2"], "dynamic_parallel is only for INTERNLM2" + + fns = get_fns(folder) + model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612 + + tp = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + assert old_tp % tp == 0 or tp % old_tp == 0, ( + f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in " + f"checkpoint and {tp} in current config" + ) + + correspond_tps = [] + + if old_tp <= tp: + correspond_tps.append(tp_rank // (tp // old_tp)) + ratio = tp // old_tp + rank = tp_rank % ratio + else: + for i in range(old_tp // tp): + correspond_tps.append(tp_rank * (old_tp // tp) + i) + rank = 0 + ratio = 1 + + current_states = {} + + pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612 + + assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary" + + old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1) + + for idx, parts in enumerate(old_pp_partition): + start, end = parts[0] + if model.last_layer <= start or model.first_layer >= end: + continue + tmp_states = {} + + for correspond_tp in correspond_tps: + model_name = f"model_tp{correspond_tp}_pp{idx}.pt" + states = llm_load(os.path.join(folder, model_name), map_location="cpu") + states = {k.replace("model.", ""): v for k, v in states.items()} + for i in range(start, end): + if i >= model.last_layer: + break + if i < model.first_layer: + continue + + for name in list(states.keys()): + if f".{i-start}." in name: + to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.") + + if gpc.config.model_type == "INTERNLM2": + if "norm" in name: + tmp_states[to_name] = [states.pop(name)] + elif any(x in name for x in ("wo", "w2")): + tmp_states[to_name] = tmp_states.get(to_name, []) + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank]) + elif any(x in name for x in ("w1", "w3")): + tmp_states[to_name] = tmp_states.get(to_name, []) + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) + elif any(x in name for x in ("wqkv",)): + tmp_states[to_name] = tmp_states.get(to_name, []) + if tp > gpc.config.model.num_kv_attention_heads: + assert old_tp <= gpc.config.model.num_kv_attention_heads, ( + f"`old_tp ({old_tp}) => tp ({tp})` is not supported. " + "At least one of `tp` and `old_tp` should be less than or " + "equal to `num_kv_attention_heads`" + ) + # Suitable for cases where the num_kv_attention_head is small, + # but you want to have a large TP Size + q_per_kv = ( + gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads + ) + head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads + index = torch.concat( + ( + torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio], + torch.tensor([q_per_kv, q_per_kv + 1]), + ) + ) + index = index + (q_per_kv + 2) * (tp_rank // ratio) + index = index % ( + (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp) + ) + index = index * head_dim + index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( + index.shape[0] + ) + tmp_states[to_name].append( + torch.index_select(states.pop(name), 0, index.to(torch.int32)) + ) + else: + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) + else: + raise KeyError(f"Unknown key {name}.") + + else: + assert False, "unsupported model type" + + if "tok_embeddings.weight" in states and model.first_layer == 0: + tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", []) + tmp_states["tok_embeddings.weight"].append(states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank]) + if "output.weight" in states and model.last_layer == gpc.config.model.num_layers: + tmp_states["norm.weight"] = [states["norm.weight"]] + tmp_states["output.weight"] = tmp_states.get("output.weight", []) + tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank]) + + states = {} + + for name in list(tmp_states.keys()): + data = tmp_states.pop(name) + if len(data) == 1: + current_states[name] = data[0] + else: + current_states[name] = torch.concat( + data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0 + ) + # Merge copied kv heads + if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads: + assert ( + tp <= gpc.config.model.num_kv_attention_heads + ), "new_tp should be less than or equal to num_kv_attention_heads" + head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads + q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads + copied_times = old_tp // gpc.config.model.num_kv_attention_heads + cur_q_per_kv = q_per_kv // copied_times + + # pylint: disable=all + def duplicate_kv_index(i): + if i % (cur_q_per_kv + 2) >= cur_q_per_kv: + return i + else: + return -100 + + def unique_kv_index(i): + if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv: + return i + else: + return -100 + + # pylint: enable=all + + # Verify + duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] + duplicate_index = [i for i in duplicate_index if i != -100] + duplicate_index = _duplicate_index = torch.tensor(duplicate_index) + for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): + duplicate_index = torch.concat( + (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0 + ) + duplicate_kv = [] + for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1): + index = index.reshape(-1) * head_dim + index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0]) + duplicate_kv.append(torch.index_select(current_states[name], 0, index)) + assert reduce( + lambda x, y: x and y, + [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]], + ), "Copied kv heads are not equal after training!" + + # Merge + unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] + unique_index = [i for i in unique_index if i != -100] + unique_index = _unique_index = torch.tensor(unique_index) + for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): + unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0) + unique_index = unique_index * head_dim + unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( + unique_index.shape[0] + ) + current_states[name] = torch.index_select(current_states[name], 0, unique_index) + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + print( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}", + flush=True, + ) + + +LOAD_FUNC_DICT["internlm2_test"] = load_internlm2_with_dynamic_parallel_size + + def train( dp_size: int = 1, tp_size: int = 1, From bdf5b0bf0c326016a61082b33147fa49856496af Mon Sep 17 00:00:00 2001 From: caizheng Date: Wed, 5 Mar 2025 14:10:43 +0800 Subject: [PATCH 11/32] fix merge --- internlm/solver/optimizer/hybrid_zero_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 75e92cca1..82f3a2427 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -288,7 +288,7 @@ def _partition_param_list(self, group_id, param_group): if group_id not in self.meta_for_zero[rank_to_go]: self.meta_for_zero[rank_to_go][group_id] = {} - from internlm.train.pipeline import map_fqn_local_to_global + from internlm.initialize.initialize_model import map_fqn_local_to_global global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn self.meta_for_zero[rank_to_go][group_id][global_fqn] = { From 05b58a90aaa3a55ba9fa4ebe5f6aec905fd5afea Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Wed, 5 Mar 2025 17:38:53 +0800 Subject: [PATCH 12/32] rename transformers to huggingface_models to avoid name conflict --- README-ja-JP.md | 4 ++-- README-zh-Hans.md | 4 ++-- README.md | 4 ++-- ci_scripts/model/convert_to_hf.sh | 2 +- {transformers => huggingface_models}/README-zh-Hans.md | 0 {transformers => huggingface_models}/README.md | 0 {transformers => huggingface_models}/convert2hf_internlm.py | 0 {transformers => huggingface_models}/convert2hf_internlm2.py | 0 .../convert2hf_internlm_moe.py | 0 .../internlm2_model/__init__.py | 0 .../internlm2_model/configuration_internlm2.py | 0 .../internlm2_model/modeling_internlm2.py | 0 .../internlm2_model/tokenization_internlm2.py | 0 .../internlm2_model/tokenization_internlm2_fast.py | 0 .../internlm_model/__init__.py | 0 .../internlm_model/configuration_internlm.py | 0 .../internlm_model/modeling_internlm.py | 0 .../internlm_model/tokenization_internlm.py | 0 .../internlm_moe_model/__init__.py | 0 .../internlm_moe_model/configuration_internlm_moe.py | 0 .../internlm_moe_model/modeling_internlm_moe.py | 0 .../internlm_moe_model/tokenization_internlm.py | 0 {transformers => huggingface_models}/revert_internlm.py | 0 {transformers => huggingface_models}/revert_internlm2.py | 0 tools/tokenizer.py | 2 +- 25 files changed, 8 insertions(+), 8 deletions(-) rename {transformers => huggingface_models}/README-zh-Hans.md (100%) rename {transformers => huggingface_models}/README.md (100%) rename {transformers => huggingface_models}/convert2hf_internlm.py (100%) rename {transformers => huggingface_models}/convert2hf_internlm2.py (100%) rename {transformers => huggingface_models}/convert2hf_internlm_moe.py (100%) rename {transformers => huggingface_models}/internlm2_model/__init__.py (100%) rename {transformers => huggingface_models}/internlm2_model/configuration_internlm2.py (100%) rename {transformers => huggingface_models}/internlm2_model/modeling_internlm2.py (100%) rename {transformers => huggingface_models}/internlm2_model/tokenization_internlm2.py (100%) rename {transformers => huggingface_models}/internlm2_model/tokenization_internlm2_fast.py (100%) rename {transformers => huggingface_models}/internlm_model/__init__.py (100%) rename {transformers => huggingface_models}/internlm_model/configuration_internlm.py (100%) rename {transformers => huggingface_models}/internlm_model/modeling_internlm.py (100%) rename {transformers => huggingface_models}/internlm_model/tokenization_internlm.py (100%) rename {transformers => huggingface_models}/internlm_moe_model/__init__.py (100%) rename {transformers => huggingface_models}/internlm_moe_model/configuration_internlm_moe.py (100%) rename {transformers => huggingface_models}/internlm_moe_model/modeling_internlm_moe.py (100%) rename {transformers => huggingface_models}/internlm_moe_model/tokenization_internlm.py (100%) rename {transformers => huggingface_models}/revert_internlm.py (100%) rename {transformers => huggingface_models}/revert_internlm2.py (100%) diff --git a/README-ja-JP.md b/README-ja-JP.md index 3a2711611..18db395f3 100644 --- a/README-ja-JP.md +++ b/README-ja-JP.md @@ -166,8 +166,8 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - diff --git a/README-zh-Hans.md b/README-zh-Hans.md index d955d3d1e..6a5503077 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -166,8 +166,8 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - diff --git a/README.md b/README.md index 7f3247b40..90c700bcd 100644 --- a/README.md +++ b/README.md @@ -166,8 +166,8 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar diff --git a/ci_scripts/model/convert_to_hf.sh b/ci_scripts/model/convert_to_hf.sh index 3bf381c74..c0280be5d 100644 --- a/ci_scripts/model/convert_to_hf.sh +++ b/ci_scripts/model/convert_to_hf.sh @@ -25,7 +25,7 @@ if [[ -d ${CKPTS_OUTPUT} ]]; then fi fi -python ./transformers/convert2hf_internlm.py --src ${CKPTS_INPUT} --tgt ${CKPTS_OUTPUT} --tokenizer ./tools/tokenizer_internlm.model +python ./huggingface_models/convert2hf_internlm.py --src ${CKPTS_INPUT} --tgt ${CKPTS_OUTPUT} --tokenizer ./tools/tokenizer_internlm.model [[ $? -ne 0 ]] && { echo "test convert2hf_internlm.py failed."; exit_code=$(($exit_code + 1)); } #assert exists model diff --git a/transformers/README-zh-Hans.md b/huggingface_models/README-zh-Hans.md similarity index 100% rename from transformers/README-zh-Hans.md rename to huggingface_models/README-zh-Hans.md diff --git a/transformers/README.md b/huggingface_models/README.md similarity index 100% rename from transformers/README.md rename to huggingface_models/README.md diff --git a/transformers/convert2hf_internlm.py b/huggingface_models/convert2hf_internlm.py similarity index 100% rename from transformers/convert2hf_internlm.py rename to huggingface_models/convert2hf_internlm.py diff --git a/transformers/convert2hf_internlm2.py b/huggingface_models/convert2hf_internlm2.py similarity index 100% rename from transformers/convert2hf_internlm2.py rename to huggingface_models/convert2hf_internlm2.py diff --git a/transformers/convert2hf_internlm_moe.py b/huggingface_models/convert2hf_internlm_moe.py similarity index 100% rename from transformers/convert2hf_internlm_moe.py rename to huggingface_models/convert2hf_internlm_moe.py diff --git a/transformers/internlm2_model/__init__.py b/huggingface_models/internlm2_model/__init__.py similarity index 100% rename from transformers/internlm2_model/__init__.py rename to huggingface_models/internlm2_model/__init__.py diff --git a/transformers/internlm2_model/configuration_internlm2.py b/huggingface_models/internlm2_model/configuration_internlm2.py similarity index 100% rename from transformers/internlm2_model/configuration_internlm2.py rename to huggingface_models/internlm2_model/configuration_internlm2.py diff --git a/transformers/internlm2_model/modeling_internlm2.py b/huggingface_models/internlm2_model/modeling_internlm2.py similarity index 100% rename from transformers/internlm2_model/modeling_internlm2.py rename to huggingface_models/internlm2_model/modeling_internlm2.py diff --git a/transformers/internlm2_model/tokenization_internlm2.py b/huggingface_models/internlm2_model/tokenization_internlm2.py similarity index 100% rename from transformers/internlm2_model/tokenization_internlm2.py rename to huggingface_models/internlm2_model/tokenization_internlm2.py diff --git a/transformers/internlm2_model/tokenization_internlm2_fast.py b/huggingface_models/internlm2_model/tokenization_internlm2_fast.py similarity index 100% rename from transformers/internlm2_model/tokenization_internlm2_fast.py rename to huggingface_models/internlm2_model/tokenization_internlm2_fast.py diff --git a/transformers/internlm_model/__init__.py b/huggingface_models/internlm_model/__init__.py similarity index 100% rename from transformers/internlm_model/__init__.py rename to huggingface_models/internlm_model/__init__.py diff --git a/transformers/internlm_model/configuration_internlm.py b/huggingface_models/internlm_model/configuration_internlm.py similarity index 100% rename from transformers/internlm_model/configuration_internlm.py rename to huggingface_models/internlm_model/configuration_internlm.py diff --git a/transformers/internlm_model/modeling_internlm.py b/huggingface_models/internlm_model/modeling_internlm.py similarity index 100% rename from transformers/internlm_model/modeling_internlm.py rename to huggingface_models/internlm_model/modeling_internlm.py diff --git a/transformers/internlm_model/tokenization_internlm.py b/huggingface_models/internlm_model/tokenization_internlm.py similarity index 100% rename from transformers/internlm_model/tokenization_internlm.py rename to huggingface_models/internlm_model/tokenization_internlm.py diff --git a/transformers/internlm_moe_model/__init__.py b/huggingface_models/internlm_moe_model/__init__.py similarity index 100% rename from transformers/internlm_moe_model/__init__.py rename to huggingface_models/internlm_moe_model/__init__.py diff --git a/transformers/internlm_moe_model/configuration_internlm_moe.py b/huggingface_models/internlm_moe_model/configuration_internlm_moe.py similarity index 100% rename from transformers/internlm_moe_model/configuration_internlm_moe.py rename to huggingface_models/internlm_moe_model/configuration_internlm_moe.py diff --git a/transformers/internlm_moe_model/modeling_internlm_moe.py b/huggingface_models/internlm_moe_model/modeling_internlm_moe.py similarity index 100% rename from transformers/internlm_moe_model/modeling_internlm_moe.py rename to huggingface_models/internlm_moe_model/modeling_internlm_moe.py diff --git a/transformers/internlm_moe_model/tokenization_internlm.py b/huggingface_models/internlm_moe_model/tokenization_internlm.py similarity index 100% rename from transformers/internlm_moe_model/tokenization_internlm.py rename to huggingface_models/internlm_moe_model/tokenization_internlm.py diff --git a/transformers/revert_internlm.py b/huggingface_models/revert_internlm.py similarity index 100% rename from transformers/revert_internlm.py rename to huggingface_models/revert_internlm.py diff --git a/transformers/revert_internlm2.py b/huggingface_models/revert_internlm2.py similarity index 100% rename from transformers/revert_internlm2.py rename to huggingface_models/revert_internlm2.py diff --git a/tools/tokenizer.py b/tools/tokenizer.py index e67874f9d..8eba0c64a 100644 --- a/tools/tokenizer.py +++ b/tools/tokenizer.py @@ -7,7 +7,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(current_dir, "tokenizer_internlm.model") -sys.path.append(os.path.join(current_dir, "../transformers")) +sys.path.append(os.path.join(current_dir, "../huggingface_models")) from internlm_model import InternLMTokenizer # noqa: E402 # pylint: disable=C0413 tokenizer = InternLMTokenizer(vocab_file=model_path, add_bos_token=True, add_eos_token=True) From 4cd207e0f975238c950afa958b55e7186249ab19 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Wed, 5 Mar 2025 18:48:39 +0800 Subject: [PATCH 13/32] update args sanity checks and add support for FP8 --- internlm/core/engine.py | 31 ++++++++++- internlm/data/streaming/dataset.py | 2 +- internlm/initialize/initialize_launcher.py | 55 ++++++++++++++++++- .../model/model_implementations/builder.py | 36 ++++++++++++ .../transformers/modeling_internlm.py | 10 ++-- .../transformers/modeling_internlm2.py | 10 ++-- .../transformers/modeling_llama.py | 10 ++-- 7 files changed, 136 insertions(+), 18 deletions(-) diff --git a/internlm/core/engine.py b/internlm/core/engine.py index 97cb41db0..f6de9aebf 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -3,6 +3,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine +from contextlib import nullcontext from typing import List, Optional import torch @@ -10,11 +11,20 @@ from torch.nn.modules.loss import _Loss from torch.optim.lr_scheduler import _LRScheduler +from internlm.core.context import global_context as gpc from internlm.core.gradient_handler import BaseGradientHandler from internlm.solver.optimizer import BaseOptimizer from internlm.solver.schedulers import Beta2Scheduler from internlm.utils.common import get_batch_size, move_to_device +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import DelayedScaling, Format + + HAS_TE = True +except (ModuleNotFoundError, ImportError): + HAS_TE = False + class Engine: """ @@ -78,6 +88,23 @@ def __init__( # build gradient handler self._gradient_handlers = gradient_handlers if gradient_handlers else [] + # FP8 GEMM + fp8_cfg = gpc.config.get("fp8", None) + self.use_fp8 = HAS_TE and fp8_cfg is not None + if self.use_fp8: + self.fp8_recipe = DelayedScaling( + margin=fp8_cfg.get("margin", 0), # int, default = 0. Margin for scaling factor computation + fp8_format=Format[ + fp8_cfg.get("fp8_format", "HYBRID") + ], # {Format.E4M3, Format.HYBRID}, default = Format.HYBRID. FP8 Data format + amax_history_len=fp8_cfg.get( + "amax_history_len", 1024 + ), # int, default = 1024. Amax history window used for scaling factor computation + amax_compute_algo=fp8_cfg.get( + "amax_compute_algo", "max" + ), # {'max', 'most_recent'}, default = "max". Algorithm used for choosing amax + ) + @property def model(self): """Returns the model attached to the engine.""" @@ -166,7 +193,9 @@ def __call__(self, *args, **kwargs): Returns: torch.Tensor: The output of the model. """ - return self.model(*args, **kwargs) + with te.fp8_autocast(enabled=self.use_fp8, fp8_recipe=self.fp8_recipe) if self.use_fp8 else nullcontext(): + output = self.model(*args, **kwargs) + return output def load_batch(self, data_iter, to_gpu=True): """ diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py index 8b0755edf..5d7f22445 100644 --- a/internlm/data/streaming/dataset.py +++ b/internlm/data/streaming/dataset.py @@ -8,10 +8,10 @@ from datasets.distributed import split_dataset_by_node from PIL import Image from torch.utils.data import Dataset +from transformers import AutoTokenizer from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from transformers import AutoTokenizer class StreamingDataset(Dataset): diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index a162d24c0..4992df851 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -17,7 +17,7 @@ from internlm.utils.gputest import warmup_process_group from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_hf +from internlm.utils.parallel import is_using_fsdp, is_using_hf from internlm.utils.timeout import llm_timeout from internlm.utils.utils import DataType, ModelType, TensorParallelMode @@ -154,6 +154,9 @@ def args_sanity_check(): if gpc.config.parallel.pipeline["mode"] == "ZBV": gpc.v_shape = True + if "fsdp" not in gpc.config.parallel: + gpc.config.parallel._add_item("fsdp", dict(enable=False)) + # processing the data config in gpc data = gpc.config.data @@ -642,6 +645,56 @@ def args_sanity_check(): gpc.config.data.use_packed_dataset is False ), "only unpacked data is supported when using 2D sequence parallel." + # fsdp checks + if is_using_fsdp(): + assert ( + gpc.config.parallel.pipeline.size == 1 + ), f"fsdp only compatible with pp size = 1, but get pipeline size = {gpc.config.parallel.pipeline.size}" + assert ( + gpc.config.parallel.tensor.size == 1 or gpc.config.parallel.tensor.get("mode", "mtp") == "isp" + ), ( + f"fsdp only compatible with tp size > 1 in isp mode, but get tp size = " + f"{gpc.config.parallel.tensor.size} and tp mode = {gpc.config.parallel.tensor.mode}" + ) + assert ( + gpc.config.parallel.zero1.size == 1 + ), f"fsdp only compatible with zero1 size = 1, but get zero1 size = {gpc.config.parallel.zero1.size}" + assert ( + gpc.config.parallel.weight.size == 1 + ), f"fsdp only compatible with weight size = 1, but get weight size = {gpc.config.parallel.weight.size}" + if "expert" in gpc.config.parallel: + assert ( + gpc.config.parallel.expert.size == 1 + ), f"fsdp only compatible with expert size = 1, but get expert size = {gpc.config.parallel.expert.size}" + if "expert_zero1" in gpc.config.parallel: + assert gpc.config.parallel.expert_zero1.size == 1, ( + f"fsdp only compatible with expert_zero1 size = 1, " + f"but get expert_zero1 size = {gpc.config.parallel.expert_zero1.size}" + ) + if "expert_weight" in gpc.config.parallel: + assert gpc.config.parallel.expert_weight.size == 1, ( + f"fsdp only compatible with expert_weight size = 1, " + f"but get expert_weight size = {gpc.config.parallel.expert_weight.size}" + ) + assert "mode" in gpc.config.parallel.fsdp, "mode must be specified in fsdp when enabled" + fsdp_mode = gpc.config.parallel.fsdp.mode + assert "init_method" in gpc.config.parallel.fsdp, "init_method must be specified in fsdp when enabled" + fsdp_init_method = gpc.config.parallel.fsdp.init_method + if fsdp_mode == "v1": + assert ( + torch.__version__ >= "2.4.0" + ), f"requires torch>=2.4.0 when using fsdp v1 but current version is {torch.__version__}" + elif fsdp_mode == "v2": + assert ( + torch.__version__ >= "2.5.1" + ), f"requires torch>=2.5.1 when using fsdp v2 but current version is {torch.__version__}" + else: + raise ValueError(f"fsdp mode {fsdp_mode} not supported") + assert fsdp_init_method in ["cuda", "cpu", "meta"], f"fsdp init_method {fsdp_init_method} not supported" + + # fp8 checks + + # loss operator type loss_cfg = gpc.config.loss if loss_cfg.get("op_type", None) is None: diff --git a/internlm/model/model_implementations/builder.py b/internlm/model/model_implementations/builder.py index 8b3113faa..7393837de 100644 --- a/internlm/model/model_implementations/builder.py +++ b/internlm/model/model_implementations/builder.py @@ -19,9 +19,42 @@ from internlm.utils.logger import get_logger from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp +try: + import transformer_engine.pytorch as te + + HAS_TE = True +except (ModuleNotFoundError, ImportError): + HAS_TE = False + + logger = get_logger(__file__) +def simple_swap(model, device): + for submodule_name, submodule in model.named_modules(): + if isinstance(submodule, torch.nn.Linear): + path_in_state_dict = submodule_name.split(".") + current_module = model + + # traverse to leaf module + leaf_path = path_in_state_dict[:-1] + leaf_name = path_in_state_dict[-1] + for child_name in leaf_path: + current_module = getattr(current_module, child_name) + + # perform a swap + old_leaf = getattr(current_module, leaf_name) + new_leaf = te.Linear(old_leaf.in_features, old_leaf.out_features, old_leaf.bias is not None, device=device) + with torch.no_grad(): + new_leaf.weight.copy_(old_leaf.weight) + assert torch.equal(new_leaf.weight, old_leaf.weight) + if old_leaf.bias is not None: + new_leaf.bias.copy_(old_leaf.bias) + assert torch.equal(new_leaf.bias, old_leaf.bias) + + setattr(current_module, leaf_name, new_leaf) + + def create_model() -> Union[nn.Module, List[nn.Module]]: if is_using_hf(): model = create_model_hf(hf=gpc.config.hf) @@ -128,4 +161,7 @@ def traverse(module): else: traverse(model) + if HAS_TE and gpc.config.get("fp8", None) is not None: + simple_swap(model=model, device=fsdp_init_method) + return model diff --git a/internlm/model/model_implementations/transformers/modeling_internlm.py b/internlm/model/model_implementations/transformers/modeling_internlm.py index 22de76f1a..a201abe3b 100644 --- a/internlm/model/model_implementations/transformers/modeling_internlm.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm.py @@ -8,6 +8,11 @@ import torch from torch import nn from tqdm import tqdm +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -34,11 +39,6 @@ from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) internlm_accelerator = get_accelerator() logger = get_logger(__file__) diff --git a/internlm/model/model_implementations/transformers/modeling_internlm2.py b/internlm/model/model_implementations/transformers/modeling_internlm2.py index 875ecc9ba..e22e51a76 100644 --- a/internlm/model/model_implementations/transformers/modeling_internlm2.py +++ b/internlm/model/model_implementations/transformers/modeling_internlm2.py @@ -7,6 +7,11 @@ import torch from torch import nn from tqdm import tqdm +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -33,11 +38,6 @@ from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) internlm_accelerator = get_accelerator() logger = get_logger(__file__) diff --git a/internlm/model/model_implementations/transformers/modeling_llama.py b/internlm/model/model_implementations/transformers/modeling_llama.py index 69c71ac80..b2e1aef9e 100644 --- a/internlm/model/model_implementations/transformers/modeling_llama.py +++ b/internlm/model/model_implementations/transformers/modeling_llama.py @@ -5,6 +5,11 @@ import torch from torch import nn from tqdm import tqdm +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -31,11 +36,6 @@ from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from transformers.modeling_utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - shard_checkpoint, -) internlm_accelerator = get_accelerator() logger = get_logger(__file__) From 1cf9ee630e1398f98ba8f3b478e96f0508fc7bd0 Mon Sep 17 00:00:00 2001 From: caizheng Date: Wed, 5 Mar 2025 21:52:48 +0800 Subject: [PATCH 14/32] add 7B_internlm2_hf config and refine some fsdp or fp8 codes --- configs/7B_internlm2_hf.py | 262 ++++++++++++++++++ .../internlm2_model/modeling_internlm2.py | 81 +++++- internlm/core/parallel/comm/isp.py | 19 +- internlm/initialize/initialize_launcher.py | 46 +-- internlm/initialize/initialize_optimizer.py | 15 +- internlm/model/model_ops/ops/attention.py | 16 +- 6 files changed, 374 insertions(+), 65 deletions(-) create mode 100644 configs/7B_internlm2_hf.py diff --git a/configs/7B_internlm2_hf.py b/configs/7B_internlm2_hf.py new file mode 100644 index 000000000..4e6ed9042 --- /dev/null +++ b/configs/7B_internlm2_hf.py @@ -0,0 +1,262 @@ +JOB_NAME = "7b_internlm2_train" +DO_ALERT = False + + +VOCAB_SIZE = 92544 +SEQ_LEN = 2048 + + +MODEL_ONLY_FOLDER = None +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + load_ckpt_folder="local:llm_ckpts/", + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=True, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=20000, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. +loss = dict(label_smoothing=0, op_type="py_vocab_parallel") + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False + +model = dict( + dtype="torch.bfloat16", + checkpoint=0, + parallel_output=True, +) + +hf = dict( + cfg="huggingface_models.internlm2_model.configuration_internlm2", + cfg_cls="InternLM2Config", + cfg_extra_kwargs=dict( + vocab_size=VOCAB_SIZE, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=False, + pad_token_id=None, # We actually didn't use pad_token_id in this framework + # bos_token_id=1, + # eos_token_id=2, + # pretraining_tp=1, + tie_word_embeddings=False, + bias=False, + rope_theta=1000000, + rope_scaling=None, + attn_implementation="flash_attention_2", + dtype=model["dtype"], + return_dict=False, + ), + mod="huggingface_models.internlm2_model.modeling_internlm2", + mod_cls="InternLM2ForCausalLM", +) + +fsdp_wrap_cls = [ + dict( + mod=hf["mod"], + mod_cls="InternLM2DecoderLayer", + ), +] + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. + 3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. +""" +parallel = dict( + fsdp=dict(enable=True, mode="v1", init_method="meta"), + zero1=dict(size=1), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"), + weight=dict(size=1, overlap=True), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) + + +# fp8 = dict( +# margin=0, +# fp8_format="HYBRID", +# amax_history_len=1024, +# amax_compute_algo="max", +# ) diff --git a/huggingface_models/internlm2_model/modeling_internlm2.py b/huggingface_models/internlm2_model/modeling_internlm2.py index f026e5f9d..18eaaa1c4 100644 --- a/huggingface_models/internlm2_model/modeling_internlm2.py +++ b/huggingface_models/internlm2_model/modeling_internlm2.py @@ -40,6 +40,15 @@ replace_return_docstrings, ) +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.model.model_ops.ops.attention import ( + isp_flash_attn_func, + isp_flash_attn_varlen_func, +) +from internlm.model.model_ops.ops.fused_rmsnorm import fused_rms_norm_fn +from internlm.solver.activation_checkpoint import apply_ac_to_transformer_block + try: from transformers.generation.streamers import BaseStreamer except: # noqa # pylint: disable=bare-except @@ -53,17 +62,24 @@ flash_attn_func, flash_attn_varlen_func = None, None pad_input, index_first_axis, unpad_input = None, None, None + + def _import_flash_attn(): global flash_attn_func, flash_attn_varlen_func global pad_input, index_first_axis, unpad_input try: - from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func - from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input except ImportError: raise ImportError("flash_attn is not installed.") + # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -121,11 +137,15 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + # input_dtype = hidden_states.dtype + # hidden_states = hidden_states.to(torch.float32) + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight * hidden_states.to(input_dtype) + return fused_rms_norm_fn(hidden_states, self.weight, self.variance_epsilon) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 @@ -164,6 +184,13 @@ def forward(self, x, seq_len=None): self.sin_cached[:seq_len].to(dtype=x.dtype), ) + def reset_parameters(self): + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(self.inv_freq.device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._set_cos_sin_cache( + seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2 class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): @@ -443,6 +470,12 @@ def forward( bsz, q_len, _ = hidden_states.size() + use_packed_dataset = gpc.config.data.get("use_packed_dataset", False) + if use_packed_dataset: + assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True" + cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] + max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] + qkv_states = self.wqkv(hidden_states) qkv_states = rearrange( @@ -480,9 +513,31 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len - ) + # attn_output = self._flash_attention_forward( + # query_states, key_states, value_states, attention_mask, q_len + # ) + if use_packed_dataset: + attn_output = isp_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + ) + else: + attn_output = isp_flash_attn_func( + query_states, + key_states, + value_states, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.wo(attn_output) @@ -584,6 +639,7 @@ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, quer (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + INTERNLM2_ATTENTION_CLASSES = { "eager": InternLM2Attention, "flash_attention_2": InternLM2FlashAttention2, @@ -794,6 +850,11 @@ def __init__(self, config: InternLM2Config): self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + for layer_id, transformer_block in self.layers.named_children(): + checkpoint = gpc.config.model.checkpoint + if checkpoint > 0: + transformer_block = apply_ac_to_transformer_block(transformer_block, checkpoint) + self.layers.register_module(layer_id, transformer_block) self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 416dc0fb9..c7a1e3b9b 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -1504,6 +1504,17 @@ def _q_k_v(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwa Returns: * output (Tensor): context output """ + # if the num head of kv is not enough to be splitted by sp + # then we could copy the kv head + num_head_k = k.shape[2] + if self.sp_size > num_head_k: + assert self.sp_size % num_head_k == 0, "the num_head_k should be divided by sp size." + k = expandKVPacked(k, self.sp_size // num_head_k, 2) + num_head_v = v.shape[2] + if self.sp_size > num_head_v: + assert self.sp_size % num_head_v == 0, "the num_head_v should be divided by sp size." + v = expandKVPacked(v, self.sp_size // num_head_v, 2) + # self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim] # q shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] # scatter in n_head and gather in seqlen(packlen) @@ -1562,8 +1573,10 @@ def _attetion_constructor(*args, attn_impl: type, **kwargs) -> Callable: if tp_mode != TensorParallelMode.isp.name: return attn_impl(*args, **kwargs) else: - return DistributedAttention( - local_attention=attn_impl, sequence_process_group=gpc.get_group(ParallelMode.TENSOR) - )(*args, **kwargs) + if gpc.config.parallel.sequence_2D.enable is True: + spg = gpc.get_group(ParallelMode.HEAD) + else: + spg = gpc.get_group(ParallelMode.TENSOR) + return DistributedAttention(local_attention=attn_impl, sequence_process_group=spg)(*args, **kwargs) return partial(_attetion_constructor, attn_impl=attn_impl) diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index 4992df851..32aa3cd84 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -35,30 +35,12 @@ internlm_accelerator = get_accelerator() -def dispatch_hf_config_before_launch(model_config) -> None: - # dispatch HuggingFace model config into InternEvo model config as much as we know - if hasattr(model_config, "vocab_size"): - gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size - if hasattr(model_config, "num_hidden_layers"): - gpc.config.model.num_layers = gpc.config.NUM_LAYER = model_config.num_hidden_layers - if hasattr(model_config, "num_attention_heads"): - gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = model_config.num_attention_heads - if hasattr(model_config, "num_key_value_heads"): - gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = model_config.num_key_value_heads - if hasattr(model_config, "hidden_size"): - gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = model_config.hidden_size - if hasattr(model_config, "intermediate_size"): - gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size - if hasattr(model_config, "num_experts"): - gpc.config.model.num_experts = model_config.num_experts - - -def inject_hf_config_before_launch(hf: dict): +def dispatch_hf_config_before_launch(hf: dict) -> None: # get HuggingFace model config cfg = LazyObject(hf.cfg, hf.cfg_cls) cfg = cfg.build() model_config = cfg(**hf.cfg_extra_kwargs) - # inject HuggingFace model config into InternTrain as much as we know + # dispatch HuggingFace model config into InternEvo model config as much as we know if hasattr(model_config, "vocab_size"): gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size if hasattr(model_config, "num_hidden_layers"): @@ -87,13 +69,6 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) - # dispatch HuggingFace model config into InternEvo model config - if is_using_hf(): - cfg = LazyObject(gpc.config.hf.cfg, gpc.config.hf.cfg_cls) - cfg = cfg.build() - model_config = cfg(**gpc.config.hf.cfg_extra_kwargs) - dispatch_hf_config_before_launch(model_config) - if gpc.config.model_type == "InternLM3_M": # TODO: need check for isp overlap num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers @@ -156,7 +131,7 @@ def args_sanity_check(): if "fsdp" not in gpc.config.parallel: gpc.config.parallel._add_item("fsdp", dict(enable=False)) - + # processing the data config in gpc data = gpc.config.data @@ -650,9 +625,7 @@ def args_sanity_check(): assert ( gpc.config.parallel.pipeline.size == 1 ), f"fsdp only compatible with pp size = 1, but get pipeline size = {gpc.config.parallel.pipeline.size}" - assert ( - gpc.config.parallel.tensor.size == 1 or gpc.config.parallel.tensor.get("mode", "mtp") == "isp" - ), ( + assert gpc.config.parallel.tensor.size == 1 or gpc.config.parallel.tensor.get("mode", "mtp") == "isp", ( f"fsdp only compatible with tp size > 1 in isp mode, but get tp size = " f"{gpc.config.parallel.tensor.size} and tp mode = {gpc.config.parallel.tensor.mode}" ) @@ -664,9 +637,9 @@ def args_sanity_check(): ), f"fsdp only compatible with weight size = 1, but get weight size = {gpc.config.parallel.weight.size}" if "expert" in gpc.config.parallel: assert ( - gpc.config.parallel.expert.size == 1 + gpc.config.parallel.expert.size == 1 or gpc.config.parallel.expert.size == -1 ), f"fsdp only compatible with expert size = 1, but get expert size = {gpc.config.parallel.expert.size}" - if "expert_zero1" in gpc.config.parallel: + if "expert_zero1" in gpc.config.parallel: assert gpc.config.parallel.expert_zero1.size == 1, ( f"fsdp only compatible with expert_zero1 size = 1, " f"but get expert_zero1 size = {gpc.config.parallel.expert_zero1.size}" @@ -691,9 +664,6 @@ def args_sanity_check(): else: raise ValueError(f"fsdp mode {fsdp_mode} not supported") assert fsdp_init_method in ["cuda", "cpu", "meta"], f"fsdp init_method {fsdp_init_method} not supported" - - # fp8 checks - # loss operator type loss_cfg = gpc.config.loss @@ -743,6 +713,10 @@ def launch( # init default process group gpc.init_global_dist(rank, world_size, backend, host, port) + # dispatch HuggingFace model config into InternEvo + if is_using_hf(): + dispatch_hf_config_before_launch(gpc.config.hf) + # init process groups for different parallel modes from config gpc.init_parallel_groups() diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py index 28082cb85..7fb5c038d 100644 --- a/internlm/initialize/initialize_optimizer.py +++ b/internlm/initialize/initialize_optimizer.py @@ -48,13 +48,20 @@ def split_params_into_different_groups_for_optimizer( elif not isinstance(param_groups, list): raise ValueError(f"Unknown param group type of {type(param_groups)}") + if is_using_fsdp(): + optimizer_mode = ParallelMode.GLOBAL + optimizer_mode_expert = ParallelMode.GLOBAL + else: + optimizer_mode = ParallelMode.ZERO1 + optimizer_mode_expert = ParallelMode.EXPERT_DATA + new_groups = {} # create new groups for fp32 parameter group - new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1} + new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": optimizer_mode} if gpc.config.model.get("num_experts", 1) > 1: for key in gpc.expert_parallel_group_names: - new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": ParallelMode.EXPERT_DATA} + new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": optimizer_mode_expert} for pgroup in param_groups: # copy attribute from origin group, we assume the input param_groups only @@ -76,12 +83,12 @@ def split_params_into_different_groups_for_optimizer( # default param group, which is the first group in the param groups pgroup["params"] = origin_params - pgroup["optimizer_mode"] = ParallelMode.ZERO1 + pgroup["optimizer_mode"] = optimizer_mode # param groups may contain empty groups, such as fp32 param_groups.extend(new_groups.values()) - return tuple(param_groups) + return list(param_groups) def create_param_groups(model, weight_decay): diff --git a/internlm/model/model_ops/ops/attention.py b/internlm/model/model_ops/ops/attention.py index e2b622412..5beccba9e 100644 --- a/internlm/model/model_ops/ops/attention.py +++ b/internlm/model/model_ops/ops/attention.py @@ -1188,12 +1188,9 @@ def isp_flash_attn_varlen_func( causal=False, softmax_scale=None, attention_dropout=0.0, - return_attn_probs=False, ): - assert ( - device_backend == AcceleratorType.GPU and gpu_flash_attn_impl - ), "isp_flash_attn_varlen_func currently only support GPU." - return _flash_varlen_qkvsplited_func( + _, op = _select_attn_op(AttnOpType.VarLenQKVSplited) + return op( q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1), @@ -1204,7 +1201,6 @@ def isp_flash_attn_varlen_func( dropout_p=attention_dropout, softmax_scale=softmax_scale, causal=causal, - return_attn_probs=return_attn_probs, ).unsqueeze(0) @@ -1216,17 +1212,13 @@ def isp_flash_attn_func( causal=False, softmax_scale=None, attention_dropout=0.0, - return_attn_probs=False, ): - assert ( - device_backend == AcceleratorType.GPU and gpu_flash_attn_impl - ), "isp_flash_attn_func currently only support GPU." - return _flash_fixedlen_qkvsplited_func( + _, op = _select_attn_op(AttnOpType.FixedLenQKVSplited) + return op( q, k, v, dropout_p=attention_dropout, softmax_scale=softmax_scale, causal=causal, - return_attn_probs=return_attn_probs, ) From f77b4f217016c160fe3f00412abaf21054fb2945 Mon Sep 17 00:00:00 2001 From: caizheng Date: Wed, 5 Mar 2025 21:57:45 +0800 Subject: [PATCH 15/32] fix pylint --- internlm/initialize/initialize_launcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index 32aa3cd84..47007336c 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -636,8 +636,9 @@ def args_sanity_check(): gpc.config.parallel.weight.size == 1 ), f"fsdp only compatible with weight size = 1, but get weight size = {gpc.config.parallel.weight.size}" if "expert" in gpc.config.parallel: - assert ( - gpc.config.parallel.expert.size == 1 or gpc.config.parallel.expert.size == -1 + assert gpc.config.parallel.expert.size in ( + 1, + -1, ), f"fsdp only compatible with expert size = 1, but get expert size = {gpc.config.parallel.expert.size}" if "expert_zero1" in gpc.config.parallel: assert gpc.config.parallel.expert_zero1.size == 1, ( From e854c9088016d97113ee1f47e43092c743f57ce2 Mon Sep 17 00:00:00 2001 From: caizheng Date: Fri, 7 Mar 2025 16:18:48 +0800 Subject: [PATCH 16/32] typo fix --- internlm/initialize/initialize_launcher.py | 6 ++++-- internlm/initialize/initialize_model.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index 47007336c..bc8eb4bc5 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -55,6 +55,8 @@ def dispatch_hf_config_before_launch(hf: dict) -> None: gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size if hasattr(model_config, "num_experts"): gpc.config.model.num_experts = model_config.num_experts + elif hasattr(model_config, "n_routed_experts"): + gpc.config.model.num_experts = model_config.n_routed_experts def args_sanity_check(): @@ -580,7 +582,7 @@ def args_sanity_check(): assert gpc.config.parallel.zero1.size in ( -1, gpc.get_world_size(ParallelMode.DATA), - ), "moe only support zero1, set zero1=dict(size=-1,...) can fix this" + ) or is_using_fsdp(), "moe only support zero1, set zero1=dict(size=-1,...) can fix this" if gpc.config.parallel.tensor.mode != "isp": assert gpc.config.parallel.expert_weight.size <= 1, "expert weight parallel is only supported with isp" @@ -639,7 +641,7 @@ def args_sanity_check(): assert gpc.config.parallel.expert.size in ( 1, -1, - ), f"fsdp only compatible with expert size = 1, but get expert size = {gpc.config.parallel.expert.size}" + ), f"fsdp only compatible with expert size = (-1, 1), but get expert size = {gpc.config.parallel.expert.size}" if "expert_zero1" in gpc.config.parallel: assert gpc.config.parallel.expert_zero1.size == 1, ( f"fsdp only compatible with expert_zero1 size = 1, " diff --git a/internlm/initialize/initialize_model.py b/internlm/initialize/initialize_model.py index 5a9580014..6f7514285 100644 --- a/internlm/initialize/initialize_model.py +++ b/internlm/initialize/initialize_model.py @@ -353,4 +353,8 @@ def initialize_model_and_parallel_communicator(model: Optional[Union[nn.Module, model = wrap_FSDP_model(model) + if gpc.is_rank_for_log(): + logger.info(f"show model: {model}") + logger.info(f"model params: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B") + return model, isp_communicator From 8e04b09f7a81c6677ab5e727da2a5ce19f40a2d9 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Wed, 12 Mar 2025 15:24:50 +0800 Subject: [PATCH 17/32] update fsdp wrap --- internlm/core/fsdp.py | 5 ++--- internlm/initialize/initialize_launcher.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py index 3f74b7b34..1803334ac 100644 --- a/internlm/core/fsdp.py +++ b/internlm/core/fsdp.py @@ -1,5 +1,4 @@ import collections -import functools import itertools from typing import List, Optional, Set, Union @@ -11,7 +10,7 @@ BackwardPrefetch, ShardingStrategy, ) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -170,7 +169,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): module=model, process_group=gpc.get_group(ParallelMode.GLOBAL), sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD - auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)), + auto_wrap_policy=ModuleWrapPolicy(wrap_cls), sync_module_states=fsdp_init_method != "cuda", # sync model paramters forward_prefetch=True, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index bc8eb4bc5..18d029aa1 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -579,10 +579,14 @@ def args_sanity_check(): assert ( not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" - assert gpc.config.parallel.zero1.size in ( - -1, - gpc.get_world_size(ParallelMode.DATA), - ) or is_using_fsdp(), "moe only support zero1, set zero1=dict(size=-1,...) can fix this" + assert ( + gpc.config.parallel.zero1.size + in ( + -1, + gpc.get_world_size(ParallelMode.DATA), + ) + or is_using_fsdp() + ), "moe only support zero1, set zero1=dict(size=-1,...) can fix this" if gpc.config.parallel.tensor.mode != "isp": assert gpc.config.parallel.expert_weight.size <= 1, "expert weight parallel is only supported with isp" @@ -637,11 +641,6 @@ def args_sanity_check(): assert ( gpc.config.parallel.weight.size == 1 ), f"fsdp only compatible with weight size = 1, but get weight size = {gpc.config.parallel.weight.size}" - if "expert" in gpc.config.parallel: - assert gpc.config.parallel.expert.size in ( - 1, - -1, - ), f"fsdp only compatible with expert size = (-1, 1), but get expert size = {gpc.config.parallel.expert.size}" if "expert_zero1" in gpc.config.parallel: assert gpc.config.parallel.expert_zero1.size == 1, ( f"fsdp only compatible with expert_zero1 size = 1, " From ca942a2394b686088e77d142b2108dc129ff2a85 Mon Sep 17 00:00:00 2001 From: caizheng Date: Thu, 13 Mar 2025 15:02:24 +0800 Subject: [PATCH 18/32] typo fix --- internlm/utils/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 843c31fcb..b02af1e66 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -13,7 +13,7 @@ ParallelMode, ) from internlm.core.context import global_context as gpc -from internlm.model.modules.utils import is_gate_param +from internlm.model.model_ops.modules.utils import is_gate_param from internlm.utils.utils import TensorParallelMode From bea38df1d817a88449bdde833e80d91789c3a883 Mon Sep 17 00:00:00 2001 From: caizheng Date: Thu, 3 Apr 2025 12:27:09 +0800 Subject: [PATCH 19/32] support ep for fsdp --- configs/7B_sft.py | 13 +- internlm/checkpoint/checkpoint_manager.py | 2 +- internlm/core/fsdp.py | 52 ++- internlm/initialize/initialize_launcher.py | 7 +- internlm/initialize/initialize_optimizer.py | 2 +- internlm/model/model_ops/ops/cross_entropy.py | 14 +- .../ops/cross_entropy_ops/__init__.py | 2 + .../ops/cross_entropy_ops/flash_loss.py | 412 ++++++++++++++++++ internlm/solver/optimizer/fsdp_optimizer.py | 21 +- train.py | 1 + 10 files changed, 485 insertions(+), 41 deletions(-) create mode 100644 internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py create mode 120000 train.py diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 27847a5e8..0a19f137c 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -22,16 +22,7 @@ CHECKPOINT_EVERY = 50 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. - enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), - load_ckpt_folder="local:llm_ckpts/", - # 'load_ckpt_info' setting guide: - # 1. the 'path' indicate ckpt path, - # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined - # load function such as "llama" - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) # with an automatic restart mechanism upon training reboot. @@ -39,7 +30,7 @@ # path specified in `load_ckpt_info` by default. # If you want to initialize your model weights from another model, you must set `auto_resume` to False. # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. - auto_resume=True, + auto_resume=False, checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. @@ -144,14 +135,12 @@ model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, - embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, - apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 8e55883ea..8e36b7745 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -582,7 +582,7 @@ def try_resume_training(self, train_state: TrainState, current_time=""): f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) - elif is_using_fsdp() and is_using_hf() and not self.auto_resume: + elif is_using_fsdp() and not self.auto_resume: pass else: load_path = self.load_ckpt_info["path"] diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py index 1803334ac..64435b180 100644 --- a/internlm/core/fsdp.py +++ b/internlm/core/fsdp.py @@ -33,8 +33,10 @@ FSDP2_SUPPORTED = False try: + import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.state_dict import ( StateDictOptions, + get_model_state_dict, set_model_state_dict, ) @@ -163,8 +165,29 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): ) fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1") fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") + if gpc.is_using_parallel_mode(ParallelMode.EXPERT): + assert gpc.get_world_size(ParallelMode.EXPERT_DATA) * gpc.get_world_size(ParallelMode.EXPERT) == gpc.get_world_size(ParallelMode.GLOBAL) if fsdp_mode == "v1": + ignored_mod = [] + if gpc.is_using_parallel_mode(ParallelMode.EXPERT): + for layer_id, layer in enumerate(model.model.layers): + if layer_id >= gpc.config.model.first_k_dense_replace: + # Should follow this modeling pattern if EP is enabled. + # Change the expert module name if needed. + # TODO: Make this part hard-coded or config-driven? + layer.feed_forward.moe_layer.experts = FSDP( + layer.feed_forward.moe_layer.experts, + process_group=gpc.get_group(ParallelMode.EXPERT_DATA), + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=fsdp_init_method != "cuda", # sync model paramters + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + use_orig_params=True, + device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states + ) + ignored_mod.append(layer.feed_forward.moe_layer.experts) model = FSDP( module=model, process_group=gpc.get_group(ParallelMode.GLOBAL), @@ -176,6 +199,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): limit_all_gathers=True, use_orig_params=True, device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states + ignored_modules=ignored_mod, ) # For FSDP v1, to get ckpt resuming work normally, we do dummy forward. # This hack is needed due to FSDP v1 lazy initialization in model construction. @@ -196,7 +220,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): else: raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}") - if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False): + if not gpc.config.ckpt.get("auto_resume", False): load_ckpt_info = gpc.config.ckpt.load_ckpt_info load_ckpt_path = load_ckpt_info.get("path", None) load_ckpt_content = load_ckpt_info.get("content", []) @@ -205,16 +229,22 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): "model", ), "If auto_resume=False and checkpoint path is given, only model can be loaded" if DCP_SUPPORTED: - hf = gpc.config.hf - mod = LazyObject(hf.mod, hf.mod_cls) - mod = mod.build() - state_dict = mod.from_pretrained( - pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True - ).state_dict() - state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} - set_model_state_dict( - model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) - ) + if is_using_hf(): + hf = gpc.config.hf + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + state_dict = mod.from_pretrained( + pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True + ).state_dict() + state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} + set_model_state_dict( + model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) + ) + else: + state_dict = get_model_state_dict(model=model) + state_dict = {key: state_dict[key].clone().detach() for key in state_dict} + dcp.load(state_dict=state_dict, checkpoint_id=load_ckpt_path) + set_model_state_dict(model=model, model_state_dict=state_dict) del state_dict internlm_accelerator.empty_cache() else: diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index 18d029aa1..d0956d82a 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -57,6 +57,8 @@ def dispatch_hf_config_before_launch(hf: dict) -> None: gpc.config.model.num_experts = model_config.num_experts elif hasattr(model_config, "n_routed_experts"): gpc.config.model.num_experts = model_config.n_routed_experts + if hasattr(model_config, "first_k_dense_replace"): + gpc.config.model.first_k_dense_replace = model_config.first_k_dense_replace def args_sanity_check(): @@ -306,8 +308,9 @@ def args_sanity_check(): logger.info(f"clip_grad_norm: {clip_grad_norm}") model = gpc.config.model - if "enable_qkv_fusion" not in model: - model._add_item("enable_qkv_fusion", True) + # TODO: should we set default value for enable_qkv_fusion? + # if "enable_qkv_fusion" not in model: + # model._add_item("enable_qkv_fusion", True) if "dtype" not in model: logger.warning("dtype is not set, use torch.float16 by defalut!") diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py index 7fb5c038d..2c2180ade 100644 --- a/internlm/initialize/initialize_optimizer.py +++ b/internlm/initialize/initialize_optimizer.py @@ -50,7 +50,7 @@ def split_params_into_different_groups_for_optimizer( if is_using_fsdp(): optimizer_mode = ParallelMode.GLOBAL - optimizer_mode_expert = ParallelMode.GLOBAL + optimizer_mode_expert = ParallelMode.EXPERT_DATA else: optimizer_mode = ParallelMode.ZERO1 optimizer_mode_expert = ParallelMode.EXPERT_DATA diff --git a/internlm/model/model_ops/ops/cross_entropy.py b/internlm/model/model_ops/ops/cross_entropy.py index 35de1b6ef..17b9f8c05 100644 --- a/internlm/model/model_ops/ops/cross_entropy.py +++ b/internlm/model/model_ops/ops/cross_entropy.py @@ -18,6 +18,7 @@ CrossEntropyApexVocabParallel, CrossEntropyLossApex, CrossEntropyPython, + CrossEntropyLossFlash, ) from internlm.utils.logger import get_logger @@ -86,17 +87,8 @@ def new_cross_entropy( assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None." - try: - from flash_attn.losses.cross_entropy import ( - CrossEntropyLoss as FlashCrossEntropyLoss, - ) - - flash_cross_entropy_impl = True - except (ModuleNotFoundError, ImportError): - flash_cross_entropy_impl = False - assert ( - gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl + gpc.config.model.get("use_flash_attn", False) ), "Only flash cross entropy support parallel_output" assert ( @@ -108,7 +100,7 @@ def new_cross_entropy( which may result loss divergency in long sequence." ) - return FlashCrossEntropyLoss( + return CrossEntropyLossFlash( ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, diff --git a/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py b/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py index 1f4b6630d..ad8c208b0 100644 --- a/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py +++ b/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py @@ -2,10 +2,12 @@ from .py_naive_loss import CrossEntropyPython from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss +from .flash_loss import CrossEntropyLossFlash __all__ = [ "CrossEntropyLossApex", "CrossEntropyPython", "CrossEntropyApexVocabParallel", "VocabSequenceParallelCrossEntropyLoss", + "CrossEntropyLossFlash", ] diff --git a/internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py new file mode 100644 index 000000000..baab79e54 --- /dev/null +++ b/internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py @@ -0,0 +1,412 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import triton +import triton.language as tl + +from typing import Tuple, Optional, Union + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + z_loss_ptr, + logits_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, + PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) +): + row_idx = tl.program_id(0) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + sum_logits = 0.0 # For smoothing + if not PRECOMPUTED_LSE: + # Statistics for online softmax + m_i = -float("inf") + l_i = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + cols = col_offset + tl.arange(0, BLOCK_SIZE) + logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + if HAS_SMOOTHING: + sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) + m_i_new = tl.maximum(m_i, tl.max(logits)) + l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) + m_i = m_i_new + lse = tl.log(l_i) + m_i + tl.store(lse_ptr + row_idx, lse) + else: + lse = tl.load(lse_ptr + row_idx) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx == ignore_index: + loss = 0.0 + z_loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= 0 and label_idx < n_cols: + logits_label = tl.load(logits_ptr + label_idx) * logit_scale + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - smoothing * sum_logits / total_classes + - (1 - smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss + if HAS_SMOOTHING: + loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 + tl.store(loss_ptr + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + row_idx, z_loss) + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignore_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + lse = tl.load(lse_ptr + row_idx) + probs = tl.exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_positive = 1.0 - smoothing + smooth_negative = smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) + + +class CrossEntropyLoss(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + logits, + labels, + precomputed_lse=None, + smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + ignore_index=-100, + inplace_backward=False, + process_group=None, + ): + # For some reason Triton generates wrong code when labels has dtype long and its address + # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. + if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: + labels = F.pad(labels, (0, 1))[..., :-1] + assert labels.data_ptr() % 16 == 0 + assert logit_scale > 0.0 + n_rows, n_cols = logits.shape + assert labels.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 + + if logits.stride(-1) != 1: + logits = logits.contiguous() + MAX_BLOCK_SIZE = 16 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if use_precomputed_lse: + assert precomputed_lse.shape == (n_rows,) + lse = precomputed_lse.contiguous() + else: + lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) + z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_fwd_kernel[(n_rows,)]( + losses, # data ptrs + lse, + z_losses, + logits, + labels, + smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + SPLIT=world_size > 1, + PRECOMPUTED_LSE=use_precomputed_lse, + num_warps=num_warps, + ) + + if world_size > 1: + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For labels not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, + # we just have to add the (global) lse. + # If there's smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the (global) lse. + losses += lse + if lse_square_scale != 0.0: + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(labels == ignore_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) + losses.masked_fill_(labels == ignore_index, 0.0) + + ctx.save_for_backward(logits, lse, labels) + ctx.mark_non_differentiable(z_losses) + ctx.smoothing = smoothing + ctx.logit_scale = logit_scale + ctx.lse_square_scale = lse_square_scale + ctx.ignore_index = ignore_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + return losses, z_losses + + @staticmethod + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + + logits, lse, labels = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + labels, + ctx.smoothing, + ctx.logit_scale, + ctx.lse_square_scale, + ctx.ignore_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + labels: torch.Tensor, + precomputed_lse: Optional[torch.Tensor] = None, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignore_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: (batch, vocab_size) + labels: (batch,) + label_smoothing: float + logit_scale: float. Multiply logits by this scale before calculating the loss. + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: (batch,), float + z_losses: (batch,), float + """ + return CrossEntropyLoss.apply( + logits, + labels, + precomputed_lse, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + inplace_backward, + process_group, + ) + + + +class CrossEntropyLossFlash(nn.Module): + def __init__( + self, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + inplace_backward=False, + process_group=None, + return_z_loss=False, + ): + """ + Arguments: + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. + """ + super().__init__() + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.inplace_backward = inplace_backward + self.process_group = process_group + self.return_z_loss = return_z_loss + + def forward(self, input, target, precomputed_lse=None): + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss, z_loss = cross_entropy_loss( + input, + target, + precomputed_lse=precomputed_lse, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + lse_square_scale=self.lse_square_scale, + ignore_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, + ) + if self.reduction == "mean": + loss = loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + loss = loss.sum() + else: + loss = loss + + if not self.return_z_loss: + return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 2d9bb755b..517051ef0 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -16,7 +16,7 @@ get_norm, release_param_grad, ) -from internlm.utils.common import get_tensor_norm, move_norm_to_cuda +from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda from internlm.utils.config import Config from internlm.utils.logger import get_logger @@ -37,6 +37,7 @@ def compute_norm( gradients: Iterable[torch.Tensor], parameters: Iterable[torch.Tensor], + zero_mode, ) -> float: """Get L2 norm Arguments: @@ -61,7 +62,17 @@ def compute_norm( if DTENSOR_SUPPORTED and isinstance(total_norm, DTensor): total_norm = total_norm.full_tensor() - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.GLOBAL)) + if gpc.is_using_parallel_mode(zero_mode): + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) + + # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce + # model and zero have been reduced!!! + if zero_mode == ParallelMode.EXPERT_DATA: + pg = gpc.get_group(ParallelMode.EXPERT) + scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm_tensor, group=pg) + total_norm = scaled_norm_tensor.item() if torch.is_tensor(total_norm): total_norm = total_norm.item() @@ -112,10 +123,14 @@ def __init__( # fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group self._fp16_param_groups = dict() self._fp32_param_tensor_groups = dict() + self._broadcast_parallel_mode = [] # init fp16 and fp32 params for group_idx, param_group in enumerate(self.optim.param_groups): group_params = param_group["params"] + + zero_mode = param_group["optimizer_mode"] + self._broadcast_parallel_mode.append(zero_mode) # fp16 FlatParam storage self._fp16_param_groups[group_idx] = group_params @@ -142,7 +157,7 @@ def _compute_norm_with_fsdp_flatten(self, group_id): norm_group = 0 if len(params) <= 0 or len(gradients) <= 0: return norm_group - norm_group = compute_norm(gradients=gradients, parameters=params) + norm_group = compute_norm(gradients=gradients, parameters=params, zero_mode=self._broadcast_parallel_mode[group_id]) return norm_group diff --git a/train.py b/train.py new file mode 120000 index 000000000..744178299 --- /dev/null +++ b/train.py @@ -0,0 +1 @@ +internlm/launcher/launch.py \ No newline at end of file From 3ca045fb914d783d96619b6ab76de66e1a082eb5 Mon Sep 17 00:00:00 2001 From: caizheng Date: Wed, 9 Apr 2025 14:48:00 +0800 Subject: [PATCH 20/32] experimentally support fsdp2 with ep, has problem with gmm kernel --- internlm/core/fsdp.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py index 64435b180..bce5047d4 100644 --- a/internlm/core/fsdp.py +++ b/internlm/core/fsdp.py @@ -27,7 +27,7 @@ try: from torch.distributed._composable.fsdp import fully_shard - + from torch.distributed.tensor import DeviceMesh FSDP2_SUPPORTED = True except (ImportError, ModuleNotFoundError): FSDP2_SUPPORTED = False @@ -209,10 +209,23 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): fsdp_kwargs = { "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True } - for module in model.modules(): - if isinstance(module, wrap_cls): - fully_shard(module, **fsdp_kwargs) - fully_shard(model, **fsdp_kwargs) + device_mesh = DeviceMesh.from_group( + group=[gpc.get_group(ParallelMode.EXPERT), gpc.get_group(ParallelMode.EXPERT_DATA)], + device_type="cuda", + mesh=torch.arange( + gpc.get_world_size(ParallelMode.GLOBAL), + dtype=torch.int, + ).view((gpc.get_world_size(ParallelMode.EXPERT), gpc.get_world_size(ParallelMode.EXPERT_DATA))), + mesh_dim_names=("ep", "edp"), + ) + for layer_id, layer in enumerate(model.model.layers): + if gpc.is_using_parallel_mode(ParallelMode.EXPERT) and layer_id >= gpc.config.model.first_k_dense_replace: + # Should follow this modeling pattern if EP is enabled. + # Change the expert module name if needed. + # TODO: Make this part hard-coded or config-driven? + fully_shard(layer.feed_forward.moe_layer.experts, mesh=device_mesh, **fsdp_kwargs) + fully_shard(layer, mesh=device_mesh._flatten(), **fsdp_kwargs) + fully_shard(model, mesh=device_mesh._flatten(), **fsdp_kwargs) if fsdp_init_method == "meta": _materialize_meta_module(model, set(), get_current_device()) elif fsdp_init_method == "cpu": From de37b035255f633ce3518e2970b9634b7cff73cf Mon Sep 17 00:00:00 2001 From: caizheng Date: Thu, 10 Apr 2025 15:00:32 +0800 Subject: [PATCH 21/32] typo fix --- internlm/utils/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index c9abe3a5a..c45f561fa 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -50,13 +50,7 @@ class ModelType(Enum): INTERNLM2 = 2 LLAMA2 = 3 INTERNLM_MoE = 4 - LLAVA = 5 - QWEN2 = 6 - BAICHUAN2 = 7 - GEMMA = 8 - QWEN2MOE = 9 - MIXTRALMOE = 10 - INTERNLM3 = 11 + INTERNLM3 = 5 class DataType(Enum): From ee0e5c98a5ddfa55973bb9f1614f3e7cecc8c067 Mon Sep 17 00:00:00 2001 From: caizheng Date: Fri, 11 Apr 2025 14:07:59 +0800 Subject: [PATCH 22/32] typo fix --- internlm/core/fsdp.py | 2 +- internlm/initialize/initialize_launcher.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py index bce5047d4..e56cde271 100644 --- a/internlm/core/fsdp.py +++ b/internlm/core/fsdp.py @@ -223,7 +223,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): # Should follow this modeling pattern if EP is enabled. # Change the expert module name if needed. # TODO: Make this part hard-coded or config-driven? - fully_shard(layer.feed_forward.moe_layer.experts, mesh=device_mesh, **fsdp_kwargs) + fully_shard(layer.feed_forward.moe_layer.experts, mesh=device_mesh["edp"], **fsdp_kwargs) fully_shard(layer, mesh=device_mesh._flatten(), **fsdp_kwargs) fully_shard(model, mesh=device_mesh._flatten(), **fsdp_kwargs) if fsdp_init_method == "meta": diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index d0956d82a..4f971cb99 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -659,13 +659,15 @@ def args_sanity_check(): assert "init_method" in gpc.config.parallel.fsdp, "init_method must be specified in fsdp when enabled" fsdp_init_method = gpc.config.parallel.fsdp.init_method if fsdp_mode == "v1": + fsdp_v1_min_version = "1.13.0" assert ( - torch.__version__ >= "2.4.0" - ), f"requires torch>=2.4.0 when using fsdp v1 but current version is {torch.__version__}" + torch.__version__ >= fsdp_v1_min_version + ), f"requires torch>={fsdp_v1_min_version} when using fsdp v1 but current version is {torch.__version__}" elif fsdp_mode == "v2": + fsdp_v2_min_version = "2.6.0" assert ( - torch.__version__ >= "2.5.1" - ), f"requires torch>=2.5.1 when using fsdp v2 but current version is {torch.__version__}" + torch.__version__ >= fsdp_v2_min_version + ), f"requires torch>={fsdp_v2_min_version} when using fsdp v2 but current version is {torch.__version__}" else: raise ValueError(f"fsdp mode {fsdp_mode} not supported") assert fsdp_init_method in ["cuda", "cpu", "meta"], f"fsdp init_method {fsdp_init_method} not supported" From 8369b463dec9ad13f0bf6b9db07b781cf97001b2 Mon Sep 17 00:00:00 2001 From: caizheng Date: Mon, 14 Apr 2025 14:48:30 +0800 Subject: [PATCH 23/32] experimentally support fsdp2+ep, but gmm has error when enabling checkpoint and ep together --- internlm/core/fsdp.py | 24 +++++++++++++----------- internlm/solver/optimizer/utils.py | 12 +++++++++++- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py index e56cde271..30c85170f 100644 --- a/internlm/core/fsdp.py +++ b/internlm/core/fsdp.py @@ -209,23 +209,25 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): fsdp_kwargs = { "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True } - device_mesh = DeviceMesh.from_group( - group=[gpc.get_group(ParallelMode.EXPERT), gpc.get_group(ParallelMode.EXPERT_DATA)], - device_type="cuda", - mesh=torch.arange( - gpc.get_world_size(ParallelMode.GLOBAL), - dtype=torch.int, - ).view((gpc.get_world_size(ParallelMode.EXPERT), gpc.get_world_size(ParallelMode.EXPERT_DATA))), - mesh_dim_names=("ep", "edp"), - ) + if gpc.is_using_parallel_mode(ParallelMode.EXPERT): + device_mesh = DeviceMesh.from_group( + group=[gpc.get_group(ParallelMode.EXPERT), gpc.get_group(ParallelMode.EXPERT_DATA)], + device_type="cuda", + mesh=torch.arange( + gpc.get_world_size(ParallelMode.GLOBAL), + ).view((gpc.get_world_size(ParallelMode.EXPERT), gpc.get_world_size(ParallelMode.EXPERT_DATA))), + mesh_dim_names=("ep", "edp"), + ) for layer_id, layer in enumerate(model.model.layers): if gpc.is_using_parallel_mode(ParallelMode.EXPERT) and layer_id >= gpc.config.model.first_k_dense_replace: # Should follow this modeling pattern if EP is enabled. # Change the expert module name if needed. # TODO: Make this part hard-coded or config-driven? fully_shard(layer.feed_forward.moe_layer.experts, mesh=device_mesh["edp"], **fsdp_kwargs) - fully_shard(layer, mesh=device_mesh._flatten(), **fsdp_kwargs) - fully_shard(model, mesh=device_mesh._flatten(), **fsdp_kwargs) + for module in model.modules(): + if isinstance(module, wrap_cls): + fully_shard(module, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) if fsdp_init_method == "meta": _materialize_meta_module(model, set(), get_current_device()) elif fsdp_init_method == "cpu": diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index b3532dfb0..55f9629f7 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -36,6 +36,13 @@ logger.warning("The torch implementation for cal_l2norm is slower than apex. Please note this!") APEX_AVAILABLE = False +try: + from torch.distributed.tensor import DTensor + + DTENSOR_SUPPORTED = True +except (ModuleNotFoundError, ImportError): + DTENSOR_SUPPORTED = False + inf = math.inf @@ -177,7 +184,10 @@ def sync_param(flat_tensor, tensor_list): def multi_tensor_l2norm_torch(tensor_list, per_tensor): # Convert tensor_list elements to torch.float32 - tensor_list = [tensor.float() for tensor in tensor_list] + tensor_list = [ + tensor.full_tensor().float() if DTENSOR_SUPPORTED and isinstance(tensor, DTensor) else tensor.float() + for tensor in tensor_list + ] norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list]) l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0) From 051df45ab93a6b298a58b88010378c279c23901b Mon Sep 17 00:00:00 2001 From: caizheng Date: Mon, 14 Apr 2025 15:17:41 +0800 Subject: [PATCH 24/32] bug fix for fsdp2+ep --- internlm/solver/optimizer/utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 55f9629f7..b3532dfb0 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -36,13 +36,6 @@ logger.warning("The torch implementation for cal_l2norm is slower than apex. Please note this!") APEX_AVAILABLE = False -try: - from torch.distributed.tensor import DTensor - - DTENSOR_SUPPORTED = True -except (ModuleNotFoundError, ImportError): - DTENSOR_SUPPORTED = False - inf = math.inf @@ -184,10 +177,7 @@ def sync_param(flat_tensor, tensor_list): def multi_tensor_l2norm_torch(tensor_list, per_tensor): # Convert tensor_list elements to torch.float32 - tensor_list = [ - tensor.full_tensor().float() if DTENSOR_SUPPORTED and isinstance(tensor, DTensor) else tensor.float() - for tensor in tensor_list - ] + tensor_list = [tensor.float() for tensor in tensor_list] norms_tensor = torch.stack([torch.norm(tensor, p=2) for tensor in tensor_list]) l2_norm = torch.norm(norms_tensor, p=2).unsqueeze(0) From 889dbb0f68f0ea1629533d02f48320e38e3c1e19 Mon Sep 17 00:00:00 2001 From: caizheng Date: Thu, 17 Apr 2025 11:56:29 +0800 Subject: [PATCH 25/32] support mtp --- internlm/core/engine.py | 12 ++++ internlm/core/naive_amp.py | 2 + .../core/scheduler/no_pipeline_scheduler.py | 64 ++++++++++++++++++- internlm/core/trainer_builder.py | 18 ++++++ internlm/initialize/initialize_trainer.py | 2 + 5 files changed, 96 insertions(+), 2 deletions(-) diff --git a/internlm/core/engine.py b/internlm/core/engine.py index f6de9aebf..cfa3ac6a1 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -72,6 +72,7 @@ def __init__( lr_scheduler: Optional[_LRScheduler] = None, beta2_scheduler: Optional[Beta2Scheduler] = None, criterion: Optional[_Loss] = None, + mtp_criterions: Optional[List[_Loss]] = None, gradient_handlers: Optional[List[BaseGradientHandler]] = None, clip_grad_norm: float = 0.0, ): @@ -80,6 +81,7 @@ def __init__( self._lr_scheduler = lr_scheduler self._beta2_scheduler = beta2_scheduler self._criterion = criterion + self._mtp_criterions = mtp_criterions self._clip_grad_norm = clip_grad_norm # state @@ -105,6 +107,16 @@ def __init__( ), # {'max', 'most_recent'}, default = "max". Algorithm used for choosing amax ) + @property + def mtp_criterions(self): + """Returns the criterion (loss function) attached to the engine.""" + return self._mtp_criterions + + @mtp_criterions.setter + def mtp_criterions(self, mtp_criterions): + """Set the criterion (loss function) attached to the engine.""" + self._mtp_criterions = mtp_criterions + @property def model(self): """Returns the model attached to the engine.""" diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 177a5c1c4..006c1a2ce 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -94,6 +94,8 @@ def _convert_to_fp32(self, input_: Any): """Converts the input to fp32 if it is a Tensor of dtype float16.""" if isinstance(input_, Tensor) and input_.dtype in (torch.float16, torch.bfloat16): input_ = input_.float() + elif isinstance(input_, (tuple, list)): + input_ = [self._convert_to_fp32(val) for val in input_] return input_ def convert_to_fp32(self, out): diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 7e309beb6..ff7b86e64 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -59,6 +59,44 @@ def __init__( super().__init__(data_process_func) + def _call_engine_mtp_criterion(self, engine: Engine, outputs: Any, labels: Any): + """Calls the engine's criterion with the given outputs and labels. + Args: + engine (internlm.core.Engine): InternLM engine for training and inference. + outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict. + labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict. + """ + assert isinstance( + outputs, (torch.Tensor, list, tuple, dict) + ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}" + + mtp_losses = [] + for i, (output, label) in enumerate(zip(outputs, labels)): + if isinstance(output, torch.Tensor): + output = (output,) + if isinstance(label, torch.Tensor): + label = (label,) + + self._call_hooks("before_criterion", output, label) + if isinstance(output, (tuple, list)) and isinstance(label, (tuple, list)): + mtp_loss = engine.mtp_criterions[i](*output, *label) + elif isinstance(output, (tuple, list)) and isinstance(label, dict): + mtp_loss = engine.mtp_criterions[i](*output, **label) + elif isinstance(output, dict) and isinstance(label, dict): + mtp_loss = engine.mtp_criterions[i](**output, **label) + elif isinstance(output, dict) and isinstance(label, (list, tuple)): + raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(label)}") + else: + raise TypeError( + f"Expected model outputs and labels to be of type torch.Tensor ' \ + '(which is auto-converted to tuple), list, tuple, or dict, ' \ + 'but got {type(output)} (model outputs) and {type(label)} (labels)" + ) + self._call_hooks("after_criterion", mtp_loss) + mtp_losses.append(mtp_loss) + + return mtp_losses + def pre_processing(self, engine: Engine): """Performs actions before running the schedule. @@ -116,8 +154,10 @@ def _train_one_batch( with conditional_context(torch.no_grad(), enable=forward_only): self._call_hooks("before_forward", data) if hasattr(gpc.config.model, "num_experts"): - # moe is used - output, moe_losses = self._call_engine(engine, data) + if hasattr(gpc.config.model, "num_mtp_layers") and gpc.config.model.num_mtp_layers > 0: + output, moe_losses, mtp_outputs = self._call_engine(engine, data) + else: + output, moe_losses = self._call_engine(engine, data) else: output = self._call_engine(engine, data) self._call_hooks("after_forward", output) @@ -128,6 +168,26 @@ def _train_one_batch( self._call_hooks("before_criterion", output, label) loss = self._call_engine_criterion(engine, output, label) self._call_hooks("after_criterion", loss) + + if hasattr(gpc.config.model, "num_mtp_layers") and gpc.config.model.num_mtp_layers > 0: + mtp_labels = [] + for i in range(gpc.config.model.num_mtp_layers): + mtp_labels.append( + torch.cat( + [ + label[:, i + 1 :], + torch.full((label.size(0), i + 1), -100, dtype=label.dtype, device=label.device), + ], + dim=1, + ) + ) + mtp_losses = self._call_engine_mtp_criterion(engine, mtp_outputs, mtp_labels) + mtp_loss = sum(mtp_losses) * gpc.config.loss.mtp_loss_coeff + mtp_loss /= scale_loss + loss += mtp_loss + else: + mtp_loss = None + moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606 if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index a442da773..5efe892cb 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -114,6 +114,9 @@ def __init__( # initialize loss function criterion = self._initialize_criterion() + # initialize mtp loss function + mtp_criterions = self._initialize_mtp_criterion() + # initialize cpu offload manager for selective checkpoint initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) @@ -149,6 +152,7 @@ def __init__( model=model, optimizer=optimizer, criterion=criterion, + mtp_criterions=mtp_criterions, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator), @@ -161,6 +165,20 @@ def __init__( super().__init__(engine, scheduler) + def _initialize_mtp_criterion(self) -> InternLoss: + if hasattr(gpc.config.model, "num_mtp_layers") and gpc.config.model.num_mtp_layers > 0: + mtp_criterions = [] + for _ in range(gpc.config.model.num_mtp_layers): + mtp_criterion = InternLoss( + parallel_output=gpc.config.model.parallel_output, + label_smoothing=gpc.config.loss.label_smoothing, + op_type=gpc.config.loss.op_type, + ) + mtp_criterions.append(mtp_criterion) + else: + mtp_criterions = [] + return mtp_criterions + def _setup_time_and_logging(self) -> str: current_time = launch_time() objs = [current_time] diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 5b8cd9f35..71e974899 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -36,6 +36,7 @@ def initialize_trainer( model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, + mtp_criterions: Optional[List[_Loss]] = None, lr_scheduler: Optional[_LRScheduler] = None, beta2_scheduler: Optional[Beta2Scheduler] = None, scheduler_hooks: Optional[List[SchedulerHook]] = None, @@ -166,6 +167,7 @@ def _data_preparation_func(_data, _label): lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, criterion=criterion, + mtp_criterions=mtp_criterions, gradient_handlers=gradient_handlers, clip_grad_norm=clip_grad_norm, ) From 38d48aad92d6f73547d0cac45dccdb473ea5c95a Mon Sep 17 00:00:00 2001 From: caizheng Date: Thu, 17 Apr 2025 14:36:39 +0800 Subject: [PATCH 26/32] fix for mtp mode --- internlm/initialize/initialize_launcher.py | 2 ++ internlm/initialize/initialize_model.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index 4f971cb99..de249e2cc 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -59,6 +59,8 @@ def dispatch_hf_config_before_launch(hf: dict) -> None: gpc.config.model.num_experts = model_config.n_routed_experts if hasattr(model_config, "first_k_dense_replace"): gpc.config.model.first_k_dense_replace = model_config.first_k_dense_replace + if hasattr(model_config, "num_nextn_predict_layers"): + gpc.config.model.num_mtp_layers = model_config.num_nextn_predict_layers def args_sanity_check(): diff --git a/internlm/initialize/initialize_model.py b/internlm/initialize/initialize_model.py index 6f7514285..94541c529 100644 --- a/internlm/initialize/initialize_model.py +++ b/internlm/initialize/initialize_model.py @@ -304,8 +304,8 @@ def initialize_model_and_parallel_communicator(model: Optional[Union[nn.Module, register_model_initializer() model = create_model() - # For non-HF cases, set tracking name for parameters - if not is_using_hf(): + # For non-HF or non-FSDP cases, set tracking name for parameters + if not is_using_hf() and not is_using_fsdp(): set_param_unique_tracking_name(model) # should be set before NaiveAMPModel From 4edba34492560f24a98155e000b32c7d74354313 Mon Sep 17 00:00:00 2001 From: caizheng Date: Fri, 18 Apr 2025 15:11:56 +0800 Subject: [PATCH 27/32] decouple expert_group_name from modeling to optimizer param groups --- internlm/initialize/initialize_optimizer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py index 2c2180ade..41e12e91a 100644 --- a/internlm/initialize/initialize_optimizer.py +++ b/internlm/initialize/initialize_optimizer.py @@ -51,16 +51,19 @@ def split_params_into_different_groups_for_optimizer( if is_using_fsdp(): optimizer_mode = ParallelMode.GLOBAL optimizer_mode_expert = ParallelMode.EXPERT_DATA + expert_group_name = f"moe_ep_size_{gpc.get_world_size(ParallelMode.EXPERT)}" + expert_parallel_group_names = [expert_group_name] else: optimizer_mode = ParallelMode.ZERO1 optimizer_mode_expert = ParallelMode.EXPERT_DATA + expert_parallel_group_names = gpc.expert_parallel_group_names new_groups = {} # create new groups for fp32 parameter group new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": optimizer_mode} if gpc.config.model.get("num_experts", 1) > 1: - for key in gpc.expert_parallel_group_names: + for key in expert_parallel_group_names: new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": optimizer_mode_expert} for pgroup in param_groups: @@ -75,7 +78,10 @@ def split_params_into_different_groups_for_optimizer( for param in pgroup["params"]: # moe param means MoE is enabled if is_moe_param(param): - new_groups[param.group_name]["params"].append(param) + if is_using_fsdp(): + new_groups[expert_group_name]["params"].append(param) + else: + new_groups[param.group_name]["params"].append(param) elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32: new_groups["fp32"]["params"].append(param) else: From 97fafcccb48a05bdaa8b0c02ab3af9e4b6d5a4f1 Mon Sep 17 00:00:00 2001 From: caizheng Date: Fri, 18 Apr 2025 16:05:43 +0800 Subject: [PATCH 28/32] add ep_group as tmp workaround --- internlm/model/model_implementations/builder.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internlm/model/model_implementations/builder.py b/internlm/model/model_implementations/builder.py index 7393837de..ac43c7fb0 100644 --- a/internlm/model/model_implementations/builder.py +++ b/internlm/model/model_implementations/builder.py @@ -78,6 +78,13 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: model_buidler = model_initializer.get_module(module_name=model_type) + if ( + is_using_fsdp() + and hasattr(gpc.config.model, "num_experts") + and gpc.config.model.num_experts > 1 + ): + kwargs["ep_group"] = gpc.get_group(ParallelMode.EXPERT) + if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): kwargs["first"] = kwargs["last"] = True kwargs["start_layer_idx"] = 0 From e2b27ab9894c8eb5a9949cc00f1117ece5080278 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 22 Apr 2025 19:45:12 +0800 Subject: [PATCH 29/32] some refinement for fsdp+ep --- internlm/core/trainer.py | 6 +++--- internlm/core/trainer_builder.py | 4 ++-- internlm/initialize/initialize_launcher.py | 4 +++- internlm/model/model_implementations/builder.py | 8 ++++++++ internlm/solver/optimizer/fsdp_optimizer.py | 7 ++----- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 6fd85886d..0579c2e8c 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -387,7 +387,7 @@ def record_current_batch_training_metrics( infos = { "tflops": tflops, "step": batch_count, - "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), + "loss": loss - moe_loss if moe_loss is not None else loss, "real_tgs": real_tgs, "tgs (tokens/gpu/second)": tgs_origin, "tgs/last_tgs_1": last_tgs_1, @@ -401,7 +401,7 @@ def record_current_batch_training_metrics( "grad_norm": grad_norm, } if moe_loss is not None: - infos["moe_loss"] = moe_loss.item() + infos["moe_loss"] = moe_loss infos["micro_num"] = len(batch[1]) infos["num_consumed_tokens"] = train_state.num_consumed_tokens @@ -434,5 +434,5 @@ def record_current_batch_training_metrics( mm.monitor_loss_spike( alert_address=gpc.config.monitor.alert.feishu_alert_address, step_count=batch_count, - cur_step_loss=loss.item(), + cur_step_loss=loss, ) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 5efe892cb..53440979b 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -383,8 +383,8 @@ def _record_metrics(self, batch_count: int, batch, start_time, loss, moe_loss, s engine=self.engine, start_time=start_time, very_begining_time=self.very_beginning_time, - loss=loss, - moe_loss=moe_loss, + loss=loss.item() if isinstance(loss, torch.Tensor) else loss, + moe_loss=moe_loss.item() if isinstance(moe_loss, torch.Tensor) else moe_loss, grad_norm=grad_norm_groups, metric=self.metric, ) diff --git a/internlm/initialize/initialize_launcher.py b/internlm/initialize/initialize_launcher.py index de249e2cc..b80b261c1 100644 --- a/internlm/initialize/initialize_launcher.py +++ b/internlm/initialize/initialize_launcher.py @@ -59,7 +59,9 @@ def dispatch_hf_config_before_launch(hf: dict) -> None: gpc.config.model.num_experts = model_config.n_routed_experts if hasattr(model_config, "first_k_dense_replace"): gpc.config.model.first_k_dense_replace = model_config.first_k_dense_replace - if hasattr(model_config, "num_nextn_predict_layers"): + if hasattr(model_config, "num_mtp_layers"): + gpc.config.model.num_mtp_layers = model_config.num_mtp_layers + elif hasattr(model_config, "num_nextn_predict_layers"): gpc.config.model.num_mtp_layers = model_config.num_nextn_predict_layers diff --git a/internlm/model/model_implementations/builder.py b/internlm/model/model_implementations/builder.py index ac43c7fb0..b1593183f 100644 --- a/internlm/model/model_implementations/builder.py +++ b/internlm/model/model_implementations/builder.py @@ -82,6 +82,7 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: is_using_fsdp() and hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 + and "ep_group" in kwargs ): kwargs["ep_group"] = gpc.get_group(ParallelMode.EXPERT) @@ -108,6 +109,13 @@ def create_model_hf(hf: dict) -> nn.Module: cfg = cfg.build() mod = LazyObject(hf.mod, hf.mod_cls) mod = mod.build() + if ( + is_using_fsdp() + and hasattr(gpc.config.model, "num_experts") + and gpc.config.model.num_experts > 1 + and hasattr(cfg, "ep_group") + ): + cfg.ep_group = gpc.get_group(ParallelMode.EXPERT) assert is_using_fsdp(), "Curently HF models can only train with FSDP." diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 517051ef0..1e200e3aa 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -68,11 +68,8 @@ def compute_norm( # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce # model and zero have been reduced!!! if zero_mode == ParallelMode.EXPERT_DATA: - pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) - dist.all_reduce(scaled_norm_tensor, group=pg) - total_norm = scaled_norm_tensor.item() + total_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.EXPERT)) if torch.is_tensor(total_norm): total_norm = total_norm.item() From 837a91b1ce3071c8ff3f72c04afe5c7f525d1fe5 Mon Sep 17 00:00:00 2001 From: caizheng Date: Tue, 22 Apr 2025 19:49:44 +0800 Subject: [PATCH 30/32] some refinement for fsdp+ep --- internlm/core/fsdp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py index 30c85170f..5e4176ff3 100644 --- a/internlm/core/fsdp.py +++ b/internlm/core/fsdp.py @@ -218,12 +218,12 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): ).view((gpc.get_world_size(ParallelMode.EXPERT), gpc.get_world_size(ParallelMode.EXPERT_DATA))), mesh_dim_names=("ep", "edp"), ) - for layer_id, layer in enumerate(model.model.layers): - if gpc.is_using_parallel_mode(ParallelMode.EXPERT) and layer_id >= gpc.config.model.first_k_dense_replace: - # Should follow this modeling pattern if EP is enabled. - # Change the expert module name if needed. - # TODO: Make this part hard-coded or config-driven? - fully_shard(layer.feed_forward.moe_layer.experts, mesh=device_mesh["edp"], **fsdp_kwargs) + for layer_id, layer in enumerate(model.model.layers): + if layer_id >= gpc.config.model.first_k_dense_replace: + # Should follow this modeling pattern if EP is enabled. + # Change the expert module name if needed. + # TODO: Make this part hard-coded or config-driven? + fully_shard(layer.feed_forward.moe_layer.experts, mesh=device_mesh["edp"], **fsdp_kwargs) for module in model.modules(): if isinstance(module, wrap_cls): fully_shard(module, **fsdp_kwargs) From 1d6be7d612a67227f61c257b45ac3ca4f6f6b459 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Thu, 24 Apr 2025 19:35:35 +0800 Subject: [PATCH 31/32] add fix when using HF pretrained model and fsdp --- internlm/initialize/initialize_optimizer.py | 16 +++++++++++----- internlm/solver/optimizer/fsdp_optimizer.py | 5 +++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py index 41e12e91a..20d832c5c 100644 --- a/internlm/initialize/initialize_optimizer.py +++ b/internlm/initialize/initialize_optimizer.py @@ -15,7 +15,7 @@ ) from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.solver.schedulers import Beta2Scheduler, FineTuneCosineAnnealingWarmupLR -from internlm.utils.parallel import is_using_fsdp +from internlm.utils.parallel import is_using_fsdp, is_using_hf from internlm.utils.timeout import llm_timeout @@ -75,11 +75,17 @@ def split_params_into_different_groups_for_optimizer( group[ori_key] = pgroup[ori_key] # assign param origin_params = [] - for param in pgroup["params"]: + for named_param in pgroup["params"]: # moe param means MoE is enabled - if is_moe_param(param): + name, param = named_param + # NOTICE: param attribute would get lost with PretrainedModel+FSDP + # DoHack: we split expert param via name as complementary method + if is_moe_param(param) or "wrapped_experts" in name: if is_using_fsdp(): - new_groups[expert_group_name]["params"].append(param) + if gpc.is_using_parallel_mode(ParallelMode.EXPERT) or not is_using_hf(): + new_groups[expert_group_name]["params"].append(param) + else: + origin_params.append(param) else: new_groups[param.group_name]["params"].append(param) elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32: @@ -99,7 +105,7 @@ def split_params_into_different_groups_for_optimizer( def create_param_groups(model, weight_decay): parameters = { - "params": [param for param in model.parameters() if param.requires_grad], + "params": [(name, param) for name, param in model.named_parameters() if param.requires_grad], "name": "default", "weight_decay": weight_decay, } diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 1e200e3aa..d4f3cc811 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -68,8 +68,9 @@ def compute_norm( # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce # model and zero have been reduced!!! if zero_mode == ParallelMode.EXPERT_DATA: - total_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.EXPERT)) + scaled_norm = torch.tensor(total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)), device=get_current_device(), dtype=torch.float) + dist.all_reduce(scaled_norm, group=gpc.get_group(ParallelMode.EXPERT)) + total_norm = scaled_norm.item() if torch.is_tensor(total_norm): total_norm = total_norm.item() From f8cffba1a5a34f1a4173db38be45dc9d2fda7abd Mon Sep 17 00:00:00 2001 From: caizheng Date: Thu, 8 May 2025 03:15:49 +0000 Subject: [PATCH 32/32] fix moe --- internlm/initialize/initialize_optimizer.py | 2 +- internlm/model/model_implementations/builder.py | 14 -------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py index 20d832c5c..4303cf5ca 100644 --- a/internlm/initialize/initialize_optimizer.py +++ b/internlm/initialize/initialize_optimizer.py @@ -80,7 +80,7 @@ def split_params_into_different_groups_for_optimizer( name, param = named_param # NOTICE: param attribute would get lost with PretrainedModel+FSDP # DoHack: we split expert param via name as complementary method - if is_moe_param(param) or "wrapped_experts" in name: + if is_moe_param(param) or "experts" in name: if is_using_fsdp(): if gpc.is_using_parallel_mode(ParallelMode.EXPERT) or not is_using_hf(): new_groups[expert_group_name]["params"].append(param) diff --git a/internlm/model/model_implementations/builder.py b/internlm/model/model_implementations/builder.py index b1593183f..63bbe468f 100644 --- a/internlm/model/model_implementations/builder.py +++ b/internlm/model/model_implementations/builder.py @@ -78,13 +78,6 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: model_buidler = model_initializer.get_module(module_name=model_type) - if ( - is_using_fsdp() - and hasattr(gpc.config.model, "num_experts") - and gpc.config.model.num_experts > 1 - and "ep_group" in kwargs - ): - kwargs["ep_group"] = gpc.get_group(ParallelMode.EXPERT) if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): kwargs["first"] = kwargs["last"] = True @@ -109,13 +102,6 @@ def create_model_hf(hf: dict) -> nn.Module: cfg = cfg.build() mod = LazyObject(hf.mod, hf.mod_cls) mod = mod.build() - if ( - is_using_fsdp() - and hasattr(gpc.config.model, "num_experts") - and gpc.config.model.num_experts > 1 - and hasattr(cfg, "ep_group") - ): - cfg.ep_group = gpc.get_group(ParallelMode.EXPERT) assert is_using_fsdp(), "Curently HF models can only train with FSDP."