Skip to content

Commit b36cbb5

Browse files
committed
hybrid data pipeline
1 parent adcbf27 commit b36cbb5

20 files changed

+2566
-4
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
[
2+
{
3+
"messages": [
4+
{
5+
"role": "user",
6+
"content": "I want to know today's weather in Shanghai"
7+
},
8+
9+
{
10+
"role": "assistant",
11+
"content": "Sure, I will search for the weather of Shanghai.",
12+
"function_call": {
13+
"name": "get_current_weather",
14+
"parameters": {
15+
"location": "Shanghai"
16+
}
17+
}
18+
},
19+
20+
{
21+
"role": "function",
22+
"name": "get_current_weather",
23+
"content": "{'temperature': 22}"
24+
},
25+
{
26+
"role": "assistant",
27+
"content": "The weather in Shanghai is 22 celsius"
28+
}
29+
30+
31+
],
32+
33+
"functions": [
34+
{
35+
"name": "get_current_weather",
36+
"description": "Get the current weather in a given location",
37+
"parameters": {
38+
"type": "object",
39+
"properties": {
40+
"location": {
41+
"type": "string",
42+
"description": "The city and state, e.g. San Francisco, CA",
43+
"unit": {"type": "string"}
44+
},
45+
"required": ["location"]
46+
}
47+
}
48+
}
49+
]
50+
}
51+
52+
]
53+
54+
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from mmengine.dataset import DefaultSampler
4+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5+
LoggerHook, ParamSchedulerHook)
6+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7+
8+
from torch.optim import AdamW
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
11+
12+
from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn
13+
from xtuner.dataset.hybrid.mappings import openai_to_raw_training
14+
from xtuner.engine.hooks import DatasetInfoHook
15+
from xtuner.engine.runner import TrainLoop
16+
from xtuner.model import HybridFinetune
17+
from xtuner.types import HybridChatTemplate
18+
19+
#######################################################################
20+
# PART 1 Settings #
21+
#######################################################################
22+
# Model
23+
llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/'
24+
visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336'
25+
# Specify the pretrained pth
26+
pretrained_pth = None
27+
# Data
28+
data_dir = './'
29+
data_files = ['function_call.json']
30+
max_length = 2048
31+
32+
# Chat Template
33+
chat_template = dict(
34+
type=HybridChatTemplate,
35+
system='<|im_start|>system\n{system}<|im_end|>\n',
36+
user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
37+
assistant='{assistant}<|im_end|>\n',
38+
stop_words=['<|im_end|>'],
39+
image_token='<image>',
40+
function_call=
41+
'{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
42+
function_result=
43+
'<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
44+
functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n')
45+
46+
# Scheduler & Optimizer
47+
batch_size = 1 # per_device
48+
accumulative_counts = 1
49+
dataloader_num_workers = 0
50+
max_epochs = 1
51+
optim_type = AdamW
52+
lr = 2e-4
53+
betas = (0.9, 0.999)
54+
weight_decay = 0
55+
max_norm = 1 # grad clip
56+
warmup_ratio = 0.03
57+
58+
# Save
59+
save_steps = 500
60+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
61+
62+
# Evaluate the generation performance during the training
63+
evaluation_freq = 500
64+
SYSTEM = ''
65+
evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
66+
evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
67+
68+
#######################################################################
69+
# PART 2 Model & Tokenizer & Image Processor #
70+
#######################################################################
71+
tokenizer = dict(
72+
type=AutoTokenizer.from_pretrained,
73+
pretrained_model_name_or_path=llm_name_or_path,
74+
trust_remote_code=True,
75+
padding_side='right')
76+
77+
78+
model = dict(
79+
type=HybridFinetune,
80+
llm=dict(
81+
type=AutoModelForCausalLM.from_pretrained,
82+
pretrained_model_name_or_path=llm_name_or_path,
83+
trust_remote_code=True,
84+
torch_dtype=torch.float16))
85+
86+
#######################################################################
87+
# PART 3 Dataset & Dataloader #
88+
#######################################################################
89+
llava_dataset = dict(
90+
type=HybridDataset,
91+
data_dir=data_dir,
92+
data_files=data_files,
93+
sample_ratio=1,
94+
tokenizer=tokenizer,
95+
chat_template=chat_template,
96+
max_length=max_length,
97+
pack_to_max_length=True,
98+
num_workers = dataloader_num_workers,
99+
mappings=[openai_to_raw_training])
100+
101+
train_dataloader = dict(
102+
batch_size=batch_size,
103+
num_workers=dataloader_num_workers,
104+
dataset=llava_dataset,
105+
sampler=dict(type=DefaultSampler, shuffle=True),
106+
collate_fn=dict(type=hybrid_collate_fn))
107+
108+
#######################################################################
109+
# PART 4 Scheduler & Optimizer #
110+
#######################################################################
111+
# optimizer
112+
optim_wrapper = dict(
113+
type=AmpOptimWrapper,
114+
optimizer=dict(
115+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
116+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
117+
accumulative_counts=accumulative_counts,
118+
loss_scale='dynamic',
119+
dtype='float16')
120+
121+
# learning policy
122+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
123+
param_scheduler = [
124+
dict(
125+
type=LinearLR,
126+
start_factor=1e-5,
127+
by_epoch=True,
128+
begin=0,
129+
end=warmup_ratio * max_epochs,
130+
convert_to_iter_based=True),
131+
dict(
132+
type=CosineAnnealingLR,
133+
eta_min=0.0,
134+
by_epoch=True,
135+
begin=warmup_ratio * max_epochs,
136+
end=max_epochs,
137+
convert_to_iter_based=True)
138+
]
139+
140+
# train, val, test setting
141+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
142+
143+
#######################################################################
144+
# PART 5 Runtime #
145+
#######################################################################
146+
# Log the dialogue periodically during the training process, optional
147+
custom_hooks = [
148+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
149+
# dict(
150+
# type=EvaluateChatHook,
151+
# tokenizer=tokenizer,
152+
# image_processor=image_processor,
153+
# every_n_iters=evaluation_freq,
154+
# evaluation_inputs=evaluation_inputs,
155+
# evaluation_images=evaluation_images,
156+
# system=SYSTEM,
157+
# prompt_template=prompt_template)
158+
]
159+
160+
# configure default hooks
161+
default_hooks = dict(
162+
# record the time of every iteration.
163+
timer=dict(type=IterTimerHook),
164+
# print log every 10 iterations.
165+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
166+
# enable the parameter scheduler.
167+
param_scheduler=dict(type=ParamSchedulerHook),
168+
# save checkpoint per `save_steps`.
169+
checkpoint=dict(
170+
type=CheckpointHook,
171+
by_epoch=False,
172+
interval=save_steps,
173+
max_keep_ckpts=save_total_limit),
174+
# set sampler seed in distributed evrionment.
175+
sampler_seed=dict(type=DistSamplerSeedHook),
176+
)
177+
178+
# configure environment
179+
env_cfg = dict(
180+
# whether to enable cudnn benchmark
181+
cudnn_benchmark=False,
182+
# set multi process parameters
183+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
184+
# set distributed parameters
185+
dist_cfg=dict(backend='nccl'),
186+
)
187+
188+
# set visualizer
189+
visualizer = None
190+
191+
# set log level
192+
log_level = 'INFO'
193+
194+
# load from which checkpoint
195+
load_from = None
196+
197+
# whether to resume training from the loaded checkpoint
198+
resume = False
199+
200+
# Defaults to use random seed and disable `deterministic`
201+
randomness = dict(seed=None, deterministic=False)
202+
203+
# set log processor
204+
log_processor = dict(by_epoch=False)

0 commit comments

Comments
 (0)