Skip to content

Commit 24180aa

Browse files
feat(*): Add internlm3 config (#403)
1 parent 02ea919 commit 24180aa

20 files changed

+387
-50
lines changed

README-ja-JP.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
141141
</td>
142142
<td>
143143
<ul>
144-
<li><a href="configs/_base_/models/internlm/internlm_7B.py">InternLM</a></li>
145-
<li><a href="configs/_base_/models/internlm/internlm2_7B.py">InternLM2</a></li>
144+
<li><a href="configs/7B_isp_sft.py">InternLM</a></li>
145+
<li><a href="configs/7B_internlm2.py">InternLM2</a></li>
146+
<li><a href="configs/8B_internlm3.py">InternLM3</a></li>
146147
<li><a href="configs/7B_llama2.py">Llama2</a></li>
147148
<li><a href="configs/7B_qwen2.py">Qwen2</a></li>
148149
<li><a href="configs/7B_baichuan2.py">Baichuan2</a></li>

README-zh-Hans.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
141141
</td>
142142
<td>
143143
<ul>
144-
<li><a href="configs/_base_/models/internlm/internlm_7B.py">InternLM</a></li>
145-
<li><a href="configs/_base_/models/internlm/internlm2_7B.py">InternLM2</a></li>
144+
<li><a href="configs/7B_isp_sft.py">InternLM</a></li>
145+
<li><a href="configs/7B_internlm2.py">InternLM2</a></li>
146+
<li><a href="configs/8B_internlm3.py">InternLM3</a></li>
146147
<li><a href="configs/7B_llama2.py">Llama2</a></li>
147148
<li><a href="configs/7B_qwen2.py">Qwen2</a></li>
148149
<li><a href="configs/7B_baichuan2.py">Baichuan2</a></li>

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
141141
</td>
142142
<td>
143143
<ul>
144-
<li><a href="configs/_base_/models/internlm/internlm_7B.py">InternLM</a></li>
145-
<li><a href="configs/_base_/models/internlm/internlm2_7B.py">InternLM2</a></li>
144+
<li><a href="configs/7B_isp_sft.py">InternLM</a></li>
145+
<li><a href="configs/7B_internlm2.py">InternLM2</a></li>
146+
<li><a href="configs/8B_internlm3.py">InternLM3</a></li>
146147
<li><a href="configs/7B_llama2.py">Llama2</a></li>
147148
<li><a href="configs/7B_qwen2.py">Qwen2</a></li>
148149
<li><a href="configs/7B_baichuan2.py">Baichuan2</a></li>

configs/7B_internlm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
JOB_NAME = "7b_internlm2_train"
2-
model_type = "INTERNLM2_PUBLIC"
2+
model_type = "INTERNLM2"
33
DO_ALERT = False
44

55
VOCAB_SIZE = 92544

configs/7B_isp_sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
JOB_NAME = "7b_train"
2-
model_type = "INTERNLM2_PUBLIC"
2+
model_type = "INTERNLM2"
33
DO_ALERT = False
44

55
VOCAB_SIZE = 103168

configs/8B_internlm3.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) InternLM. All rights reserved.
2+
JOB_NAME = "8b_internlm3_train"
3+
model_type = "INTERNLM3"
4+
DO_ALERT = False
5+
6+
VOCAB_SIZE = 128512
7+
SEQ_LEN = 4096
8+
HIDDEN_SIZE = 4096
9+
NUM_ATTENTION_HEAD = 32
10+
NUM_KV_ATTENTION_HEAD = 2
11+
MLP_RATIO = 2.5
12+
NUM_LAYER = 48
13+
14+
15+
MODEL_ONLY_FOLDER = None # "local:llm_ckpts/xxxx"
16+
# Ckpt folder format:
17+
# fs: 'local:/mnt/nfs/XXX'
18+
SAVE_CKPT_FOLDER = None # "local:llm_ckpts"
19+
# LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
20+
21+
# boto3 Ckpt folder format:
22+
# import os
23+
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
24+
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
25+
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
26+
CHECKPOINT_EVERY = 50
27+
ckpt = dict(
28+
enable_save_ckpt=False, # enable ckpt save.
29+
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
30+
# 'load_ckpt_info' setting guide:
31+
# 1. the 'path' indicate ckpt path,
32+
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
33+
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
34+
# load function such as "llama"
35+
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
36+
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
37+
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
38+
# with an automatic restart mechanism upon training reboot.
39+
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
40+
# path specified in `load_ckpt_info` by default.
41+
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
42+
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
43+
auto_resume=False,
44+
checkpoint_every=CHECKPOINT_EVERY,
45+
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
46+
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
47+
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
48+
# 'enable_internevo2hf_ckpt' is designed to convert the saved model checkpoint in internevo format to the huggingface format.
49+
enable_internevo2hf_ckpt=False,
50+
)
51+
52+
TRAIN_FOLDER = None
53+
VALID_FOLDER = None # "/path/to/dataset"
54+
data = dict(
55+
seq_len=SEQ_LEN,
56+
# micro_num means the number of micro_batch contained in one gradient update
57+
micro_num=1,
58+
# packed_length = micro_bsz * SEQ_LEN
59+
micro_bsz=2,
60+
# defaults to the value of micro_num
61+
valid_micro_num=4,
62+
# defaults to 0, means disable evaluate
63+
valid_every=0,
64+
pack_sample_into_one=False,
65+
total_steps=20000,
66+
skip_batches="",
67+
# rampup_batch_size (str): A string with three space-separated integers representing the
68+
# starting batch size, the increment, and the number of steps between
69+
# each increment. For example, "192 24 8" means that the batch size (micro_num)
70+
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
71+
# (IMPORTANT): The interval step size is 'micro_bsz'.
72+
rampup_batch_size="",
73+
# Datasets with less than 50 rows will be discarded
74+
min_length=50,
75+
train_folder=TRAIN_FOLDER,
76+
valid_folder=VALID_FOLDER,
77+
empty_cache_and_diag_interval=200,
78+
diag_outlier_ratio=1.1,
79+
# use_packed_dataset=False,
80+
)
81+
82+
grad_scaler = dict(
83+
fp16=dict(
84+
# the initial loss scale, defaults to 2**16
85+
initial_scale=2**16,
86+
# the minimum loss scale, defaults to None
87+
min_scale=1,
88+
# the number of steps to increase loss scale when no overflow occurs
89+
growth_interval=1000,
90+
),
91+
# the multiplication factor for increasing loss scale, defaults to 2
92+
growth_factor=2,
93+
# the multiplication factor for decreasing loss scale, defaults to 0.5
94+
backoff_factor=0.5,
95+
# the maximum loss scale, defaults to None
96+
max_scale=2**24,
97+
# the number of overflows before decreasing loss scale, defaults to 2
98+
hysteresis=2,
99+
)
100+
101+
hybrid_zero_optimizer = dict(
102+
# Enable low_level_optimzer overlap_communication
103+
overlap_sync_grad=True,
104+
overlap_sync_param=True,
105+
# bucket size for nccl communication params
106+
reduce_bucket_size=512 * 1024 * 1024,
107+
# grad clipping
108+
clip_grad_norm=1.0,
109+
)
110+
111+
112+
# loss config (dict):
113+
# 1. label_smoothing
114+
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
115+
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
116+
# default is "py_vocab_parallel".
117+
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
118+
# "apex_naive": cross_entropy from apex
119+
# "py_naive": self-implemented cross_entropy
120+
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
121+
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy
122+
123+
# * op_types that ends with "naive" only support parallel_output=False;
124+
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.
125+
loss = dict(label_smoothing=0, op_type="py_vocab_parallel")
126+
127+
adam = dict(
128+
lr=1e-4,
129+
adam_beta1=0.9,
130+
adam_beta2=0.95,
131+
adam_beta2_c=0,
132+
adam_eps=1e-8,
133+
weight_decay=0.01,
134+
)
135+
136+
lr_scheduler = dict(
137+
total_steps=data["total_steps"],
138+
init_steps=0, # optimizer_warmup_step
139+
warmup_ratio=0.01,
140+
eta_min=1e-5,
141+
last_epoch=-1,
142+
)
143+
144+
beta2_scheduler = dict(
145+
init_beta2=adam["adam_beta2"],
146+
c=adam["adam_beta2_c"],
147+
cur_iter=-1,
148+
)
149+
150+
use_fp32_norm = False
151+
model = dict(
152+
checkpoint=False,
153+
num_chunks=1,
154+
num_attention_heads=NUM_ATTENTION_HEAD,
155+
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
156+
embed_split_hidden=True,
157+
vocab_size=VOCAB_SIZE,
158+
embed_grad_scale=1,
159+
parallel_output=True,
160+
hidden_size=HIDDEN_SIZE,
161+
num_layers=NUM_LAYER,
162+
no_bias=True,
163+
mlp_ratio=MLP_RATIO,
164+
apply_post_layer_norm=False,
165+
dtype="torch.bfloat16",
166+
norm_type="rmsnorm",
167+
layer_norm_epsilon=1e-5,
168+
use_flash_attn=True,
169+
# Whether the odd and even columns of the query and key in the model are normally interleaved.
170+
# If it's True, the model's odd and even columns are normally ordered; if it's False,
171+
# it means that the model has prematurely concatenated all odd columns and even columns in front
172+
# and back, in order to improve the RoPE's computational efficiency.
173+
# Example:
174+
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
175+
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
176+
qk_interleaved=False,
177+
rope_base=50000000,
178+
enable_qkv_fusion=False,
179+
)
180+
181+
"""
182+
zero1 parallel (dict):
183+
1. size: int
184+
* if size <= 0, the size of the zero process group is equal to the size of the dp process group,
185+
so parameters will be divided within the range of dp.
186+
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
187+
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
188+
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
189+
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
190+
tensor parallel (dict):
191+
1. size: int, the size of tensor parallel.
192+
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
193+
defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
194+
msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
195+
fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
196+
isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
197+
pipeline parallel (dict):
198+
1. size: int, the size of pipeline parallel.
199+
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
200+
defaults to False.
201+
weight parallel (dict):
202+
1. size: int, the size of weight parallel.
203+
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
204+
3. launch_allgather_before: str, before which module to launch the all gather communication to
205+
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
206+
Must be used with forward_overlap_per 'layer'.
207+
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
208+
sequence_2D (dict):
209+
1. enable: bool, whether enable the 2D sequence parallel or not.
210+
2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses).
211+
head_size * context_size should be equal tensor size.
212+
3. context_size: int, the parallel degree of context parallelism.
213+
head_size * context_size should be equal tensor size.
214+
4. window_size: int, the sliding window size in context parallelism.
215+
5. device_placement_strategy: dict,
216+
head_first: bool, if `True`, ranks of the same head parallel group are
217+
given high priority for colocation on the same node;
218+
if `False`, ranks of the same context parallel group are
219+
given high priority for colocation on the same node;
220+
interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could
221+
interleaved the ranks in the same window to make full use of NIC as much as possible.
222+
"""
223+
parallel = dict(
224+
zero1=dict(size=1),
225+
tensor=dict(size=1, mode="isp"),
226+
pipeline=dict(size=1, interleaved_overlap=True),
227+
weight=dict(size=16, overlap=True, launch_allgather_before="wo", forward_overlap_per="module"),
228+
sequence_2D=dict(
229+
enable=False,
230+
head_size=2,
231+
context_size=4,
232+
window_size=1,
233+
device_placement_strategy=dict(head_first=True, interleaved=False),
234+
),
235+
)
236+
237+
cudnn_deterministic = False
238+
cudnn_benchmark = False
239+
240+
monitor = dict(
241+
# feishu alert configs
242+
alert=dict(
243+
enable_feishu_alert=DO_ALERT,
244+
feishu_alert_address=None, # feishu webhook to send alert message
245+
light_monitor_address=None, # light_monitor address to send heartbeat
246+
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
247+
),
248+
tensorboard=dict(
249+
queue_max_length=10,
250+
),
251+
)
252+
253+
# metric_dtype can be "fp32" or other string
254+
# only when set to "fp32" will use fp32 to calc in metrics
255+
# metric_dtype = "fp32"
256+
257+
generation = dict(
258+
ckpt_folder="/path/to/saved/ckpt",
259+
output_folder="/path/to/save/generation",
260+
batch_size=1,
261+
eos_id=[2, 0],
262+
bos_id=1,
263+
max_length=100,
264+
do_sample=True,
265+
temperature=1.0,
266+
top_k=50,
267+
top_p=1.0,
268+
repetition_penalty=1,
269+
length_penalty=1.0,
270+
)

