Skip to content

Commit e571688

Browse files
committed
support code interpreter finetune
1 parent 5c8c265 commit e571688

File tree

9 files changed

+249
-70
lines changed

9 files changed

+249
-70
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{
2+
"messages": [
3+
{"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"},
4+
{
5+
"role": "user",
6+
"content": "Please help me process and visualize this dataset.",
7+
"files": [{"path": "data.csv", "size": "10K"}]
8+
},
9+
{
10+
"role": "assistant",
11+
"content": "I have processed the data and visualized it for you.",
12+
"code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='<b>Date: %{text}</b><br>Wind Direction 9am: %{x}<br>Rainfall: %{y}<br>Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```"
13+
},
14+
{
15+
"role": "code_interpreter",
16+
"content": "![image](xxx.png)"
17+
},
18+
{
19+
"role": "assistant",
20+
"content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background."
21+
},
22+
{
23+
"role": "user",
24+
"content": "I want to know today's weather in Shanghai"
25+
},
26+
{
27+
"role": "assistant",
28+
"content": "Sure, I will search for the weather of Shanghai.",
29+
"function_call": {
30+
"name": "get_current_weather",
31+
"parameters": {"location": "Shanghai"}
32+
}
33+
},
34+
{
35+
"role": "function",
36+
"name": "get_current_weather",
37+
"content": "{'temperature': 22}"
38+
},
39+
{
40+
"role": "assistant",
41+
"content": "The weather in Shanghai is 22 celsius"
42+
}
43+
],
44+
45+
"functions": [
46+
{
47+
"name": "get_current_weather",
48+
"description": "Get the current weather in a given location",
49+
"parameters": {
50+
"type": "object",
51+
"properties": {
52+
"location": {
53+
"type": "string",
54+
"description": "The city and state, e.g. San Francisco, CA",
55+
"unit": {"type": "string"}},
56+
"required": ["location"]
57+
}
58+
}
59+
}
60+
],
61+
62+
"code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import json
2+
3+
from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages
4+
5+
6+
chat_template = HybridChatTemplate(
7+
system='<|im_start|>system\n{system}<|im_end|>\n',
8+
user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
9+
assistant='{assistant}<|im_end|>\n',
10+
stop_words=['<|im_end|>'],
11+
image_token='<image>',
12+
files='<|im_start|>user name=file\n{files}<|im_end|>\n',
13+
function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
14+
function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
15+
functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n',
16+
code_interpreter_call='{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
17+
code_interpreter_result='<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
18+
code_interpreter='<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n'
19+
20+
)
21+
22+
agent_data = json.load(open('agent.json'))
23+
24+
msg = TrainingHybridChatMessages.from_dict(agent_data)
25+
print(msg.apply_chat_template(chat_template))
26+
27+
from transformers import AutoTokenizer
28+
tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True)
29+
print(msg.tokenize(tokenizer, chat_template))

xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
55
LoggerHook, ParamSchedulerHook)
66
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7-
87
from torch.optim import AdamW
98
from transformers import AutoModelForCausalLM, AutoTokenizer
10-
119

1210
from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn
1311
from xtuner.dataset.hybrid.mappings import openai_to_raw_training
@@ -74,7 +72,6 @@
7472
trust_remote_code=True,
7573
padding_side='right')
7674

77-
7875
model = dict(
7976
type=HybridFinetune,
8077
llm=dict(
@@ -95,7 +92,7 @@
9592
chat_template=chat_template,
9693
max_length=max_length,
9794
pack_to_max_length=True,
98-
num_workers = dataloader_num_workers,
95+
num_workers=dataloader_num_workers,
9996
mappings=[openai_to_raw_training])
10097

10198
train_dataloader = dict(

xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
55
LoggerHook, ParamSchedulerHook)
66
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7+
from peft import LoraConfig
78
from torch.optim import AdamW
89
from transformers import (AutoModelForCausalLM, AutoTokenizer,
9-
CLIPImageProcessor, CLIPVisionModel)
10+
BitsAndBytesConfig, CLIPImageProcessor,
11+
CLIPVisionModel)
1012

1113
from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn
1214
from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens,
@@ -21,15 +23,17 @@
2123
# PART 1 Settings #
2224
#######################################################################
2325
# Model
26+
# llm_name_or_path = '/mnt/petrelfs/share_data/basemodel/checkpoints/llm/hf_hub/models--internlm--internlm2-chat-1_8b/snapshots/aa8a7450c2227a3b6733b3c6fe33fefbb2ca54f9/'
2427
llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/'
2528
visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336'
29+
use_varlen_attn = False
2630
# Specify the pretrained pth
2731
pretrained_pth = None
2832
# Data
2933
data_dir = './llava_data/'
3034
data_files = ['LLaVA-Instruct-150K/llava_v1_5_mix665k.json']
3135
image_dir = data_dir + 'llava_images'
32-
max_length = 1024 * 32
36+
max_length = 1024 * 2
3337

