diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json
new file mode 100644
index 000000000..e7d8966e9
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json
@@ -0,0 +1,62 @@
+{
+ "messages": [
+ {"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"},
+ {
+ "role": "user",
+ "content": "Please help me process and visualize this dataset.",
+ "files": [{"path": "data.csv", "size": "10K"}]
+ },
+ {
+ "role": "assistant",
+ "content": "I have processed the data and visualized it for you.",
+ "code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='Date: %{text}
Wind Direction 9am: %{x}
Rainfall: %{y}
Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```"
+ },
+ {
+ "role": "code_interpreter",
+ "content": ""
+ },
+ {
+ "role": "assistant",
+ "content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background."
+ },
+ {
+ "role": "user",
+ "content": "I want to know today's weather in Shanghai"
+ },
+ {
+ "role": "assistant",
+ "content": "Sure, I will search for the weather of Shanghai.",
+ "function_call": {
+ "name": "get_current_weather",
+ "parameters": {"location": "Shanghai"}
+ }
+ },
+ {
+ "role": "function",
+ "name": "get_current_weather",
+ "content": "{'temperature': 22}"
+ },
+ {
+ "role": "assistant",
+ "content": "The weather in Shanghai is 22 celsius"
+ }
+ ],
+
+ "functions": [
+ {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ "unit": {"type": "string"}},
+ "required": ["location"]
+ }
+ }
+ }
+ ],
+
+ "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"}
diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py
new file mode 100644
index 000000000..e9d5796bc
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py
@@ -0,0 +1,33 @@
+import json
+
+from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages
+
+chat_template = HybridChatTemplate(
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>\n',
+ stop_words=['<|im_end|>'],
+ image_token='',
+ files='<|im_start|>user name=file\n{files}<|im_end|>\n',
+ function_call=
+ '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
+ function_result=
+ '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
+ functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n',
+ code_interpreter_call=
+ '{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
+ code_interpreter_result=
+ '<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
+ code_interpreter=
+ '<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n')
+
+agent_data = json.load(open('agent.json'))
+
+msg = TrainingHybridChatMessages.from_dict(agent_data)
+print(msg.apply_chat_template(chat_template))
+
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained(
+ 'internlm/internlm2-chat-7b', trust_remote_code=True)
+print(msg.tokenize(tokenizer, chat_template))
diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json
new file mode 100644
index 000000000..a7a1abbf3
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json
@@ -0,0 +1,52 @@
+[
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "I want to know today's weather in Shanghai"
+ },
+
+ {
+ "role": "assistant",
+ "content": "Sure, I will search for the weather of Shanghai.",
+ "function_call": {
+ "name": "get_current_weather",
+ "parameters": {
+ "location": "Shanghai"
+ }
+ }
+ },
+
+ {
+ "role": "function",
+ "name": "get_current_weather",
+ "content": "{'temperature': 22}"
+ },
+ {
+ "role": "assistant",
+ "content": "The weather in Shanghai is 22 celsius"
+ }
+
+
+ ],
+
+ "functions": [
+ {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ "unit": {"type": "string"}
+ },
+ "required": ["location"]
+ }
+ }
+ }
+ ]
+ }
+
+]
diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py
new file mode 100644
index 000000000..481b22620
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py
@@ -0,0 +1,193 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmengine.dataset import DefaultSampler
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from torch.optim import AdamW
+from transformers import AutoTokenizer
+
+from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn
+from xtuner.dataset.hybrid.mappings import openai_to_raw_training
+from xtuner.engine.hooks import DatasetInfoHook
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import AgentFinetune, AutoModelForCausalLM
+from xtuner.types import HybridChatTemplate
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+llm_name_or_path = '/mnt/petrelfs/share_data/basemodel/checkpoints/llm/hf_hub/models--internlm--internlm2-chat-1_8b/snapshots/aa8a7450c2227a3b6733b3c6fe33fefbb2ca54f9/'
+
+# Data
+data_dir = './'
+data_files = ['agentlego.json']
+max_length = 2048
+
+# Chat Template
+chat_template = dict(
+ type=HybridChatTemplate,
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>\n',
+ stop_words=['<|im_end|>'],
+ image_token='',
+ function_call=
+ '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
+ function_result=
+ '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
+ functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n')
+
+# Scheduler & Optimizer
+batch_size = 1 # per_device
+accumulative_counts = 1
+dataloader_num_workers = 0
+max_epochs = 1
+optim_type = AdamW
+lr = 2e-4
+betas = (0.9, 0.999)
+weight_decay = 0
+max_norm = 1 # grad clip
+warmup_ratio = 0.03
+
+# Save
+save_steps = 500
+save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
+
+#######################################################################
+# PART 2 Model & Tokenizer & Image Processor #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=llm_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=AgentFinetune,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=llm_name_or_path,
+ trust_remote_code=True,
+ torch_dtype=torch.float16))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+dataset = dict(
+ type=HybridDataset,
+ data_dir=data_dir,
+ data_files=data_files,
+ sample_ratio=1,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ max_length=max_length,
+ pack_to_max_length=True,
+ num_workers=4,
+ mappings=[openai_to_raw_training])
+
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=dataset,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ collate_fn=dict(type=hybrid_collate_fn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic',
+ dtype='float16')
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-5,
+ by_epoch=True,
+ begin=0,
+ end=warmup_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=0.0,
+ by_epoch=True,
+ begin=warmup_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
+ # dict(
+ # type=EvaluateChatHook,
+ # tokenizer=tokenizer,
+ # image_processor=image_processor,
+ # every_n_iters=evaluation_freq,
+ # evaluation_inputs=evaluation_inputs,
+ # evaluation_images=evaluation_images,
+ # system=SYSTEM,
+ # prompt_template=prompt_template)
+]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 10 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# set log processor
+log_processor = dict(by_epoch=False)
diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py
new file mode 100644
index 000000000..a010f44b5
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py
@@ -0,0 +1,248 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmengine.dataset import DefaultSampler
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from peft import LoraConfig
+from torch.optim import AdamW
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig, CLIPImageProcessor,
+ CLIPVisionModel)
+
+from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn
+from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens,
+ llava_to_openai,
+ openai_to_raw_training)
+from xtuner.engine.hooks import DatasetInfoHook
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import HybridFinetune
+from xtuner.types import HybridChatTemplate
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+# llm_name_or_path = '/mnt/petrelfs/share_data/basemodel/checkpoints/llm/hf_hub/models--internlm--internlm2-chat-1_8b/snapshots/aa8a7450c2227a3b6733b3c6fe33fefbb2ca54f9/'
+llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/'
+visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336'
+use_varlen_attn = False
+# Specify the pretrained pth
+pretrained_pth = None
+# Data
+data_dir = './llava_data/'
+data_files = ['LLaVA-Instruct-150K/llava_v1_5_mix665k.json']
+image_dir = data_dir + 'llava_images'
+max_length = 1024 * 2
+
+# Chat Template
+chat_template = dict(
+ type=HybridChatTemplate,
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>\n',
+ stop_words=['<|im_end|>'],
+ image_token='',
+ function_call=
+ '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
+ function_result=
+ '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
+ functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n')
+
+# Scheduler & Optimizer
+batch_size = 16 # per_device
+accumulative_counts = 1
+dataloader_num_workers = 0
+max_epochs = 1
+optim_type = AdamW
+lr = 0
+betas = (0.9, 0.999)
+weight_decay = 0
+max_norm = 1 # grad clip
+warmup_ratio = 0.03
+
+# Save
+save_steps = 500
+save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
+
+# Evaluate the generation performance during the training
+evaluation_freq = 500
+SYSTEM = ''
+evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
+evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
+
+#######################################################################
+# PART 2 Model & Tokenizer & Image Processor #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=llm_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+image_processor = dict(
+ type=CLIPImageProcessor.from_pretrained,
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
+ trust_remote_code=True)
+
+model = dict(
+ type=HybridFinetune,
+ freeze_llm=False,
+ freeze_visual_encoder=True,
+ pretrained_pth=pretrained_pth,
+ use_varlen_attn=use_varlen_attn,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=llm_name_or_path,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ attn_implementation='flash_attention_2',
+ quantization_config=dict(
+ type=BitsAndBytesConfig,
+ load_in_4bit=True,
+ load_in_8bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4')),
+ llm_lora=dict(
+ type=LoraConfig,
+ r=512,
+ lora_alpha=256,
+ lora_dropout=0.05,
+ bias='none',
+ task_type='CAUSAL_LM'),
+ visual_encoder=dict(
+ type=CLIPVisionModel.from_pretrained,
+ pretrained_model_name_or_path=visual_encoder_name_or_path),
+ visual_encoder_lora=dict(
+ type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+llava_dataset = dict(
+ type=HybridDataset,
+ data_dir=data_dir,
+ data_files=data_files,
+ # data_cached='cached_llava',
+ image_dir=image_dir,
+ sample_ratio=0.1,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ image_processor=image_processor,
+ pad_img_to_squared=True,
+ max_length=max_length,
+ pack_to_max_length=False,
+ num_workers=4,
+ mappings=[
+ llava_to_openai,
+ openai_to_raw_training,
+ insert_img_pad_tokens,
+ ])
+
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=4,
+ dataset=llava_dataset,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ collate_fn=dict(type=hybrid_collate_fn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic',
+ dtype='float16')
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-5,
+ by_epoch=True,
+ begin=0,
+ end=warmup_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=0.0,
+ by_epoch=True,
+ begin=warmup_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
+ # dict(
+ # type=EvaluateChatHook,
+ # tokenizer=tokenizer,
+ # image_processor=image_processor,
+ # every_n_iters=evaluation_freq,
+ # evaluation_inputs=evaluation_inputs,
+ # evaluation_images=evaluation_images,
+ # system=SYSTEM,
+ # prompt_template=prompt_template)
+]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 10 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# set log processor
+log_processor = dict(by_epoch=False)
diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json
new file mode 100644
index 000000000..ebe5cf457
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json
@@ -0,0 +1,39 @@
+[
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": "image1.jpg"
+ },
+ {
+ "type": "image_url",
+ "image_url": "image2.jpg"
+ },
+ {
+ "type": "text",
+ "text": "What are the colors of the bus in the first image?"
+ }
+ ]
+ },
+
+ {
+ "role": "assistant",
+ "content": "The bus in the image is white and red."
+ },
+
+ {
+ "role": "user",
+ "content": "Where is the cat positioned in the second image?"
+ },
+
+ {
+ "role": "assistant",
+ "content": "The cat is positioned on top of the back of the couch in the living room."
+ }
+
+ ]
+ }
+]
diff --git a/xtuner/dataset/hybrid/__init__.py b/xtuner/dataset/hybrid/__init__.py
new file mode 100644
index 000000000..febf1a497
--- /dev/null
+++ b/xtuner/dataset/hybrid/__init__.py
@@ -0,0 +1,14 @@
+from .collate import hybrid_collate_fn
+from .dataset import HybridDataset
+from .mappings import (insert_img_pad_tokens, llava_to_openai, map_protocol,
+ map_sequential, openai_to_raw_training)
+
+__all__ = [
+ 'hybrid_collate_fn',
+ 'HybridDataset',
+ 'insert_img_pad_tokens',
+ 'llava_to_openai',
+ 'map_protocol',
+ 'map_sequential',
+ 'openai_to_raw_training',
+]
diff --git a/xtuner/dataset/hybrid/_pack.py b/xtuner/dataset/hybrid/_pack.py
new file mode 100644
index 000000000..12b29fbca
--- /dev/null
+++ b/xtuner/dataset/hybrid/_pack.py
@@ -0,0 +1,133 @@
+import bisect
+import itertools
+import random
+
+import torch
+
+
+class _PackDataset(torch.utils.data.Dataset):
+
+ def __init__(self, dataset, max_length=2048):
+ super().__init__()
+
+ self.max_length = max_length
+
+ # unpack dataset
+ self.dataset = dataset
+
+ self._ori_img_urls = dataset['image_urls']
+ self._ori_img_rngs = dataset['image_ranges']
+ self._ori_lens = dataset['tokens']
+
+ self._num_packed_samples = sum(self._ori_lens) // self.max_length
+
+ inds = [i for i in range(len(self.dataset))]
+ random.shuffle(inds)
+ self.shfl_inds = inds
+
+ shfl_lens = [self._ori_lens[i] for i in inds]
+ # shuffled cumulative lengths
+ shfl_acc_lens = list(itertools.accumulate(shfl_lens))
+
+ self._shfl_item_rngs_left = [0] + shfl_acc_lens[:-1]
+ self._shfl_item_rngs_right = shfl_acc_lens
+
+ shfl_img_urls = [self._ori_img_urls[i] for i in inds]
+ self._flat_shfl_img_urls = list(itertools.chain(*shfl_img_urls))
+
+ flat_shfl_acc_img_rngs = []
+ flat_shfl_acc_img_rngs_left = []
+ flat_shfl_acc_img_rngs_right = []
+ for i in range(len(self.dataset)):
+ shfl_i = self.shfl_inds[i]
+ img_rngs = self._ori_img_rngs[shfl_i]
+ for left, right in img_rngs:
+ acc_left = left + self._shfl_item_rngs_left[i]
+ acc_right = right + self._shfl_item_rngs_left[i]
+
+ flat_shfl_acc_img_rngs_left.append(acc_left)
+ flat_shfl_acc_img_rngs_right.append(acc_right)
+ flat_shfl_acc_img_rngs.append([acc_left, acc_right])
+ assert len(flat_shfl_acc_img_rngs) == len(self._flat_shfl_img_urls)
+
+ self._flat_shfl_acc_img_rngs = flat_shfl_acc_img_rngs
+ self._flat_shfl_acc_img_rngs_left = flat_shfl_acc_img_rngs_left
+ self._flat_shfl_acc_img_rngs_right = flat_shfl_acc_img_rngs_right
+
+ def _pack_img_urls_and_rngs_in_range(self, begin, end):
+
+ left = bisect.bisect(self._flat_shfl_acc_img_rngs_right, begin)
+ right = bisect.bisect(self._flat_shfl_acc_img_rngs_left, end)
+
+ filter_urls = self._flat_shfl_img_urls[left:right]
+ filter_rngs = self._flat_shfl_acc_img_rngs[left:right]
+
+ inner_rngs = []
+ inner_urls = []
+ for url, rng in zip(filter_urls, filter_rngs):
+ inner_left = max(begin, rng[0]) - begin
+ inner_right = min(end, rng[1]) - begin
+
+ if inner_right - inner_left > 0:
+ inner_rngs.append([inner_left, inner_right])
+ inner_urls.append(url)
+ return inner_urls, inner_rngs
+
+ def _pack_ids_and_labels_in_range(self, begin, end):
+
+ left = bisect.bisect(self._shfl_item_rngs_right, begin)
+ right = bisect.bisect(self._shfl_item_rngs_left, end)
+
+ trunc_ids = []
+ trunc_labels = []
+ cumulative_len = []
+ position_ids = []
+ for i in range(left, right):
+ cumulative_len.append(len(trunc_ids))
+
+ item_begin = self._shfl_item_rngs_left[i]
+ item_end = self._shfl_item_rngs_right[i]
+
+ inner_l = max(begin, item_begin) - item_begin
+ inner_r = min(end, item_end) - item_begin
+ position_ids.extend([i for i in range(inner_r - inner_l)])
+
+ ori_idx = self.shfl_inds[i]
+ ori_input_ids = self.dataset[ori_idx]['input_ids']
+ ori_labels = self.dataset[ori_idx]['labels']
+
+ trunc_ids.extend(ori_input_ids[inner_l:inner_r])
+ trunc_labels.extend(ori_labels[inner_l:inner_r])
+
+ cumulative_len.append(self.max_length)
+ return trunc_ids, trunc_labels, cumulative_len, position_ids
+
+ def __len__(self):
+ return self._num_packed_samples
+
+ def __getitem__(self, item):
+
+ begin = item * self.max_length
+ end = (item + 1) * self.max_length
+
+ _res = self._pack_ids_and_labels_in_range(begin, end)
+ packed_ids, packed_labels, cumulative_len, position_ids = _res
+ assert self.max_length == len(packed_ids) == len(packed_labels)
+
+ _res = self._pack_img_urls_and_rngs_in_range(begin, end)
+ packed_img_urls, packed_img_rngs = _res
+
+ for left, right in packed_img_rngs:
+ assert len(set(packed_ids[left:right])) == 1
+
+ packed = {
+ 'input_ids': packed_ids,
+ 'labels': packed_labels,
+ 'tokens': self.max_length,
+ 'image_urls': packed_img_urls,
+ 'image_ranges': packed_img_rngs,
+ 'cumulative_len': cumulative_len,
+ 'position_ids': position_ids
+ }
+
+ return packed
diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py
new file mode 100644
index 000000000..e6ed2288a
--- /dev/null
+++ b/xtuner/dataset/hybrid/collate.py
@@ -0,0 +1,80 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Sequence
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
+
+
+def hybrid_collate_fn(instances: Sequence[Dict],
+ pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
+ return_hf_format: bool = False):
+
+ input_ids = []
+ labels = []
+ pixel_values = []
+ cumulative_len = []
+ image_ranges = []
+ image_belongs = []
+ position_ids = []
+
+ for i, data in enumerate(instances):
+ input_ids.append(torch.LongTensor(data['input_ids']))
+ labels.append(torch.LongTensor(data['labels']))
+ position_ids.append(torch.IntTensor(data['position_ids']))
+
+ if 'cumulative_len' in data:
+ cumulative_len.append(torch.IntTensor(data['cumulative_len']))
+
+ _values = data['pixel_values']
+ _ranges = data['image_ranges']
+
+ assert len(_values) == len(_ranges)
+ for v, rng in zip(_values, _ranges):
+ pixel_values.append(v)
+ image_ranges.append(rng)
+ image_belongs.append(i)
+
+ if len(pixel_values) > 0:
+ assert len(image_ranges) > 0
+ assert len(image_belongs) > 0
+
+ pixel_values = torch.stack(pixel_values)
+ # image_belongs = torch.IntTensor(image_belongs)
+ else:
+ pixel_values = None
+ image_ranges = None
+ image_belongs = None
+
+ if len(instances) > 1:
+ input_ids = pad_sequence(
+ input_ids, batch_first=True, padding_value=pad_index)
+ labels = pad_sequence(
+ labels, batch_first=True, padding_value=IGNORE_INDEX)
+ position_ids = pad_sequence(
+ position_ids, batch_first=True, padding_value=0)
+ else:
+ input_ids = torch.stack(input_ids)
+ labels = torch.stack(labels)
+ position_ids = torch.stack(position_ids)
+
+ if len(cumulative_len) == 0:
+ cumulative_len = None
+
+ # breakpoint()
+ data_dict = {
+ 'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'attention_mask': input_ids.ne(pad_index),
+ 'labels': labels,
+ 'pixel_values': pixel_values,
+ 'cumulative_len': cumulative_len,
+ 'image_ranges': image_ranges,
+ 'image_belongs': image_belongs
+ }
+
+ if return_hf_format:
+ return data_dict
+ else:
+ return {'data': data_dict, 'data_samples': None}
diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py
new file mode 100644
index 000000000..b2699e048
--- /dev/null
+++ b/xtuner/dataset/hybrid/dataset.py
@@ -0,0 +1,474 @@
+import json
+import os
+import random
+from concurrent.futures import ThreadPoolExecutor
+from datetime import timedelta
+from functools import partial
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from datasets import Dataset, load_from_disk
+from mmengine import print_log
+from PIL import Image
+from torch import distributed as dist
+from torch import nn
+from tqdm import tqdm
+
+from xtuner.dataset.hybrid._pack import _PackDataset
+from xtuner.dataset.hybrid.mappings import map_protocol, map_sequential
+from xtuner.dataset.utils import expand2square
+from xtuner.registry import BUILDER
+from xtuner.types import HybridChatTemplate
+from xtuner.utils import build_tokenizer
+
+os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+
+
+@map_protocol(
+ input_keys=dict(input_ids=list),
+ added_keys=dict(tokens=int),
+)
+def _register_tokens(data, tokenizer=None, chat_template=None):
+ data['tokens'] = len(data['input_ids'])
+ return data
+
+
+@map_protocol(
+ input_keys=dict(input_ids=list),
+ added_keys=dict(cumulative_len=list),
+)
+def _register_cumulative_len(data, tokenizer=None, chat_template=None):
+ data['cumulative_len'] = [0, len(data['input_ids'])]
+ return data
+
+
+@map_protocol(
+ input_keys=dict(input_ids=list),
+ added_keys=dict(position_ids=list),
+)
+def _register_position_ids(data, tokenizer=None, chat_template=None):
+ data['position_ids'] = [i for i in range(len(data['input_ids']))]
+ return data
+
+
+@map_protocol(
+ added_keys=dict(image_ranges=list), )
+def _register_empty_img_ranges(data, tokenizer=None, chat_template=None):
+ if 'image_ranges' not in data:
+ data['image_ranges'] = []
+ return data
+
+
+@map_protocol(
+ input_keys=dict(
+ input_ids=list,
+ labels=list,
+ tokens=int,
+ image_urls=list,
+ image_ranges=list,
+ position_ids=list,
+ cumulative_len=list),
+ output_keys=dict(
+ input_ids=list,
+ labels=list,
+ tokens=int,
+ image_urls=list,
+ image_ranges=list,
+ position_ids=list,
+ cumulative_len=list))
+def _check_mapped_data(item, tokenizer=None, chat_template=None):
+ assert isinstance(item['input_ids'][0], int)
+ assert isinstance(item['labels'][0], int)
+
+ if len(item['image_urls']) > 0:
+ assert isinstance(item['image_urls'][0], str)
+
+ if len(item['image_ranges']) > 0:
+ assert isinstance(item['image_ranges'][0], list)
+ assert isinstance(item['image_ranges'][0][0], int)
+
+ return item
+
+
+class HybridDataset(torch.utils.data.Dataset):
+ """
+ Args:
+ tokenizer: The tokenizer processes some raw text as input and outputs
+ an Encoding.
+ max_length: Max length of the sequence.
+ pack_to_max_length: Whether to pack the dataset to the `max_length `.
+ This usually improves gpu utilization and therefore reduces
+ training time.
+ shuffle_before_pack: Whether to shuffle the dataset before
+ packing them.
+ use_varlen_attn: If use_varlen_attn is True, we calculate attention
+ the actual length of the sequence rather than the actual length
+ of the sequence
+ """
+
+ def __init__(self,
+ tokenizer,
+ chat_template: Union[Dict, HybridChatTemplate],
+ sample_ratio: int = 1.0,
+ max_length: int = 2048,
+ pack_to_max_length: bool = False,
+ num_workers: int = 8,
+ mappings: Union[Callable, List[Callable]] = [],
+ data_dir: Optional[str] = None,
+ data_files: Optional[Union[str, List[str]]] = None,
+ data_cached: Optional[str] = None,
+ image_dir: Optional[str] = None,
+ image_processor: Optional[nn.Module] = None,
+ pad_img_to_squared: bool = True):
+ super().__init__()
+
+ assert data_dir or data_files or data_cached
+
+ self.tokenizer = build_tokenizer(tokenizer)
+
+ if isinstance(chat_template, HybridChatTemplate):
+ self.chat_template = chat_template
+ elif isinstance(chat_template, dict):
+ self.chat_template = BUILDER.build(chat_template)
+ else:
+ raise TypeError
+
+ if isinstance(image_processor, dict):
+ image_processor = BUILDER.build(image_processor)
+ self.image_processor = image_processor
+
+ if image_dir:
+ self.image_dir = Path(image_dir)
+ else:
+ self.image_dir = Path('')
+
+ self.pad_img_to_squared = pad_img_to_squared
+
+ self.sample_ratio = sample_ratio
+ self.max_length = max_length
+ self.pack_to_max_length = pack_to_max_length
+
+ mappings.append(_register_cumulative_len)
+ mappings.append(_register_position_ids)
+ mappings.append(_register_tokens)
+ mappings.append(_register_empty_img_ranges)
+ mappings.append(_check_mapped_data)
+ map_fn = map_sequential(mappings)
+ self.map_fn = partial(
+ map_fn, tokenizer=self.tokenizer, chat_template=self.chat_template)
+
+ self.num_workers = num_workers
+ if data_cached:
+ self.data_dir = data_dir
+ self.data_files = data_files
+ self.data_cached = data_cached
+ else:
+ data_dir = Path(data_dir)
+ if data_files is None:
+ data_files = [str(f) for f in data_dir.rglob('*.json')]
+ elif isinstance(data_files, list):
+ data_files = [str(data_dir / Path(f)) for f in data_files]
+ elif isinstance(data_files, str):
+ data_files = [str(data_dir / data_files)]
+ else:
+ raise TypeError
+
+ self.data_dir = str(data_dir)
+ self.data_files = data_files
+ self.data_cached = data_cached
+
+ self.dataset = self.build_dataset()
+
+ def build_dataset(self):
+
+ if not (dist.is_available() and dist.is_initialized()):
+ return self._build_dataset()
+
+ timeout = timedelta(
+ minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=30)))
+ print_log(f'xtuner_dataset_timeout = {timeout}', logger='current')
+
+ gloo_group = dist.new_group(backend='gloo', timeout=timeout)
+
+ if dist.get_rank() == 0:
+ dataset = self._build_dataset()
+ objects = [dataset]
+ else:
+ objects = [None]
+
+ dist.monitored_barrier(group=gloo_group, timeout=timeout)
+ dist.broadcast_object_list(objects, src=0)
+
+ return objects[0]
+
+ def _build_dataset(self):
+
+ if self.data_cached:
+ dataset = load_from_disk(self.data_cached)
+ if self.pack_to_max_length:
+ dataset = self._pack_dataset(dataset)
+ return dataset
+
+ dataset = []
+ for file in self.data_files:
+ dataset.extend(json.load(open(file)))
+ print_log(f'Loaded json data from {file}', logger='current')
+
+ if self.sample_ratio < 1:
+ num_samples = int(self.sample_ratio * len(dataset))
+ dataset = random.sample(dataset, num_samples)
+ print_log(
+ f'Randomly selected {num_samples} samples', logger='current')
+
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
+ dataset = list(
+ tqdm(
+ executor.map(self.map_fn, dataset),
+ desc='Map Dataset',
+ total=len(dataset)))
+
+ dataset = self.filter_non_labels_data(dataset)
+
+ self.analysis_tokens_labels(dataset)
+ self.analysis_image_samples(dataset)
+
+ dataset = Dataset.from_list(dataset)
+
+ if self.pack_to_max_length:
+ dataset = self._pack_dataset(dataset)
+
+ return dataset
+
+ def _pack_dataset(self, dataset):
+
+ unpacked_samples = len(dataset)
+ dataset = _PackDataset(dataset, self.max_length)
+ packed_samples = len(dataset)
+ print_log(
+ 'Before pack multi samples to max length: '
+ f'{unpacked_samples} samples',
+ logger='current')
+ print_log(
+ 'After pack multi samples to max length: '
+ f'{packed_samples} samples',
+ logger='current')
+ return dataset
+
+ def filter_non_labels_data(self, dataset):
+
+ def filter_fn(item):
+ return any(item['labels'][i] >= 0 for i in range(self.max_length))
+
+ ori_samples = len(dataset)
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
+ results = list(
+ tqdm(
+ executor.map(filter_fn, dataset),
+ desc='Filter Dataset',
+ total=len(dataset)))
+
+ new_dataset = [x for x, passed in zip(dataset, results) if passed]
+
+ new_samples = len(new_dataset)
+ print_log(f'Before filter: {ori_samples} samples', logger='current')
+ print_log(f'After filter: {new_samples} samples', logger='current')
+ print_log(
+ f'Filtered {ori_samples - new_samples} samples '
+ '(all labels are ignore)',
+ logger='current')
+ return new_dataset
+
+ def analysis_image_samples(self, dataset):
+
+ def img_sample_counter(item):
+ return len(item['image_urls']) > 0
+
+ def img_counter(item):
+ return len(item['image_urls'])
+
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
+ images = list(
+ tqdm(
+ executor.map(img_counter, dataset),
+ desc='Count Images',
+ total=len(dataset)))
+
+ samples = list(
+ tqdm(
+ executor.map(img_sample_counter, dataset),
+ desc='Count Contain Image Samples',
+ total=len(dataset)))
+
+ num_images = sum(images)
+ num_samples = sum(samples)
+ print_log(
+ f'There are a total of {num_samples} samples with images, '
+ f'amounting to {num_images} images.',
+ logger='current')
+
+ def analysis_tokens_labels(self, dataset):
+
+ def label_counter(item):
+ return sum([1 for i in item['labels'] if i >= 0])
+
+ def token_counter(item):
+ return len(item['input_ids'])
+
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
+ tokens = list(
+ tqdm(
+ executor.map(token_counter, dataset),
+ desc='Count Tokens',
+ total=len(dataset)))
+
+ labels = list(
+ tqdm(
+ executor.map(label_counter, dataset),
+ desc='Count Labels',
+ total=len(dataset)))
+
+ num_tokens = sum(tokens)
+ num_labels = sum(labels)
+ print_log(
+ f'There are a total of {num_tokens} tokens, '
+ f'of which {num_labels} tokens need loss calculation.',
+ logger='current')
+
+ def cache(self, cache_dir: str):
+ cache_dir = Path(cache_dir)
+
+ if self.pack_to_max_length:
+ hf_dataset = Dataset.from_list(self.dataset.dataset)
+ else:
+ hf_dataset = Dataset.from_list(self.dataset)
+
+ hf_dataset.save_to_disk(cache_dir)
+
+ dset_conf = {
+ 'image_dir': str(self.image_dir),
+ 'data_files': self.data_files,
+ 'max_length': self.max_length,
+ 'chat_template': self.chat_template.model_dump(),
+ 'pack_to_max_length': self.pack_to_max_length,
+ 'tokenizer': type(self.tokenizer).__name__,
+ }
+
+ with open(cache_dir / 'dataset_configuration.json', 'w') as f:
+ json.dump(dset_conf, f)
+
+ self.tokenizer.save_pretrained(cache_dir / 'tokenizer')
+ self.image_processor.save_pretrained(cache_dir / 'image_processor')
+
+ def load_image(self, url):
+ image_file = self.image_dir / url
+ image = Image.open(image_file).convert('RGB')
+
+ if self.pad_img_to_squared:
+ background = tuple(
+ int(x * 255) for x in self.image_processor.image_mean)
+ image = expand2square(image, background)
+
+ image = self.image_processor.preprocess(
+ image, return_tensors='pt')['pixel_values'][0]
+
+ return image
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, item: int) -> Dict[str, List]:
+
+ data = self.dataset[item]
+
+ pixel_values = []
+ for url in data['image_urls']:
+ image = self.load_image(url)
+
+ pixel_values.append(image)
+
+ data['pixel_values'] = pixel_values
+
+ return data
+
+
+if __name__ == '__main__':
+
+ from transformers import CLIPImageProcessor
+
+ chat_template = HybridChatTemplate(
+ system='<|im_start|>system\n{system}<|im_end|>\n',
+ user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n',
+ assistant='{assistant}<|im_end|>\n',
+ stop_words=['<|im_end|>'],
+ image_token='',
+ function_call=
+ '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251
+ function_result=
+ '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251
+ functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n'
+ )
+
+ processor = CLIPImageProcessor.from_pretrained(
+ 'openai/clip-vit-large-patch14-336',
+ trust_remote_code=True,
+ )
+
+ from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens,
+ llava_to_openai,
+ openai_to_raw_training)
+
+ data_dir = './llava_data/LLaVA-Instruct-150K/'
+ image_dir = './llava_data/llava_images/'
+ data_files = 'llava_v1_5_mix665k.json'
+
+ dataset = HybridDataset(
+ 'internlm/internlm2-chat-1_8b',
+ chat_template,
+ sample_ratio=1,
+ max_length=32 * 1024,
+ data_dir=data_dir,
+ data_files=data_files,
+ image_dir=image_dir,
+ image_processor=processor,
+ pack_to_max_length=True,
+ mappings=[
+ llava_to_openai,
+ openai_to_raw_training,
+ insert_img_pad_tokens,
+ ],
+ num_workers=4)
+
+ print(dataset[0])
+
+ dataset.cache('cached_llava')
+ dataset = HybridDataset(
+ 'internlm/internlm2-chat-1_8b',
+ chat_template,
+ sample_ratio=1,
+ max_length=32 * 1024,
+ data_cached='cached_llava',
+ image_dir=image_dir,
+ image_processor=processor,
+ pack_to_max_length=True,
+ mappings=[
+ llava_to_openai,
+ openai_to_raw_training,
+ insert_img_pad_tokens,
+ ],
+ num_workers=4)
+ print(dataset[0])
+
+ from mmengine.dataset import DefaultSampler
+ from torch.utils.data import DataLoader
+
+ from xtuner.dataset.hybrid.collate import hybrid_collate_fn
+ loader = DataLoader(
+ dataset,
+ 4,
+ num_workers=0,
+ collate_fn=hybrid_collate_fn,
+ sampler=DefaultSampler(dataset, shuffle=True))
+
+ for data in tqdm(loader):
+ continue
diff --git a/xtuner/dataset/hybrid/mappings.py b/xtuner/dataset/hybrid/mappings.py
new file mode 100644
index 000000000..e104885ff
--- /dev/null
+++ b/xtuner/dataset/hybrid/mappings.py
@@ -0,0 +1,172 @@
+import re
+from typing import Callable, Dict, List, Type
+
+from mmengine.config.lazy import LazyObject
+
+from xtuner.types import TrainingHybridChatMessages
+
+
+def map_protocol(
+ input_keys: Dict[str, Type] = {},
+ output_keys: Dict[str, Type] = {},
+ added_keys: Dict[str, Type] = {},
+) -> Callable:
+
+ def decorator(func):
+
+ def wrapper(data, *args, **kwargs):
+
+ for key, _type in input_keys.items():
+ assert key in data
+ if not isinstance(data[key], _type):
+ breakpoint()
+
+ result = func(data, *args, **kwargs)
+
+ for key, _type in output_keys.items():
+ assert key in result
+ assert isinstance(result[key], _type)
+
+ return result
+
+ return wrapper
+
+ setattr(decorator, 'input_keys', input_keys)
+ setattr(decorator, 'output_keys', output_keys)
+ setattr(decorator, 'added_keys', added_keys)
+
+ return decorator
+
+
+def map_sequential(mappings: List[Callable]):
+
+ if not isinstance(mappings, List):
+ mappings = list(mappings)
+
+ for i in range(len(mappings)):
+ if isinstance(mappings[i], LazyObject):
+ mappings[i] = mappings[i].build()
+
+ def _sequential(item, tokenizer, chat_template):
+
+ for func in mappings:
+ item = func(item, tokenizer, chat_template)
+
+ return item
+
+ return _sequential
+
+
+@map_protocol(
+ input_keys=dict(input_ids=list, labels=list, image_urls=list),
+ output_keys=dict(
+ input_ids=list, labels=list, image_urls=list, image_ranges=list),
+)
+def insert_img_pad_tokens(data, tokenizer, chat_template) -> Dict:
+
+ image_urls = data['image_urls']
+ if len(image_urls) == 0:
+ data['image_ranges'] = []
+ return data
+
+ input_ids = data['input_ids']
+ labels = data['labels']
+
+ img_token = chat_template.image_token_index
+ img_token_inds = [i for i, t in enumerate(input_ids) if t == img_token]
+ assert len(img_token_inds) == len(
+ image_urls), f'{img_token_inds} {image_urls}'
+
+ for url, ind in zip(image_urls, img_token_inds):
+ # image = self.load_image(url)
+ h, w = 336 // 14, 336 // 14
+
+ pad_tokens = [tokenizer.pad_token_id] * (h * w)
+ pad_labels = [labels[ind]] * (h * w)
+
+ input_ids[ind] = pad_tokens
+ labels[ind] = pad_labels
+
+ new_ids = []
+ new_labels = []
+ assert len(input_ids) == len(labels)
+
+ img_ranges = []
+ for i, _ in enumerate(zip(input_ids, labels)):
+ if isinstance(input_ids[i], list):
+ assert isinstance(labels[i], list)
+ assert len(input_ids[i]) == len(labels[i])
+
+ img_begin = len(new_ids)
+ img_end = img_begin + len(input_ids[i])
+ img_ranges.append([img_begin, img_end])
+
+ new_ids.extend(input_ids[i])
+ new_labels.extend(labels[i])
+
+ else:
+ new_ids.append(input_ids[i])
+ new_labels.append(labels[i])
+
+ data['input_ids'] = new_ids
+ data['labels'] = new_labels
+ data['image_ranges'] = img_ranges
+
+ return data
+
+
+@map_protocol(
+ input_keys=dict(messages=list),
+ output_keys=dict(input_ids=list, labels=list, image_urls=list),
+)
+def openai_to_raw_training(item: dict, tokenizer, chat_template) -> Dict:
+
+ data = TrainingHybridChatMessages.from_dict(item)
+ data = data.tokenize(tokenizer, chat_template)
+
+ return data
+
+
+@map_protocol(
+ input_keys=dict(conversations=list),
+ output_keys=dict(messages=list),
+)
+def llava_to_openai(data, tokenizer=None, chat_template=None):
+
+ image_token = ''
+ conversations = data['conversations']
+ messages = []
+
+ if 'image' in data:
+ image_url = data['image']
+ else:
+ image_url = None
+
+ while conversations and conversations[0]['from'] == 'gpt':
+ # Skip the first one if it is from gpt
+ conversations = conversations[1:]
+
+ for convs in conversations:
+ if convs['from'] == 'human':
+ pattern = f'({image_token})'
+ chunks = re.split(pattern, convs['value'])
+
+ content = []
+ for chunk in chunks:
+ if chunk == image_token:
+ assert isinstance(image_url, str), image_url
+ item = dict(type='image_url', image_url=image_url)
+ content.append(item)
+ elif len(chunk.strip()):
+ item = dict(type='text', text=chunk.strip())
+ content.append(item)
+
+ msg = {'role': 'user', 'content': content}
+ messages.append(msg)
+
+ elif convs['from'] == 'gpt':
+ msg = {'role': 'assistant', 'content': convs['value']}
+ messages.append(msg)
+ else:
+ raise NotImplementedError
+ return {'messages': messages}
diff --git a/xtuner/model/__init__.py b/xtuner/model/__init__.py
index 39547b2d7..6ff30761d 100644
--- a/xtuner/model/__init__.py
+++ b/xtuner/model/__init__.py
@@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .agent import AgentFinetune
+from .auto import AutoModelForCausalLM
+from .hybrid import HybridFinetune
from .llava import LLaVAModel
from .sft import SupervisedFinetune
-__all__ = ['SupervisedFinetune', 'LLaVAModel']
+__all__ = [
+ 'HybridFinetune', 'SupervisedFinetune', 'LLaVAModel', 'AgentFinetune'
+]
diff --git a/xtuner/model/agent.py b/xtuner/model/agent.py
new file mode 100644
index 000000000..dd3f1bde6
--- /dev/null
+++ b/xtuner/model/agent.py
@@ -0,0 +1,138 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from mmengine.model import BaseModel
+from peft import LoraConfig
+from torch import nn
+
+from xtuner.registry import BUILDER
+from xtuner.utils.config import build_from_cfg_or_obj
+from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
+from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing,
+ get_peft_model_state_dict, prepare_for_llm_lora,
+ prepare_for_vision_lora,
+ smart_tokenizer_and_embedding_resize)
+
+
+class AgentFinetune(BaseModel):
+
+ def __init__(
+ self,
+ llm,
+ tokenizer=None,
+ llm_lora=None,
+ use_activation_checkpointing=True,
+ use_varlen_attn=False,
+ ):
+ super().__init__()
+
+ # Build the base language model without initialization.
+ # This will greatly reduce the time to build the model.
+ with LoadWoInit():
+ self.llm = build_from_cfg_or_obj(llm, nn.Module)
+ self.llm.config.use_cache = False
+ dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)
+
+ if tokenizer is not None:
+ if isinstance(tokenizer, dict):
+ tokenizer = BUILDER.build(tokenizer)
+ smart_tokenizer_and_embedding_resize(tokenizer, self.llm)
+
+ if use_activation_checkpointing:
+ # For backward compatibility
+ enable_hf_model_gradient_checkpointing(self.llm)
+
+ self.use_llm_lora = llm_lora is not None
+
+ # Prepare the model for LoRA if specified
+ if self.use_llm_lora:
+ lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig)
+ self.llm = prepare_for_llm_lora(self.llm, lora_conf,
+ use_activation_checkpointing)
+
+ self._is_init = True
+
+ # Determines whether to calculate attention based on the
+ # seq_len dimension (use_varlen_attn = False) or the actual length of
+ # the sequence.
+ self.use_varlen_attn = use_varlen_attn
+
+ def init_weights(self):
+ """Parent class method.
+
+ To avoid overwriting the loaded weights, overload it to an empty
+ function.
+ """
+ pass
+
+ def forward(self, data, data_samples=None, mode='loss'):
+ """Overload parent class method, only support training."""
+
+ if mode == 'loss':
+ return self.compute_loss(data)
+ else:
+ raise NotImplementedError(
+ f"{type(self)}'s forward is only supported for use during "
+ 'training. If you want to get predictions or chat, please '
+ "directly use `llm`'s forward.")
+
+ def compute_loss(self, data):
+
+ input_ids = data['input_ids']
+ labels = data['labels']
+ # position_ids = data['position_ids']
+ attention_mask = data['attention_mask']
+ # breakpoint()
+ bs, tokens = input_ids.shape
+ if self.use_varlen_attn:
+ assert bs == 1
+
+ cumulative_len = data['cumulative_len'][0]
+ max_seqlen = (cumulative_len[1:] - cumulative_len[:-1]).max()
+
+ position_ids = []
+ for i in range(1, len(cumulative_len)):
+ chunk_tokens = cumulative_len[i] - cumulative_len[i - 1]
+ position_ids.append(torch.arange(chunk_tokens))
+ position_ids = torch.cat(position_ids, dim=0).unsqueeze(0)
+
+ from mmengine import MessageHub
+ rank = dist.get_rank()
+ message_hub = MessageHub.get_instance('varlen_attn_args')
+ message_hub.update_info(f'cumulative_len_rank_{rank}',
+ cumulative_len)
+ message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)
+ else:
+
+ position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1)
+
+ outputs = self.llm(
+ input_ids=input_ids,
+ # position_ids=position_ids,
+ attention_mask=attention_mask,
+ labels=labels)
+
+ loss_dict = {'loss': outputs.loss}
+ return loss_dict
+
+ def state_dict(self, *args, **kwargs):
+ state_dict = super().state_dict(*args, **kwargs)
+ to_return = OrderedDict()
+ # Step 1. LLM
+ if self.use_llm_lora:
+ to_return.update(
+ get_peft_model_state_dict(self.llm, state_dict=state_dict))
+ else:
+ to_return.update(
+ {k: v
+ for k, v in state_dict.items() if 'llm.' in k})
+
+ return to_return
+
+ def __getattr__(self, name: str):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.llm, name)
diff --git a/xtuner/model/auto.py b/xtuner/model/auto.py
new file mode 100644
index 000000000..5546b1a5f
--- /dev/null
+++ b/xtuner/model/auto.py
@@ -0,0 +1,73 @@
+import os
+from typing import Dict, Optional, Union
+
+import torch
+
+from transformers import AutoConfig as HfAutoConfig
+from transformers import AutoModelForCausalLM as HfAutoModelForCausalLM
+from transformers import BitsAndBytesConfig
+
+from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
+
+
+
+
+class AutoModelForCausalLM:
+
+ @classmethod
+ def from_config(cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: bool = True,
+ **kwargs):
+ return HfAutoConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: bool = True,
+ quantization_config: Optional[BitsAndBytesConfig] = None,
+ **kwargs):
+
+ config = cls.from_config(
+ pretrained_model_name_or_path, trust_remote_code=True)
+ attn_kwargs = cls._flash_attn_kwargs(config)
+ kwargs.update(attn_kwargs)
+
+ if torch.cuda.is_bf16_supported():
+ kwargs.update(torch_dtype=torch.bfloat16)
+ else:
+ kwargs.update(torch_dtype=torch.float16)
+
+ model = HfAutoModelForCausalLM.from_pretrained(
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ quantization_config=quantization_config,
+ **kwargs)
+
+ return model
+
+ @staticmethod
+ def _flash_attn_kwargs(config):
+ cls_name = type(config).__name__
+ _built_in_flash_attn_1 = ('LlamaConfig', 'GemmaConfig',
+ 'MistralConfig', 'MixtralConfig',
+ 'Qwen2Config', 'Starcoder2Config',
+ 'Starcoder2Config')
+
+ _built_in_flash_attn_2 = ('InternLMConfig', 'InternLM2Config',
+ 'LlamaConfig', 'GemmaConfig',
+ 'MistralConfig', 'MixtralConfig',
+ 'Qwen2Config', 'Starcoder2Config',
+ 'Starcoder2Config')
+
+ attn_kwargs = {}
+ if SUPPORT_FLASH2 and cls_name in _built_in_flash_attn_2:
+ attn_kwargs.update(attn_implementation='flash_attention_2')
+ elif SUPPORT_FLASH1 and cls_name in _built_in_flash_attn_1:
+ attn_kwargs.update(attn_implementation='sdpa')
+
+ return attn_kwargs
\ No newline at end of file
diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py
new file mode 100644
index 000000000..0f0fc7e76
--- /dev/null
+++ b/xtuner/model/hybrid.py
@@ -0,0 +1,260 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from mmengine.model import BaseModel
+from peft import LoraConfig
+from torch import nn
+
+from xtuner.registry import BUILDER
+from xtuner.utils.config import build_from_cfg_or_obj
+from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
+from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing,
+ get_peft_model_state_dict, prepare_for_llm_lora,
+ prepare_for_vision_lora,
+ smart_tokenizer_and_embedding_resize)
+
+
+class HybridFinetune(BaseModel):
+
+ def __init__(
+ self,
+ llm,
+ visual_encoder=None,
+ visual_select_layer=-2,
+ projector_depth=2,
+ pretrained_pth=None,
+ tokenizer=None,
+ llm_lora=None,
+ visual_encoder_lora=None,
+ freeze_llm=False,
+ freeze_visual_encoder=False,
+ use_activation_checkpointing=True,
+ use_varlen_attn=False,
+ ):
+ super().__init__()
+
+ # Build the base language model without initialization.
+ # This will greatly reduce the time to build the model.
+ with LoadWoInit():
+ self.llm = build_from_cfg_or_obj(llm, nn.Module)
+ if visual_encoder:
+ visual_encoder = build_from_cfg_or_obj(visual_encoder,
+ nn.Module)
+ self.visual_encoder = visual_encoder
+ self.visual_select_layer = visual_select_layer
+ self.llm.config.use_cache = False
+ dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)
+
+ if tokenizer is not None:
+ if isinstance(tokenizer, dict):
+ tokenizer = BUILDER.build(tokenizer)
+ smart_tokenizer_and_embedding_resize(tokenizer, self.llm)
+
+ projector_config = ProjectorConfig(
+ visual_hidden_size=self.visual_encoder.config.hidden_size,
+ llm_hidden_size=self.llm.config.hidden_size,
+ depth=projector_depth)
+ self.projector = ProjectorModel(projector_config).to(
+ self.visual_encoder.dtype)
+
+ self.freeze_llm = freeze_llm
+ self.freeze_visual_encoder = freeze_visual_encoder
+ if self.freeze_llm:
+ self.llm.requires_grad_(False)
+ if self.freeze_visual_encoder:
+ self.visual_encoder.requires_grad_(False)
+
+ if use_activation_checkpointing:
+ # For backward compatibility
+ enable_hf_model_gradient_checkpointing(self.llm)
+ enable_hf_model_gradient_checkpointing(self.visual_encoder)
+
+ self.projector.enable_input_require_grads()
+ self.projector.gradient_checkpointing_enable()
+
+ self.use_llm_lora = llm_lora is not None
+ self.use_visual_encoder_lora = visual_encoder_lora is not None
+
+ # Prepare the model for LoRA if specified
+ if self.use_llm_lora:
+ lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig)
+ self.llm = prepare_for_llm_lora(self.llm, lora_conf,
+ use_activation_checkpointing)
+
+ if self.use_visual_encoder_lora:
+ lora_conf = build_from_cfg_or_obj(
+ visual_encoder_lora, accept=LoraConfig)
+ self.visual_encoder = prepare_for_vision_lora(
+ self.visual_encoder, lora_conf, use_activation_checkpointing)
+ self._is_init = True
+
+ # Determines whether to calculate attention based on the
+ # seq_len dimension (use_varlen_attn = False) or the actual length of
+ # the sequence.
+ self.use_varlen_attn = use_varlen_attn
+
+ def init_weights(self):
+ """Parent class method.
+
+ To avoid overwriting the loaded weights, overload it to an empty
+ function.
+ """
+ pass
+
+ def forward(self, data, data_samples=None, mode='loss'):
+ """Overload parent class method, only support training."""
+
+ if mode == 'loss':
+ return self.compute_loss(data)
+ else:
+ raise NotImplementedError(
+ f"{type(self)}'s forward is only supported for use during "
+ 'training. If you want to get predictions or chat, please '
+ "directly use `llm`'s forward.")
+
+ def _get_vision_embeds_and_ranges(self, data):
+
+ input_ids = data['input_ids']
+ pixel_values = data['pixel_values']
+ img_rngs = data['image_ranges']
+ img_belongs = data['image_belongs']
+
+ bs, tokens = input_ids.shape
+
+ img_embeds = []
+ ranges_in_flat_batch = []
+
+ if pixel_values is not None:
+ assert isinstance(pixel_values, torch.Tensor)
+ assert len(img_rngs) == len(img_belongs) == pixel_values.size(0)
+
+ batch_total_imgs = len(img_rngs)
+
+ visual_outputs = self.visual_encoder(
+ pixel_values, output_hidden_states=True)
+ features = self.projector(
+ visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
+ batch_total_imgs, real_img_tokens, _ = features.shape
+
+ for i in range(batch_total_imgs):
+ img_start, img_end = img_rngs[i]
+ exp_img_tokens = img_end - img_start
+ img_emb = features[i]
+ img_bs_ind = img_belongs[i]
+
+ if real_img_tokens == exp_img_tokens:
+ img_embeds.append(img_emb)
+ elif not real_img_tokens == exp_img_tokens and img_start == 0:
+ img_embeds.append(img_emb[real_img_tokens - img_end:])
+ elif (not real_img_tokens == exp_img_tokens
+ and img_end == tokens):
+ img_embeds.append(img_emb[:exp_img_tokens])
+ else:
+ raise RuntimeError
+
+ flat_offset = tokens * img_bs_ind
+
+ left = flat_offset + img_start
+ right = flat_offset + img_end
+ ranges_in_flat_batch.append((left, right))
+
+ return img_embeds, ranges_in_flat_batch
+
+ def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges):
+
+ assert len(mm_embeds) == len(ranges)
+ if len(mm_embeds) == 0:
+ return flat_embeds
+
+ _empty_embeds = torch.zeros_like(flat_embeds)
+ for (start, end), emb in zip(ranges, mm_embeds):
+ _empty_embeds[start:end] += emb
+
+ flat_embeds = flat_embeds * (_empty_embeds == 0)
+
+ return flat_embeds + _empty_embeds
+
+ def compute_loss(self, data):
+
+ input_ids = data['input_ids']
+ labels = data['labels']
+ # position_ids = data['position_ids']
+ attention_mask = data['attention_mask']
+ # breakpoint()
+ bs, tokens = input_ids.shape
+ if self.use_varlen_attn:
+ assert bs == 1
+
+ cumulative_len = data['cumulative_len'][0]
+ max_seqlen = (cumulative_len[1:] - cumulative_len[:-1]).max()
+
+ position_ids = []
+ for i in range(1, len(cumulative_len)):
+ chunk_tokens = cumulative_len[i] - cumulative_len[i - 1]
+ position_ids.append(torch.arange(chunk_tokens))
+ position_ids = torch.cat(position_ids, dim=0).unsqueeze(0)
+
+ from mmengine import MessageHub
+ rank = dist.get_rank()
+ message_hub = MessageHub.get_instance('varlen_attn_args')
+ message_hub.update_info(f'cumulative_len_rank_{rank}',
+ cumulative_len)
+ message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)
+ else:
+
+ position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1)
+
+ input_embeds = self.llm.get_input_embeddings()(input_ids)
+
+ bs, tokens, dim = input_embeds.shape
+ flat_embeds = input_embeds.flatten(0, 1)
+
+ img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data)
+ flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs,
+ flat_bs_img_rngs)
+ input_embeds = flat_embeds.reshape(bs, tokens, dim)
+
+ outputs = self.llm(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=input_embeds,
+ labels=labels)
+
+ loss_dict = {'loss': outputs.loss}
+ return loss_dict
+
+ def state_dict(self, *args, **kwargs):
+ state_dict = super().state_dict(*args, **kwargs)
+ to_return = OrderedDict()
+ # Step 1. visual_encoder
+ if self.use_visual_encoder_lora:
+ to_return.update(
+ get_peft_model_state_dict(
+ self.visual_encoder, state_dict=state_dict))
+ elif not self.freeze_visual_encoder:
+ to_return.update({
+ k: v
+ for k, v in state_dict.items() if 'visual_encoder.' in k
+ })
+ # Step 2. LLM
+ if self.use_llm_lora:
+ to_return.update(
+ get_peft_model_state_dict(self.llm, state_dict=state_dict))
+ elif not self.freeze_llm:
+ to_return.update(
+ {k: v
+ for k, v in state_dict.items() if 'llm.' in k})
+ # Step 3. Projector
+ to_return.update(
+ {k: v
+ for k, v in state_dict.items() if 'projector.' in k})
+ return to_return
+
+ def __getattr__(self, name: str):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.llm, name)
diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py
index a166e8bae..8aa664ad5 100644
--- a/xtuner/model/modules/dispatch/internlm2.py
+++ b/xtuner/model/modules/dispatch/internlm2.py
@@ -173,7 +173,7 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
max_seqlen):
q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
+
attn_output = flash_attn_varlen_func(
q_unpad,
k_unpad,
diff --git a/xtuner/model/modules/dispatch/utils.py b/xtuner/model/modules/dispatch/utils.py
index 4cfa26cd1..5355bce74 100644
--- a/xtuner/model/modules/dispatch/utils.py
+++ b/xtuner/model/modules/dispatch/utils.py
@@ -25,7 +25,6 @@ def upad_qkv(query_layer, key_layer, value_layer, attention_mask,
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim), indices_k)
diff --git a/xtuner/model/utils.py b/xtuner/model/utils.py
index dce86315d..0e5dfb826 100644
--- a/xtuner/model/utils.py
+++ b/xtuner/model/utils.py
@@ -1,13 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
+from contextlib import nullcontext
from typing import List, Optional
import torch
from mmengine import print_log
from mmengine.utils.misc import get_object_from_string
-from peft import PeftType
+from peft import (LoraConfig, PeftModel, PeftType, get_peft_model,
+ prepare_model_for_kbit_training)
from torch import nn
-from transformers import PreTrainedModel
+from transformers import PreTrainedModel, PreTrainedTokenizer
+from transformers.integrations import is_deepspeed_zero3_enabled
from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX
@@ -50,6 +53,34 @@ def find_all_linear_names(model):
return list(lora_module_names)
+def collect_linear_suffix_names(model: torch.nn.Module,
+ exclude: list[str] = []) -> list[str]:
+ """Collect suffix names of nn.Linear modules from a PyTorch model.
+
+ Args:
+ model: The PyTorch model.
+ exclude: A list of keys to be excluded from the collected
+ suffix names. Default: ['lm_head', 'output_layer'].
+
+ Returns:
+ A list of collected suffix names after excluding specified keys.
+ """
+ suffix_names = set()
+
+ # Iterate through all named modules in the model
+ for name, module in model.named_modules():
+ # Check if the module is an instance of nn.Linear
+ if isinstance(module, torch.nn.Linear):
+ names = name.split('.')
+ suffix_names.add(names[0] if len(names) == 1 else names[-1])
+
+ # Remove exclude_keys from the collected suffix_names
+ for key in exclude:
+ suffix_names.remove(key)
+
+ return list(suffix_names)
+
+
class LoadWoInit:
"""Context manager that disable parameter initialization."""
@@ -286,6 +317,73 @@ def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
+def prepare_for_llm_lora(model: PreTrainedModel,
+ lora_config: LoraConfig,
+ gradient_checkpointing: bool = True) -> PeftModel:
+ model = prepare_model_for_kbit_training(model, gradient_checkpointing)
+ if lora_config.target_modules is None:
+ modules = collect_linear_suffix_names(model, exclude=['output'])
+ lora_config.target_modules = modules
+
+ model = get_peft_model(model, lora_config)
+ return model
+
+
+def prepare_for_vision_lora(model: PreTrainedModel,
+ lora_config: LoraConfig,
+ gradient_checkpointing: bool = True) -> PeftModel:
+
+ if lora_config.target_modules is None:
+ modules = collect_linear_suffix_names(model)
+ lora_config.target_modules = modules
+
+ model = get_peft_model(model, lora_config)
+ return model
+
+
+def smart_tokenizer_and_embedding_resize(
+ tokenizer: PreTrainedTokenizer,
+ model: PreTrainedModel,
+):
+ """Resize embedding."""
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ params = [model.get_input_embeddings().weight]
+ if model.get_output_embeddings(
+ ) is not None and not model.config.tie_word_embeddings:
+ params.append(model.get_output_embeddings().weight)
+
+ context_maybe_zero3 = deepspeed.zero.GatheredParameters(
+ params, modifier_rank=0)
+ else:
+ context_maybe_zero3 = nullcontext()
+
+ with context_maybe_zero3:
+ current_embedding_size = model.get_input_embeddings().weight.size(0)
+
+ if len(tokenizer) > current_embedding_size:
+ assert isinstance(model.get_output_embeddings(), nn.Linear)
+
+ model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
+ with context_maybe_zero3:
+ num_new_tokens = len(tokenizer) - current_embedding_size
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ print_log(
+ f'Resized token embeddings from {current_embedding_size} to '
+ f'{len(tokenizer)}.', 'current')
+
+
def guess_load_checkpoint(pth_model):
if osp.isfile(pth_model):
state_dict = torch.load(pth_model, map_location='cpu')
@@ -307,3 +405,19 @@ def guess_load_checkpoint(pth_model):
else:
raise FileNotFoundError(f'Cannot find {pth_model}')
return state_dict
+
+
+def enable_hf_model_gradient_checkpointing(model: PreTrainedModel) -> None:
+ # For backward compatibility
+ if hasattr(model, 'enable_input_require_grads'):
+ model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad)
+
+ # enable gradient checkpointing for memory efficiency
+ model.gradient_checkpointing_enable()
diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py
new file mode 100644
index 000000000..cc230e8f8
--- /dev/null
+++ b/xtuner/types/__init__.py
@@ -0,0 +1,6 @@
+from .chat_template import HybridChatTemplate
+from .train import RawTrainingData, TrainingHybridChatMessages
+
+__all__ = [
+ 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages'
+]
diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py
new file mode 100644
index 000000000..cd0a4d4a7
--- /dev/null
+++ b/xtuner/types/chat.py
@@ -0,0 +1,173 @@
+from typing import Dict, List, Literal, Optional, Union
+
+from pydantic import BaseModel
+
+from .chat_template import HybridChatTemplate
+
+
+class TextContentItem(BaseModel):
+ type: Literal['text']
+ text: str
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return self.text
+
+
+class ImageContentItem(BaseModel):
+ type: Literal['image_url']
+ image_url: str
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return chat_template.image_token
+
+
+class FileContentItem(BaseModel):
+ type: Literal['file_url']
+ file_url: str
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return self.file_url
+
+
+MultModalContentType = Union[TextContentItem, ImageContentItem]
+ContentType = Union[str, List[MultModalContentType]]
+
+
+class ChatMsg(BaseModel):
+ role: Literal['assistant', 'user', 'system']
+ content: ContentType
+ files: List[Union[str, Dict]] = []
+
+ def collect_img_urls(self) -> List[str]:
+ img_urls = []
+ if isinstance(self.content, list):
+ for item in self.content:
+ if isinstance(item, ImageContentItem):
+ img_urls.append(item.image_url)
+ return img_urls
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+
+ if isinstance(self.content, str):
+ text = self.content
+ elif isinstance(self.content, list):
+ text = ''
+ for i, item in enumerate(self.content):
+ if i == 0:
+ text += item.apply_chat_template(chat_template)
+ else:
+ text += '\n' + item.apply_chat_template(chat_template)
+ else:
+ raise NotImplementedError
+
+ if self.role == 'system':
+ prompt = chat_template.decorate_system(text)
+ elif self.role == 'user':
+ if len(self.files) > 0:
+ stop_word = chat_template.stop_words[0]
+ text += f'\n{stop_word}\n{chat_template.decorate_files(self.files)}'
+ prompt = chat_template.decorate_user(text)
+
+ elif self.role == 'assistant':
+ prompt = chat_template.decorate_assistant(text)
+ else:
+ raise NotImplementedError
+
+ return prompt
+
+
+# Function Call
+
+
+class FunctionCall(BaseModel):
+ name: str
+ arguments: Dict
+
+
+class FunctionCallMsg(BaseModel):
+
+ role: Literal['assistant']
+ content: str
+ function_call: Union[str, Dict]
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+
+ return chat_template.decorate_function_call(self.content,
+ self.function_call)
+
+
+class FunctionResultMsg(BaseModel):
+ role: Literal['function']
+ name: str
+ content: Union[str, Dict]
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return chat_template.decorate_function_result(self.content)
+
+
+class CodeInterpreterCallMsg(BaseModel):
+
+ role: Literal['assistant']
+ content: str
+ conde_interpreter_call: Union[str, Dict]
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+
+ return chat_template.decorate_code_interpreter_call(
+ self.content, self.conde_interpreter_call)
+
+
+class CodeInterpreterResultMsg(BaseModel):
+ role: Literal['code_interpreter']
+ content: Union[str, Dict]
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+ return chat_template.decorate_code_interpreter_result(self.content)
+
+
+class Functions(BaseModel):
+
+ name: str
+ description: Union[str, Dict]
+ parameters: Union[str, Dict]
+
+
+HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg,
+ CodeInterpreterCallMsg, CodeInterpreterResultMsg]
+
+
+class HybridChatMessages(BaseModel):
+
+ messages: List[HybridChatMsgType] = []
+ # images: List[Image.Image] = []
+ functions: List[Functions] = []
+ code_interpreter: Optional[str] = None
+
+ # TODO (pppppM) add audio and video
+
+ def collect_img_urls(self) -> List[str]:
+ img_urls = []
+ for msg in self.messages:
+ img_urls.extend(msg.collect_img_urls())
+ return img_urls
+
+ def pop_latest_msg(self):
+ return self.messages.pop()
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+
+ prompt = ''
+
+ if self.code_interpreter:
+ prompt += chat_template.decorate_functions(self.code_interpreter)
+
+ if len(self.functions) > 0:
+
+ functions = [func.model_dump() for func in self.functions]
+
+ prompt += chat_template.decorate_functions(functions)
+
+ for msg in self.messages:
+ prompt += msg.apply_chat_template(chat_template)
+
+ return prompt
diff --git a/xtuner/types/chat_template.py b/xtuner/types/chat_template.py
new file mode 100644
index 000000000..4318c6104
--- /dev/null
+++ b/xtuner/types/chat_template.py
@@ -0,0 +1,195 @@
+from typing import Dict, List, Optional
+
+from pydantic import BaseModel, field_validator
+
+
+class HybridChatTemplate(BaseModel):
+ """Define a Pydantic data model for a hybrid chat with attributes for
+ system, user and assistant chat as well as function and interpreter calls
+ and results."""
+
+ # Normal Chat
+ system: str # System message format
+ user: str # User message format
+ assistant: str # Assistant message format
+ stop_words: List[str] # List of stop words
+
+ # Multimodal Chat
+ # Predefined token and index for images
+ image_token: str = ''
+ image_token_index: int = -100
+
+ # Agent Chat
+
+ # Interpreter and function related strings
+ files: Optional[str] = None
+
+ functions: Optional[str] = None # Function description format
+ function_call: Optional[str] = None # Function call format
+ function_result: Optional[str] = None # Function result format
+
+ code_interpreter: Optional[str] = None
+ code_interpreter_call: Optional[str] = None # Interpreter call format
+ code_interpreter_result: Optional[str] = None # Interpreter result format
+
+ function_token: Optional[str] = None
+ code_interpreter_token: Optional[str] = None
+ action_start_token: Optional[str] = None
+ action_end_token: Optional[str] = None
+
+ @property
+ def mm_token_maps(self) -> Dict[str, int]:
+ """Return a dictionary that maps multimodal tokens to corresponding
+ token indexes."""
+ return {self.image_token: self.image_token_index}
+
+ def decorate_system(self, text: str) -> str:
+ """Decorate text with the `system` template."""
+ return self.system.format(system=text)
+
+ def decorate_assistant(self, text: str) -> str:
+ """Decorate text with the `assistant` template."""
+ return self.assistant.format(assistant=text)
+
+ def decorate_user(self, text: str) -> str:
+ """Decorate text with the `user` template."""
+ return self.user.format(user=text)
+
+ def decorate_files(self, text: str) -> str:
+ """Decorate text with the `functions` template."""
+ return self.files.format(files=text)
+
+ def decorate_functions(self, text: str) -> str:
+ """Decorate text with the `functions` template."""
+ return self.functions.format(functions=text)
+
+ def decorate_function_call(self, text: str, func: str) -> str:
+ """Decorate text with the `function_call` template."""
+ return self.function_call.format(assistant=text, function_call=func)
+
+ def decorate_function_result(self, text: str) -> str:
+ """Decorate text with the `function_result` template."""
+ return self.function_result.format(function_result=text)
+
+ def decorate_code_interpreter(self, text: str) -> str:
+ """Decorate text with the `code_interpreter` template."""
+ return self.code_interpreter.format(code_interpreter=text)
+
+ def decorate_code_interpreter_call(self, text: str, func: str) -> str:
+ """Decorate text with the `code_interpreter_call` template."""
+ return self.code_interpreter_call.format(
+ assistant=text, code_interpreter_call=func)
+
+ def decorate_code_interpreter_result(self, text: str) -> str:
+ """Decorate text with the `code_interpreter_result` template."""
+ return self.code_interpreter_result.format(
+ code_interpreter_result=text)
+
+ @field_validator('system')
+ def check_system(cls, v: str) -> str:
+ """Validate that `system` contains '{system}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{system}' not in v:
+ raise ValueError("system must contain the keyword '{system}'")
+ return v
+
+ @field_validator('user')
+ def check_user(cls, v: str) -> str:
+ """Validate that `user` contains '{user}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{user}' not in v:
+ raise ValueError("user must contain the keyword '{user}'")
+ return v
+
+ @field_validator('assistant')
+ def check_assistant(cls, v: str) -> str:
+ """Validate that `assistant` contains '{assistant}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{assistant}' not in v:
+ raise ValueError(
+ "assistant must contain the keyword '{assistant}'")
+ return v
+
+ @field_validator('function_call')
+ def check_function_call(cls, v: str) -> str:
+ """Validate that `function_call` contains '{function_call}'.
+
+ If not, raises a ValueError.
+ """
+ if (v is not None and '{function_call}' not in v
+ and '{assistant}' not in v):
+ raise ValueError(
+ "function_call must contain the keywords '{function_call}'")
+ if v is not None and '{assistant}' not in v:
+ raise ValueError(
+ "function_call must contain the keyword '{assistant}' and "
+ "'{function_call}'")
+ return v
+
+ @field_validator('function_result')
+ def check_function_result(cls, v: str) -> str:
+ """Validate that `function_result` contains '{function_result}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{function_result}' not in v:
+ raise ValueError(
+ "function_result must contain the keyword '{function_result}'")
+ return v
+
+ @field_validator('functions')
+ def check_functions(cls, v: str) -> str:
+ """Validate that `functions` contains '{functions}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{functions}' not in v:
+ raise ValueError(
+ "functions must contain the keyword '{functions}'")
+ return v
+
+ @field_validator('code_interpreter')
+ def check_code_interpreter(cls, v: str) -> str:
+ """Validate that `code_interpreter` contains '{code_interpreter}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{code_interpreter}' not in v:
+ raise ValueError('code_interpreter must contain the keyword '
+ "'{code_interpreter}'")
+ return v
+
+ @field_validator('code_interpreter_call')
+ def check_code_interpreter_call(cls, v: str) -> str:
+ """Validate that `code_interpreter_call` contains
+ '{code_interpreter_call}'.
+
+ If not, raises a ValueError.
+ """
+ if (v is not None and '{code_interpreter_call}' not in v
+ and '{assistant}' not in v):
+ raise ValueError('code_interpreter_call must contain the keywords '
+ "'{assistant}' and '{code_interpreter_call}'")
+ if v is not None and '{assistant}' not in v:
+ raise ValueError('code_interpreter_call must contain the keywords '
+ "'{assistant}' and '{code_interpreter_call}'")
+ return v
+
+ @field_validator('code_interpreter_result')
+ def check_code_interpreter_result(cls, v: str) -> str:
+ """Validate that `code_interpreter_result` contains
+ '{code_interpreter_result}'.
+
+ If not, raises a ValueError.
+ """
+ if v is not None and '{code_interpreter_result}' not in v:
+ raise ValueError(
+ 'code_interpreter_result must contain the keyword '
+ "'{code_interpreter_result}'")
+ return v
diff --git a/xtuner/types/train.py b/xtuner/types/train.py
new file mode 100644
index 000000000..9971308cd
--- /dev/null
+++ b/xtuner/types/train.py
@@ -0,0 +1,364 @@
+import copy
+import re
+from typing import Dict, List, Optional, Union
+
+import torch
+from pydantic import BaseModel
+from transformers.tokenization_utils import PreTrainedTokenizer
+
+from xtuner.utils import IGNORE_INDEX
+from xtuner.utils.tokenizer import get_bos_token_ids
+from .chat import (ChatMsg, CodeInterpreterCallMsg, CodeInterpreterResultMsg,
+ FileContentItem, FunctionCallMsg, FunctionResultMsg,
+ Functions, ImageContentItem, TextContentItem)
+from .chat_template import HybridChatTemplate
+
+
+class TrainingChatMsg(ChatMsg):
+ loss: Optional[bool] = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if self.loss is None:
+ if self.role == 'system':
+ self.loss = False
+ elif self.role == 'user':
+ self.loss = False
+ elif self.role == 'assistant':
+ self.loss = True
+ else:
+ raise NotImplementedError
+
+ def _encode_mm_content(self, text: str, mm_token_maps: Dict[str, int],
+ tokenizer: PreTrainedTokenizer):
+
+ mm_tokens = mm_token_maps.keys()
+
+ pattern = r'(' + '|'.join(mm_tokens) + r')'
+ chunks = re.split(pattern, text)
+
+ assert len(chunks) > 1
+
+ token_ids = []
+ for c in chunks:
+ if c in mm_tokens:
+ token_ids.append(mm_token_maps[c])
+ else:
+ token_ids.extend(tokenizer.encode(c, add_special_tokens=False))
+
+ return token_ids
+
+ def _with_multi_modal_content(self):
+ flag = False
+
+ if isinstance(self.content, list):
+ for item in self.content:
+ # TODO (pppppM) support video and audio
+ if isinstance(item, ImageContentItem):
+ flag = True
+ break
+ return flag
+
+ def tokenize(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ chat_template: HybridChatTemplate,
+ ):
+
+ decorated = self.apply_chat_template(chat_template)
+
+ if self._with_multi_modal_content():
+ token_maps = chat_template.mm_token_maps
+ token_ids = self._encode_mm_content(decorated, token_maps,
+ tokenizer)
+ else:
+ token_ids = tokenizer.encode(decorated, add_special_tokens=False)
+
+ if self.loss:
+ label_ids = copy.deepcopy(token_ids)
+ else:
+ label_ids = [IGNORE_INDEX] * len(token_ids)
+
+ image_urls = self.collect_img_urls()
+
+ return {
+ 'input_ids': token_ids,
+ 'labels': label_ids,
+ 'image_urls': image_urls
+ }
+
+
+class TrainingFunctionCallMsg(FunctionCallMsg):
+ loss: bool = True
+
+ def tokenize(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ chat_template: HybridChatTemplate,
+ ):
+
+ decorated = self.apply_chat_template(chat_template)
+
+ token_ids = tokenizer.encode(decorated, add_special_tokens=False)
+
+ if self.loss:
+ label_ids = copy.deepcopy(token_ids)
+ else:
+ label_ids = [IGNORE_INDEX] * len(token_ids)
+
+ return {'input_ids': token_ids, 'labels': label_ids}
+
+
+class TrainingFunctionResultMsg(FunctionResultMsg):
+ loss: bool = False
+
+ def tokenize(self, tokenizer, chat_template: HybridChatTemplate):
+
+ decorated = self.apply_chat_template(chat_template)
+
+ token_ids = tokenizer.encode(decorated, add_special_tokens=False)
+
+ if self.loss:
+ label_ids = copy.deepcopy(token_ids)
+ else:
+ label_ids = [IGNORE_INDEX] * len(token_ids)
+
+ return {'input_ids': token_ids, 'labels': label_ids}
+
+
+class TrainingCodeInterpreterCallMsg(CodeInterpreterCallMsg):
+ loss: bool = True
+
+ def tokenize(self, tokenizer, chat_template: HybridChatTemplate):
+
+ decorated = self.apply_chat_template(chat_template)
+
+ token_ids = tokenizer.encode(decorated, add_special_tokens=False)
+
+ if self.loss:
+ label_ids = copy.deepcopy(token_ids)
+ else:
+ label_ids = [IGNORE_INDEX] * len(token_ids)
+
+ return {'input_ids': token_ids, 'labels': label_ids}
+
+
+class TrainingCodeInterpreterResultMsg(CodeInterpreterResultMsg):
+ loss: bool = False
+
+ def tokenize(self, tokenizer, chat_template: HybridChatTemplate):
+
+ decorated = self.apply_chat_template(chat_template)
+
+ token_ids = tokenizer.encode(decorated, add_special_tokens=False)
+
+ if self.loss:
+ label_ids = copy.deepcopy(token_ids)
+ else:
+ label_ids = [IGNORE_INDEX] * len(token_ids)
+
+ return {'input_ids': token_ids, 'labels': label_ids}
+
+
+class RawTrainingData(BaseModel):
+
+ input_ids: List[int]
+ labels: List[int]
+ image_urls: List[str] = []
+
+
+class ProcessedTrainingData(BaseModel):
+
+ input_ids: List[int]
+ labels: List[int]
+ length: int
+ cumulative_len: List[int]
+ position_ids: List[int]
+ image_urls: List[str] = []
+ pixel_values: List[torch.Tensor] = []
+ image_ranges: List[tuple] = []
+
+ class Config:
+ arbitrary_types_allowed = True
+
+
+TraingHybridMessageType = Union[TrainingChatMsg, TrainingFunctionCallMsg,
+ TrainingFunctionResultMsg,
+ TrainingCodeInterpreterCallMsg,
+ TrainingCodeInterpreterResultMsg]
+
+
+class TrainingHybridChatMessages(BaseModel):
+ messages: List[TraingHybridMessageType]
+ functions: Optional[List[Functions]] = None
+ code_interpreter: Optional[str] = None
+
+ @classmethod
+ def from_dict(cls, item) -> 'TrainingHybridChatMessages':
+ '''
+ item
+ {
+ 'messages':[
+ {'role':'user', 'content':'hello'},
+ {'role':'assistant', 'content':'hello!'},
+ ],
+ 'funcitons': [],
+ }
+
+ '''
+
+ assert 'messages' in item, item
+
+ _messages = item['messages']
+ messages = []
+ functions = None
+ code_interpreter = None
+
+ for _msg in _messages:
+ assert 'role' in _msg and 'content' in _msg
+ _role = _msg['role']
+ _content = _msg['content']
+ if _role == 'function':
+ msg_factory = TrainingFunctionResultMsg
+ assert 'name' in _msg
+ name = _msg['name']
+ msg = msg_factory(role=_role, name=name, content=_content)
+ messages.append(msg)
+ continue
+
+ if _role == 'code_interpreter':
+ msg_factory = TrainingCodeInterpreterResultMsg
+ msg = msg_factory(role=_role, content=_content)
+ messages.append(msg)
+ continue
+
+ if isinstance(_content, list):
+
+ content = []
+ for c_item in _content:
+ assert 'type' in c_item
+ _type = c_item['type']
+ if _type == 'text':
+ assert 'text' in c_item
+ _text = c_item['text']
+ content.append(TextContentItem(type=_type, text=_text))
+ elif _type == 'image_url':
+ assert 'image_url' in c_item
+ _url = c_item['image_url']
+ content.append(
+ ImageContentItem(type=_type, image_url=_url))
+ elif _type == 'file_url':
+ assert 'file_url' in c_item
+ _url = c_item['file_url']
+ content.append(FileContentItem(file_url=_url))
+ else:
+ raise NotImplementedError
+
+ msg = TrainingChatMsg(role=_role, content=content)
+ messages.append(msg)
+ continue
+
+ if isinstance(_content, str) and 'function_call' in _msg:
+ _call = _msg['function_call']
+ msg = TrainingFunctionCallMsg(
+ role=_role, content=_content, function_call=_call)
+ messages.append(msg)
+ continue
+ elif isinstance(_content, str) and 'code_interpreter' in _msg:
+ _call = _msg['function_call']
+ msg = TrainingCodeInterpreterCallMsg(
+ role=_role, content=_content, code_interpreter_call=_call)
+ messages.append(msg)
+ continue
+
+ if isinstance(_content, str):
+ # breakpoint()
+ msg = TrainingChatMsg(role=_role, content=_content)
+ messages.append(msg)
+
+ # TODO (pppppM) add format warning
+
+ if 'code_interpreter' in item:
+
+ assert isinstance(item['code_interpreter'], str)
+ code_interpreter = item['code_interpreter']
+
+ if 'functions' in item:
+ _functions = item['functions']
+ assert isinstance(_functions, list)
+ functions = []
+
+ for _func in _functions:
+ assert 'name' in _func
+ assert 'description' in _func
+ assert 'parameters' in _func
+
+ _name = _func['name']
+ _desc = _func['description']
+ _params = _func['parameters']
+
+ func = Functions(
+ name=_name, description=_desc, parameters=_params)
+ functions.append(func)
+
+ return cls(
+ messages=messages,
+ functions=functions,
+ code_interpreter=code_interpreter)
+
+ def collect_img_urls(self) -> List[str]:
+ img_urls = []
+ for msg in self.messages:
+ img_urls.extend(msg.collect_img_urls())
+ return img_urls
+
+ def pop_latest_msg(self):
+ return self.messages.pop()
+
+ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
+
+ prompt = ''
+
+ if isinstance(self.functions, list) and len(self.functions) > 0:
+
+ functions = [func.model_dump() for func in self.functions]
+
+ prompt += chat_template.decorate_functions(functions)
+
+ for msg in self.messages:
+ if msg.role == 'system':
+ prompt = msg.apply_chat_template(chat_template) + prompt
+ else:
+ prompt += msg.apply_chat_template(chat_template)
+
+ return prompt
+
+ def tokenize(self, tokenizer: PreTrainedTokenizer,
+ chat_template: HybridChatTemplate) -> RawTrainingData:
+
+ input_ids = []
+ labels = []
+ image_urls = []
+
+ bos_token_ids = get_bos_token_ids(tokenizer)
+ input_ids.extend(bos_token_ids)
+ labels.extend([IGNORE_INDEX] * len(bos_token_ids))
+
+ for msg in self.messages:
+ res = msg.tokenize(tokenizer, chat_template)
+ token_ids, label_ids = res['input_ids'], res['labels']
+
+ input_ids.extend(token_ids)
+ labels.extend(label_ids)
+
+ if 'image_urls' in res:
+ image_urls.extend(res['image_urls'])
+
+ # TODO (pppppM) Verify whether sep and suffix_as_eos are necessary
+
+ training_data = {
+ 'input_ids': input_ids,
+ 'labels': labels,
+ 'image_urls': image_urls
+ }
+ return training_data
diff --git a/xtuner/utils/__init__.py b/xtuner/utils/__init__.py
index 6bc9a1173..75bcad2bf 100644
--- a/xtuner/utils/__init__.py
+++ b/xtuner/utils/__init__.py
@@ -3,9 +3,10 @@
IGNORE_INDEX, IMAGE_TOKEN_INDEX)
from .stop_criteria import StopWordStoppingCriteria
from .templates import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
+from .tokenizer import build_tokenizer
__all__ = [
'IGNORE_INDEX', 'DEFAULT_PAD_TOKEN_INDEX', 'PROMPT_TEMPLATE',
'DEFAULT_IMAGE_TOKEN', 'SYSTEM_TEMPLATE', 'StopWordStoppingCriteria',
- 'IMAGE_TOKEN_INDEX'
+ 'IMAGE_TOKEN_INDEX', 'build_tokenizer'
]
diff --git a/xtuner/utils/config.py b/xtuner/utils/config.py
new file mode 100644
index 000000000..0514dd8bf
--- /dev/null
+++ b/xtuner/utils/config.py
@@ -0,0 +1,131 @@
+import dataclasses
+from typing import TypeVar, Union
+
+import torch
+from mmengine.config import Config
+from mmengine.logging import print_log
+from mmengine.utils import get_object_from_string
+
+from xtuner.registry import BUILDER
+
+
+def convert_dtype_cfg_to_obj(config: Union[dict, list[dict]]) -> None:
+ """Convert dtype related config to python object.
+
+ When MMEngine Runner is training, it will save the config file for
+ resuming training.
+ But in the saved config file, python objects of type torch.dtype are
+ converted to strings like 'torch.float16'. In order to accommodate this,
+ after loading the config, all dtype strings need to be converted into
+ python objects.
+
+ Args:
+ config: A dict or list that potentially contains dtypes as strings.
+
+ Returns:
+ None. The input 'config' is modified in-place.
+ """
+ # If the config is a dictionary
+ if isinstance(config, dict):
+ for key, value in config.items():
+ # Recursively call the function if the value is also a dict
+ if isinstance(value, dict):
+ convert_dtype_cfg_to_obj(value)
+
+ # Replace the string with the corresponding dtype object
+ # if it's a recognized dtype string
+ elif value in ['torch.float16', 'torch.float32', 'torch.bfloat16']:
+ config[key] = getattr(torch, value.split('.')[-1])
+
+ # If the config is a list
+ elif isinstance(config, list):
+ for item in config:
+ convert_dtype_cfg_to_obj(item)
+
+
+def convert_dataclass_cfg_to_obj(config: Union[dict, list[dict]]) -> None:
+ """Convert dataclass related config to python object.
+
+ Huggingface's code uses dataclasses extensively.
+ In order to use Huggingface's interfaces in the MMEngine config,
+ we need to specifically handle these configurations.
+
+ Note:
+ Before executing this function, you must first run
+ `convert_dtype_cfg_to_obj`, otherwise the dataclass config containing
+ dtype cannot be properly converted !
+
+ Args:
+ config: A dictionary or list that potentially contains configurations
+ as dataclasses.
+
+ Returns:
+ None. The input 'config' is modified in-place.
+ """
+ # If the config is a dictionary
+ if isinstance(config, dict):
+ for key, value in config.items():
+ # Recursively call the function if the value is also a dict
+ if isinstance(value, dict):
+ convert_dataclass_cfg_to_obj(value)
+
+ # Check if the type of value is a dataclass
+ if 'type' in value and dataclasses.is_dataclass(value['type']):
+ builder = value.pop(
+ 'type') # remove 'type' from value and get its content
+
+ # Convert the builder to an object if it is a string
+ if isinstance(builder, str):
+ builder = get_object_from_string(builder)
+
+ # Build a new_value using the remaining items in value
+ new_value = builder(**value)
+ # replace the original value with new_value
+ config[key] = new_value
+ print_log(f'{key} convert to {builder}')
+
+ # If the config is a list
+ elif isinstance(config, list):
+ for item in config:
+ convert_dataclass_cfg_to_obj(item)
+
+
+OBJ_T = TypeVar('OBJ_T')
+
+
+def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T],
+ accept: OBJ_T) -> OBJ_T:
+ """Build a python object from a config or return an existed object.
+
+ Args:
+ cfg_or_obj (dict, OBJ_T]): an object of a type specified in
+ `accept_obj_types`, or a dict.
+ accept_obj (OBJ_T): the type of object that return without any
+ modification.
+
+ Returns:
+ If 'cfg_or_obj' is an object of `accept_obj` , it is returned directly.
+ If 'cfg_or_obj' is a dict, it is built into an object using
+ `build_from_cfg()`.
+
+ Raises:
+ TypeError: If `cfg_or_obj` is not dict or `accept_obj`.
+ """
+
+ if isinstance(cfg_or_obj, accept):
+ return cfg_or_obj
+
+ elif isinstance(cfg_or_obj, (dict, Config)):
+ convert_dtype_cfg_to_obj(cfg_or_obj)
+ convert_dataclass_cfg_to_obj(cfg_or_obj)
+ obj = BUILDER.build(cfg_or_obj)
+
+ if not isinstance(obj, accept):
+ raise TypeError(
+ f'Expect an object of {accept}, but there is an object of '
+ f'{type(obj)}.')
+ return obj
+
+ else:
+ raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got '
+ f'{type(cfg_or_obj)}')
diff --git a/xtuner/utils/tokenizer.py b/xtuner/utils/tokenizer.py
new file mode 100644
index 000000000..7c79b9fba
--- /dev/null
+++ b/xtuner/utils/tokenizer.py
@@ -0,0 +1,46 @@
+from typing import List, Union
+
+from transformers import AutoTokenizer
+
+from xtuner.registry import BUILDER
+
+
+def build_tokenizer(tokenizer: Union[str, dict]):
+
+ if isinstance(tokenizer, str):
+ return AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
+ elif isinstance(tokenizer, dict):
+ return BUILDER.build(tokenizer)
+ else:
+ raise TypeError
+
+
+def get_bos_token_ids(tokenizer) -> List[int]:
+
+ if tokenizer.__class__.__name__ == 'QWenTokenizer':
+ bos_token_ids = []
+ elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
+ bos_token_ids = [64790, 64792]
+ else:
+ bos_token_ids = tokenizer.bos_token_id
+
+ if isinstance(bos_token_ids, int):
+ bos_token_ids = [bos_token_ids]
+
+ return bos_token_ids
+
+
+def get_eos_token_ids(tokenizer) -> List[int]:
+ if tokenizer.__class__.__name__ == 'QWenTokenizer':
+ eos_token_ids = tokenizer.eos_token_id
+ assert eos_token_ids is not None, \
+ 'Please set eos_token for Qwen tokenizer!'
+ elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
+ eos_token_ids = tokenizer.eos_token_id
+ else:
+ eos_token_ids = tokenizer.eos_token_id
+
+ if isinstance(eos_token_ids, int):
+ eos_token_ids = [eos_token_ids]
+
+ return eos_token_ids