configs/_base_/models/internlm2_1B.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) InternLM. All rights reserved.
22

3-
model_type = "INTERNLM2_PUBLIC"
3+
model_type = "INTERNLM2"
44

55
VOCAB_SIZE = 92544
66
HIDDEN_SIZE = 2048

configs/_base_/models/internlm2_20B.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) InternLM. All rights reserved.
22

3-
model_type = "INTERNLM2_PUBLIC"
3+
model_type = "INTERNLM2"
44

55
VOCAB_SIZE = 92544
66
HIDDEN_SIZE = 6144

configs/_base_/models/internlm2_7B.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) InternLM. All rights reserved.
22

3-
model_type = "INTERNLM2_PUBLIC"
3+
model_type = "INTERNLM2"
44

55
VOCAB_SIZE = 92544
66
HIDDEN_SIZE = 4096

internlm/initialize/launch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python
22
# -*- encoding: utf-8 -*-
3+
# Copyright (c) InternLM. All rights reserved.
34

45
import argparse
56
import os
@@ -307,6 +308,9 @@ def args_sanity_check():
307308
logger.info(f"clip_grad_norm: {clip_grad_norm}")
308309

309310
model = gpc.config.model
311+
if "enable_qkv_fusion" not in model:
312+
model._add_item("enable_qkv_fusion", True)
313+
310314
if "dtype" not in model:
311315
logger.warning("dtype is not set, use torch.float16 by defalut!")
312316
model._add_item("dtype", torch.float16)

0 commit comments

Comments
 (0)