Skip to content

Commit 081c8ca

Browse files
authored
Add internlm2 5 cfgs (#872)
* add internlm2.5 configs * limit transformers <= 4.42.4
1 parent d81b366 commit 081c8ca

File tree

3 files changed

+424
-1
lines changed

3 files changed

+424
-1
lines changed

requirements/runtime.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,7 @@ torchvision
2121
# Registering a causal mask in `LlamaModel` is not friendly for very large
2222
# `max_position_embeddings`. Refer to
2323
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923
24-
transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2
24+
# transformers >= 4.43.0 use _flash_attention_forward but not self._flash_attention_forward
25+
# to calculate attn output which lead to bc braeking
26+
transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2,<=4.42.4
2527
transformers_stream_generator
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from datasets import load_dataset
4+
from mmengine.dataset import DefaultSampler
5+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6+
LoggerHook, ParamSchedulerHook)
7+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8+
from peft import LoraConfig
9+
from torch.optim import AdamW
10+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
11+
BitsAndBytesConfig)
12+
13+
from xtuner.dataset import process_hf_dataset
14+
from xtuner.dataset.collate_fns import default_collate_fn
15+
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
16+
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
17+
VarlenAttnArgsToMessageHubHook)
18+
from xtuner.engine.runner import TrainLoop
19+
from xtuner.model import SupervisedFinetune
20+
from xtuner.parallel.sequence import SequenceParallelSampler
21+
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
22+
23+
#######################################################################
24+
# PART 1 Settings #
25+
#######################################################################
26+
# Model
27+
pretrained_model_name_or_path = 'internlm/internlm2_5-20b-chat'
28+
use_varlen_attn = False
29+
30+
# Data
31+
alpaca_en_path = 'tatsu-lab/alpaca'
32+
prompt_template = PROMPT_TEMPLATE.internlm2_chat
33+
max_length = 2048
34+
pack_to_max_length = True
35+
36+
# parallel
37+
sequence_parallel_size = 1
38+
39+
# Scheduler & Optimizer
40+
batch_size = 1 # per_device
41+
accumulative_counts = 1
42+
accumulative_counts *= sequence_parallel_size
43+
dataloader_num_workers = 0
44+
max_epochs = 3
45+
optim_type = AdamW
46+
lr = 2e-5
47+
betas = (0.9, 0.999)
48+
weight_decay = 0
49+
max_norm = 1 # grad clip
50+
warmup_ratio = 0.03
51+
52+
# Save
53+
save_steps = 500
54+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
55+
56+
# Evaluate the generation performance during the training
57+
evaluation_freq = 500
58+
SYSTEM = SYSTEM_TEMPLATE.alpaca
59+
evaluation_inputs = [
60+
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
61+
]
62+
63+
#######################################################################
64+
# PART 2 Model & Tokenizer #
65+
#######################################################################
66+
tokenizer = dict(
67+
type=AutoTokenizer.from_pretrained,
68+
pretrained_model_name_or_path=pretrained_model_name_or_path,
69+
trust_remote_code=True,
70+
padding_side='right')
71+
72+
model = dict(
73+
type=SupervisedFinetune,
74+
use_varlen_attn=use_varlen_attn,
75+
llm=dict(
76+
type=AutoModelForCausalLM.from_pretrained,
77+
pretrained_model_name_or_path=pretrained_model_name_or_path,
78+
trust_remote_code=True))
79+
80+
#######################################################################
81+
# PART 3 Dataset & Dataloader #
82+
#######################################################################
83+
alpaca_en = dict(
84+
type=process_hf_dataset,
85+
dataset=dict(type=load_dataset, path=alpaca_en_path),
86+
tokenizer=tokenizer,
87+
max_length=max_length,
88+
dataset_map_fn=alpaca_map_fn,
89+
template_map_fn=dict(
90+
type=template_map_fn_factory, template=prompt_template),
91+
remove_unused_columns=True,
92+
shuffle_before_pack=True,
93+
pack_to_max_length=pack_to_max_length,
94+
use_varlen_attn=use_varlen_attn)
95+
96+
sampler = SequenceParallelSampler \
97+
if sequence_parallel_size > 1 else DefaultSampler
98+
train_dataloader = dict(
99+
batch_size=batch_size,
100+
num_workers=dataloader_num_workers,
101+
dataset=alpaca_en,
102+
sampler=dict(type=sampler, shuffle=True),
103+
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
104+
105+
#######################################################################
106+
# PART 4 Scheduler & Optimizer #
107+
#######################################################################
108+
# optimizer
109+
optim_wrapper = dict(
110+
type=AmpOptimWrapper,
111+
optimizer=dict(
112+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
113+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
114+
accumulative_counts=accumulative_counts,
115+
loss_scale='dynamic',
116+
dtype='float16')
117+
118+
# learning policy
119+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
120+
param_scheduler = [
121+
dict(
122+
type=LinearLR,
123+
start_factor=1e-5,
124+
by_epoch=True,
125+
begin=0,
126+
end=warmup_ratio * max_epochs,
127+
convert_to_iter_based=True),
128+
dict(
129+
type=CosineAnnealingLR,
130+
eta_min=0.0,
131+
by_epoch=True,
132+
begin=warmup_ratio * max_epochs,
133+
end=max_epochs,
134+
convert_to_iter_based=True)
135+
]
136+
137+
# train, val, test setting
138+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
139+
140+
#######################################################################
141+
# PART 5 Runtime #
142+
#######################################################################
143+
# Log the dialogue periodically during the training process, optional
144+
custom_hooks = [
145+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
146+
dict(
147+
type=EvaluateChatHook,
148+
tokenizer=tokenizer,
149+
every_n_iters=evaluation_freq,
150+
evaluation_inputs=evaluation_inputs,
151+
system=SYSTEM,
152+
prompt_template=prompt_template)
153+
]
154+
155+
if use_varlen_attn:
156+
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
157+
158+
# configure default hooks
159+
default_hooks = dict(
160+
# record the time of every iteration.
161+
timer=dict(type=IterTimerHook),
162+
# print log every 10 iterations.
163+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
164+
# enable the parameter scheduler.
165+
param_scheduler=dict(type=ParamSchedulerHook),
166+
# save checkpoint per `save_steps`.
167+
checkpoint=dict(
168+
type=CheckpointHook,
169+
by_epoch=False,
170+
interval=save_steps,
171+
max_keep_ckpts=save_total_limit),
172+
# set sampler seed in distributed evrionment.
173+
sampler_seed=dict(type=DistSamplerSeedHook),
174+
)
175+
176+
# configure environment
177+
env_cfg = dict(
178+
# whether to enable cudnn benchmark
179+
cudnn_benchmark=False,
180+
# set multi process parameters
181+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
182+
# set distributed parameters
183+
dist_cfg=dict(backend='nccl'),
184+
)
185+
186+
# set visualizer
187+
visualizer = None
188+
189+
# set log level
190+
log_level = 'INFO'
191+
192+
# load from which checkpoint
193+
load_from = None
194+
195+
# whether to resume training from the loaded checkpoint
196+
resume = False
197+
198+
# Defaults to use random seed and disable `deterministic`
199+
randomness = dict(seed=None, deterministic=False)
200+
201+
# set log processor
202+
log_processor = dict(by_epoch=False)

0 commit comments

Comments
 (0)