3438
# Chat Template
3539
chat_template = dict(
@@ -46,12 +50,12 @@
4650
functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n')
4751

4852
# Scheduler & Optimizer
49-
batch_size = 1 # per_device
53+
batch_size = 16 # per_device
5054
accumulative_counts = 1
51-
dataloader_num_workers = 4
55+
dataloader_num_workers = 0
5256
max_epochs = 1
5357
optim_type = AdamW
54-
lr = 2e-4
58+
lr = 0
5559
betas = (0.9, 0.999)
5660
weight_decay = 0
5761
max_norm = 1 # grad clip
@@ -86,14 +90,34 @@
8690
freeze_llm=False,
8791
freeze_visual_encoder=True,
8892
pretrained_pth=pretrained_pth,
93+
use_varlen_attn=use_varlen_attn,
8994
llm=dict(
9095
type=AutoModelForCausalLM.from_pretrained,
9196
pretrained_model_name_or_path=llm_name_or_path,
9297
trust_remote_code=True,
93-
torch_dtype=torch.float16),
98+
torch_dtype=torch.bfloat16,
99+
attn_implementation='flash_attention_2',
100+
quantization_config=dict(
101+
type=BitsAndBytesConfig,
102+
load_in_4bit=True,
103+
load_in_8bit=False,
104+
llm_int8_threshold=6.0,
105+
llm_int8_has_fp16_weight=False,
106+
bnb_4bit_compute_dtype=torch.float16,
107+
bnb_4bit_use_double_quant=True,
108+
bnb_4bit_quant_type='nf4')),
109+
llm_lora=dict(
110+
type=LoraConfig,
111+
r=512,
112+
lora_alpha=256,
113+
lora_dropout=0.05,
114+
bias='none',
115+
task_type='CAUSAL_LM'),
94116
visual_encoder=dict(
95117
type=CLIPVisionModel.from_pretrained,
96-
pretrained_model_name_or_path=visual_encoder_name_or_path))
118+
pretrained_model_name_or_path=visual_encoder_name_or_path),
119+
visual_encoder_lora=dict(
120+
type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'))
97121

98122
#######################################################################
99123
# PART 3 Dataset & Dataloader #
@@ -102,16 +126,16 @@
102126
type=HybridDataset,
103127
data_dir=data_dir,
104128
data_files=data_files,
105-
data_cached='cached_llava',
129+
# data_cached='cached_llava',
106130
image_dir=image_dir,
107-
sample_ratio=1,
131+
sample_ratio=0.1,
108132
tokenizer=tokenizer,
109133
chat_template=chat_template,
110134
image_processor=image_processor,
111135
pad_img_to_squared=True,
112136
max_length=max_length,
113-
pack_to_max_length=True,
114-
num_workers=dataloader_num_workers,
137+
pack_to_max_length=False,
138+
num_workers=4,
115139
mappings=[
116140
llava_to_openai,
117141
openai_to_raw_training,
@@ -120,7 +144,7 @@
120144

121145
train_dataloader = dict(
122146
batch_size=batch_size,
123-
num_workers=dataloader_num_workers,
147+
num_workers=4,
124148
dataset=llava_dataset,
125149
sampler=dict(type=DefaultSampler, shuffle=True),
126150
collate_fn=dict(type=hybrid_collate_fn))
@@ -182,7 +206,7 @@
182206
# record the time of every iteration.
183207
timer=dict(type=IterTimerHook),
184208
# print log every 10 iterations.
185-
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
209+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
186210
# enable the parameter scheduler.
187211
param_scheduler=dict(type=ParamSchedulerHook),
188212
# save checkpoint per `save_steps`.

xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"image_url": "image2.jpg"
1414
},
1515
{
16-
"type": "text",
16+
"type": "text",
1717
"text": "What are the colors of the bus in the first image?"
1818
}
1919
]
@@ -37,5 +37,3 @@
3737
]
3838
}
3939
]
40-
41-

xtuner/dataset/hybrid/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def img_sample_counter(item):
287287
def img_counter(item):
288288
return len(item['image_urls'])
289289

290-
291290
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
292291
images = list(
293292
tqdm(
@@ -403,8 +402,10 @@ def __getitem__(self, item: int) -> Dict[str, List]:
403402
assistant='{assistant}<|im_end|>\n',
404403
stop_words=['<|im_end|>'],
405404
image_token='<image>',
406-
function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
407-
function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
405+
function_call=
406+
'{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
407+
function_result=
408+
'<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
408409
functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n'
409410
)
410411

0 commit comments

Comments
 (0)