diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json new file mode 100644 index 000000000..e7d8966e9 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json @@ -0,0 +1,62 @@ +{ + "messages": [ + {"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"}, + { + "role": "user", + "content": "Please help me process and visualize this dataset.", + "files": [{"path": "data.csv", "size": "10K"}] + }, + { + "role": "assistant", + "content": "I have processed the data and visualized it for you.", + "code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='Date: %{text}
Wind Direction 9am: %{x}
Rainfall: %{y}
Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```" + }, + { + "role": "code_interpreter", + "content": "![image](xxx.png)" + }, + { + "role": "assistant", + "content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background." + }, + { + "role": "user", + "content": "I want to know today's weather in Shanghai" + }, + { + "role": "assistant", + "content": "Sure, I will search for the weather of Shanghai.", + "function_call": { + "name": "get_current_weather", + "parameters": {"location": "Shanghai"} + } + }, + { + "role": "function", + "name": "get_current_weather", + "content": "{'temperature': 22}" + }, + { + "role": "assistant", + "content": "The weather in Shanghai is 22 celsius" + } + ], + + "functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + "unit": {"type": "string"}}, + "required": ["location"] + } + } + } + ], + + "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"} diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py new file mode 100644 index 000000000..e9d5796bc --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py @@ -0,0 +1,33 @@ +import json + +from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages + +chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + files='<|im_start|>user name=file\n{files}<|im_end|>\n', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n', + code_interpreter_call= + '{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + code_interpreter_result= + '<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + code_interpreter= + '<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n') + +agent_data = json.load(open('agent.json')) + +msg = TrainingHybridChatMessages.from_dict(agent_data) +print(msg.apply_chat_template(chat_template)) + +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained( + 'internlm/internlm2-chat-7b', trust_remote_code=True) +print(msg.tokenize(tokenizer, chat_template)) diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json new file mode 100644 index 000000000..a7a1abbf3 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json @@ -0,0 +1,52 @@ +[ + { + "messages": [ + { + "role": "user", + "content": "I want to know today's weather in Shanghai" + }, + + { + "role": "assistant", + "content": "Sure, I will search for the weather of Shanghai.", + "function_call": { + "name": "get_current_weather", + "parameters": { + "location": "Shanghai" + } + } + }, + + { + "role": "function", + "name": "get_current_weather", + "content": "{'temperature': 22}" + }, + { + "role": "assistant", + "content": "The weather in Shanghai is 22 celsius" + } + + + ], + + "functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + "unit": {"type": "string"} + }, + "required": ["location"] + } + } + } + ] + } + +] diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py new file mode 100644 index 000000000..481b22620 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import AutoTokenizer + +from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn +from xtuner.dataset.hybrid.mappings import openai_to_raw_training +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import AgentFinetune, AutoModelForCausalLM +from xtuner.types import HybridChatTemplate + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/petrelfs/share_data/basemodel/checkpoints/llm/hf_hub/models--internlm--internlm2-chat-1_8b/snapshots/aa8a7450c2227a3b6733b3c6fe33fefbb2ca54f9/' + +# Data +data_dir = './' +data_files = ['agentlego.json'] +max_length = 2048 + +# Chat Template +chat_template = dict( + type=HybridChatTemplate, + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n') + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 1 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=AgentFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +dataset = dict( + type=HybridDataset, + data_dir=data_dir, + data_files=data_files, + sample_ratio=1, + tokenizer=tokenizer, + chat_template=chat_template, + max_length=max_length, + pack_to_max_length=True, + num_workers=4, + mappings=[openai_to_raw_training]) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=hybrid_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + # dict( + # type=EvaluateChatHook, + # tokenizer=tokenizer, + # image_processor=image_processor, + # every_n_iters=evaluation_freq, + # evaluation_inputs=evaluation_inputs, + # evaluation_images=evaluation_images, + # system=SYSTEM, + # prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py new file mode 100644 index 000000000..a010f44b5 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) + +from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn +from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens, + llava_to_openai, + openai_to_raw_training) +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import HybridFinetune +from xtuner.types import HybridChatTemplate + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +# llm_name_or_path = '/mnt/petrelfs/share_data/basemodel/checkpoints/llm/hf_hub/models--internlm--internlm2-chat-1_8b/snapshots/aa8a7450c2227a3b6733b3c6fe33fefbb2ca54f9/' +llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/' +visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336' +use_varlen_attn = False +# Specify the pretrained pth +pretrained_pth = None +# Data +data_dir = './llava_data/' +data_files = ['LLaVA-Instruct-150K/llava_v1_5_mix665k.json'] +image_dir = data_dir + 'llava_images' +max_length = 1024 * 2 + +# Chat Template +chat_template = dict( + type=HybridChatTemplate, + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n') + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 1 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 0 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor.from_pretrained, + pretrained_model_name_or_path=visual_encoder_name_or_path, + trust_remote_code=True) + +model = dict( + type=HybridFinetune, + freeze_llm=False, + freeze_visual_encoder=True, + pretrained_pth=pretrained_pth, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + attn_implementation='flash_attention_2', + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), + visual_encoder=dict( + type=CLIPVisionModel.from_pretrained, + pretrained_model_name_or_path=visual_encoder_name_or_path), + visual_encoder_lora=dict( + type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=HybridDataset, + data_dir=data_dir, + data_files=data_files, + # data_cached='cached_llava', + image_dir=image_dir, + sample_ratio=0.1, + tokenizer=tokenizer, + chat_template=chat_template, + image_processor=image_processor, + pad_img_to_squared=True, + max_length=max_length, + pack_to_max_length=False, + num_workers=4, + mappings=[ + llava_to_openai, + openai_to_raw_training, + insert_img_pad_tokens, + ]) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=4, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=hybrid_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + # dict( + # type=EvaluateChatHook, + # tokenizer=tokenizer, + # image_processor=image_processor, + # every_n_iters=evaluation_freq, + # evaluation_inputs=evaluation_inputs, + # evaluation_images=evaluation_images, + # system=SYSTEM, + # prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json new file mode 100644 index 000000000..ebe5cf457 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json @@ -0,0 +1,39 @@ +[ + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": "image1.jpg" + }, + { + "type": "image_url", + "image_url": "image2.jpg" + }, + { + "type": "text", + "text": "What are the colors of the bus in the first image?" + } + ] + }, + + { + "role": "assistant", + "content": "The bus in the image is white and red." + }, + + { + "role": "user", + "content": "Where is the cat positioned in the second image?" + }, + + { + "role": "assistant", + "content": "The cat is positioned on top of the back of the couch in the living room." + } + + ] + } +] diff --git a/xtuner/dataset/hybrid/__init__.py b/xtuner/dataset/hybrid/__init__.py new file mode 100644 index 000000000..febf1a497 --- /dev/null +++ b/xtuner/dataset/hybrid/__init__.py @@ -0,0 +1,14 @@ +from .collate import hybrid_collate_fn +from .dataset import HybridDataset +from .mappings import (insert_img_pad_tokens, llava_to_openai, map_protocol, + map_sequential, openai_to_raw_training) + +__all__ = [ + 'hybrid_collate_fn', + 'HybridDataset', + 'insert_img_pad_tokens', + 'llava_to_openai', + 'map_protocol', + 'map_sequential', + 'openai_to_raw_training', +] diff --git a/xtuner/dataset/hybrid/_pack.py b/xtuner/dataset/hybrid/_pack.py new file mode 100644 index 000000000..12b29fbca --- /dev/null +++ b/xtuner/dataset/hybrid/_pack.py @@ -0,0 +1,133 @@ +import bisect +import itertools +import random + +import torch + + +class _PackDataset(torch.utils.data.Dataset): + + def __init__(self, dataset, max_length=2048): + super().__init__() + + self.max_length = max_length + + # unpack dataset + self.dataset = dataset + + self._ori_img_urls = dataset['image_urls'] + self._ori_img_rngs = dataset['image_ranges'] + self._ori_lens = dataset['tokens'] + + self._num_packed_samples = sum(self._ori_lens) // self.max_length + + inds = [i for i in range(len(self.dataset))] + random.shuffle(inds) + self.shfl_inds = inds + + shfl_lens = [self._ori_lens[i] for i in inds] + # shuffled cumulative lengths + shfl_acc_lens = list(itertools.accumulate(shfl_lens)) + + self._shfl_item_rngs_left = [0] + shfl_acc_lens[:-1] + self._shfl_item_rngs_right = shfl_acc_lens + + shfl_img_urls = [self._ori_img_urls[i] for i in inds] + self._flat_shfl_img_urls = list(itertools.chain(*shfl_img_urls)) + + flat_shfl_acc_img_rngs = [] + flat_shfl_acc_img_rngs_left = [] + flat_shfl_acc_img_rngs_right = [] + for i in range(len(self.dataset)): + shfl_i = self.shfl_inds[i] + img_rngs = self._ori_img_rngs[shfl_i] + for left, right in img_rngs: + acc_left = left + self._shfl_item_rngs_left[i] + acc_right = right + self._shfl_item_rngs_left[i] + + flat_shfl_acc_img_rngs_left.append(acc_left) + flat_shfl_acc_img_rngs_right.append(acc_right) + flat_shfl_acc_img_rngs.append([acc_left, acc_right]) + assert len(flat_shfl_acc_img_rngs) == len(self._flat_shfl_img_urls) + + self._flat_shfl_acc_img_rngs = flat_shfl_acc_img_rngs + self._flat_shfl_acc_img_rngs_left = flat_shfl_acc_img_rngs_left + self._flat_shfl_acc_img_rngs_right = flat_shfl_acc_img_rngs_right + + def _pack_img_urls_and_rngs_in_range(self, begin, end): + + left = bisect.bisect(self._flat_shfl_acc_img_rngs_right, begin) + right = bisect.bisect(self._flat_shfl_acc_img_rngs_left, end) + + filter_urls = self._flat_shfl_img_urls[left:right] + filter_rngs = self._flat_shfl_acc_img_rngs[left:right] + + inner_rngs = [] + inner_urls = [] + for url, rng in zip(filter_urls, filter_rngs): + inner_left = max(begin, rng[0]) - begin + inner_right = min(end, rng[1]) - begin + + if inner_right - inner_left > 0: + inner_rngs.append([inner_left, inner_right]) + inner_urls.append(url) + return inner_urls, inner_rngs + + def _pack_ids_and_labels_in_range(self, begin, end): + + left = bisect.bisect(self._shfl_item_rngs_right, begin) + right = bisect.bisect(self._shfl_item_rngs_left, end) + + trunc_ids = [] + trunc_labels = [] + cumulative_len = [] + position_ids = [] + for i in range(left, right): + cumulative_len.append(len(trunc_ids)) + + item_begin = self._shfl_item_rngs_left[i] + item_end = self._shfl_item_rngs_right[i] + + inner_l = max(begin, item_begin) - item_begin + inner_r = min(end, item_end) - item_begin + position_ids.extend([i for i in range(inner_r - inner_l)]) + + ori_idx = self.shfl_inds[i] + ori_input_ids = self.dataset[ori_idx]['input_ids'] + ori_labels = self.dataset[ori_idx]['labels'] + + trunc_ids.extend(ori_input_ids[inner_l:inner_r]) + trunc_labels.extend(ori_labels[inner_l:inner_r]) + + cumulative_len.append(self.max_length) + return trunc_ids, trunc_labels, cumulative_len, position_ids + + def __len__(self): + return self._num_packed_samples + + def __getitem__(self, item): + + begin = item * self.max_length + end = (item + 1) * self.max_length + + _res = self._pack_ids_and_labels_in_range(begin, end) + packed_ids, packed_labels, cumulative_len, position_ids = _res + assert self.max_length == len(packed_ids) == len(packed_labels) + + _res = self._pack_img_urls_and_rngs_in_range(begin, end) + packed_img_urls, packed_img_rngs = _res + + for left, right in packed_img_rngs: + assert len(set(packed_ids[left:right])) == 1 + + packed = { + 'input_ids': packed_ids, + 'labels': packed_labels, + 'tokens': self.max_length, + 'image_urls': packed_img_urls, + 'image_ranges': packed_img_rngs, + 'cumulative_len': cumulative_len, + 'position_ids': position_ids + } + + return packed diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py new file mode 100644 index 000000000..e6ed2288a --- /dev/null +++ b/xtuner/dataset/hybrid/collate.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +import torch +from torch.nn.utils.rnn import pad_sequence + +from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX + + +def hybrid_collate_fn(instances: Sequence[Dict], + pad_index: int = DEFAULT_PAD_TOKEN_INDEX, + return_hf_format: bool = False): + + input_ids = [] + labels = [] + pixel_values = [] + cumulative_len = [] + image_ranges = [] + image_belongs = [] + position_ids = [] + + for i, data in enumerate(instances): + input_ids.append(torch.LongTensor(data['input_ids'])) + labels.append(torch.LongTensor(data['labels'])) + position_ids.append(torch.IntTensor(data['position_ids'])) + + if 'cumulative_len' in data: + cumulative_len.append(torch.IntTensor(data['cumulative_len'])) + + _values = data['pixel_values'] + _ranges = data['image_ranges'] + + assert len(_values) == len(_ranges) + for v, rng in zip(_values, _ranges): + pixel_values.append(v) + image_ranges.append(rng) + image_belongs.append(i) + + if len(pixel_values) > 0: + assert len(image_ranges) > 0 + assert len(image_belongs) > 0 + + pixel_values = torch.stack(pixel_values) + # image_belongs = torch.IntTensor(image_belongs) + else: + pixel_values = None + image_ranges = None + image_belongs = None + + if len(instances) > 1: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=pad_index) + labels = pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX) + position_ids = pad_sequence( + position_ids, batch_first=True, padding_value=0) + else: + input_ids = torch.stack(input_ids) + labels = torch.stack(labels) + position_ids = torch.stack(position_ids) + + if len(cumulative_len) == 0: + cumulative_len = None + + # breakpoint() + data_dict = { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': input_ids.ne(pad_index), + 'labels': labels, + 'pixel_values': pixel_values, + 'cumulative_len': cumulative_len, + 'image_ranges': image_ranges, + 'image_belongs': image_belongs + } + + if return_hf_format: + return data_dict + else: + return {'data': data_dict, 'data_samples': None} diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py new file mode 100644 index 000000000..b2699e048 --- /dev/null +++ b/xtuner/dataset/hybrid/dataset.py @@ -0,0 +1,474 @@ +import json +import os +import random +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import torch +from datasets import Dataset, load_from_disk +from mmengine import print_log +from PIL import Image +from torch import distributed as dist +from torch import nn +from tqdm import tqdm + +from xtuner.dataset.hybrid._pack import _PackDataset +from xtuner.dataset.hybrid.mappings import map_protocol, map_sequential +from xtuner.dataset.utils import expand2square +from xtuner.registry import BUILDER +from xtuner.types import HybridChatTemplate +from xtuner.utils import build_tokenizer + +os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + +@map_protocol( + input_keys=dict(input_ids=list), + added_keys=dict(tokens=int), +) +def _register_tokens(data, tokenizer=None, chat_template=None): + data['tokens'] = len(data['input_ids']) + return data + + +@map_protocol( + input_keys=dict(input_ids=list), + added_keys=dict(cumulative_len=list), +) +def _register_cumulative_len(data, tokenizer=None, chat_template=None): + data['cumulative_len'] = [0, len(data['input_ids'])] + return data + + +@map_protocol( + input_keys=dict(input_ids=list), + added_keys=dict(position_ids=list), +) +def _register_position_ids(data, tokenizer=None, chat_template=None): + data['position_ids'] = [i for i in range(len(data['input_ids']))] + return data + + +@map_protocol( + added_keys=dict(image_ranges=list), ) +def _register_empty_img_ranges(data, tokenizer=None, chat_template=None): + if 'image_ranges' not in data: + data['image_ranges'] = [] + return data + + +@map_protocol( + input_keys=dict( + input_ids=list, + labels=list, + tokens=int, + image_urls=list, + image_ranges=list, + position_ids=list, + cumulative_len=list), + output_keys=dict( + input_ids=list, + labels=list, + tokens=int, + image_urls=list, + image_ranges=list, + position_ids=list, + cumulative_len=list)) +def _check_mapped_data(item, tokenizer=None, chat_template=None): + assert isinstance(item['input_ids'][0], int) + assert isinstance(item['labels'][0], int) + + if len(item['image_urls']) > 0: + assert isinstance(item['image_urls'][0], str) + + if len(item['image_ranges']) > 0: + assert isinstance(item['image_ranges'][0], list) + assert isinstance(item['image_ranges'][0][0], int) + + return item + + +class HybridDataset(torch.utils.data.Dataset): + """ + Args: + tokenizer: The tokenizer processes some raw text as input and outputs + an Encoding. + max_length: Max length of the sequence. + pack_to_max_length: Whether to pack the dataset to the `max_length `. + This usually improves gpu utilization and therefore reduces + training time. + shuffle_before_pack: Whether to shuffle the dataset before + packing them. + use_varlen_attn: If use_varlen_attn is True, we calculate attention + the actual length of the sequence rather than the actual length + of the sequence + """ + + def __init__(self, + tokenizer, + chat_template: Union[Dict, HybridChatTemplate], + sample_ratio: int = 1.0, + max_length: int = 2048, + pack_to_max_length: bool = False, + num_workers: int = 8, + mappings: Union[Callable, List[Callable]] = [], + data_dir: Optional[str] = None, + data_files: Optional[Union[str, List[str]]] = None, + data_cached: Optional[str] = None, + image_dir: Optional[str] = None, + image_processor: Optional[nn.Module] = None, + pad_img_to_squared: bool = True): + super().__init__() + + assert data_dir or data_files or data_cached + + self.tokenizer = build_tokenizer(tokenizer) + + if isinstance(chat_template, HybridChatTemplate): + self.chat_template = chat_template + elif isinstance(chat_template, dict): + self.chat_template = BUILDER.build(chat_template) + else: + raise TypeError + + if isinstance(image_processor, dict): + image_processor = BUILDER.build(image_processor) + self.image_processor = image_processor + + if image_dir: + self.image_dir = Path(image_dir) + else: + self.image_dir = Path('') + + self.pad_img_to_squared = pad_img_to_squared + + self.sample_ratio = sample_ratio + self.max_length = max_length + self.pack_to_max_length = pack_to_max_length + + mappings.append(_register_cumulative_len) + mappings.append(_register_position_ids) + mappings.append(_register_tokens) + mappings.append(_register_empty_img_ranges) + mappings.append(_check_mapped_data) + map_fn = map_sequential(mappings) + self.map_fn = partial( + map_fn, tokenizer=self.tokenizer, chat_template=self.chat_template) + + self.num_workers = num_workers + if data_cached: + self.data_dir = data_dir + self.data_files = data_files + self.data_cached = data_cached + else: + data_dir = Path(data_dir) + if data_files is None: + data_files = [str(f) for f in data_dir.rglob('*.json')] + elif isinstance(data_files, list): + data_files = [str(data_dir / Path(f)) for f in data_files] + elif isinstance(data_files, str): + data_files = [str(data_dir / data_files)] + else: + raise TypeError + + self.data_dir = str(data_dir) + self.data_files = data_files + self.data_cached = data_cached + + self.dataset = self.build_dataset() + + def build_dataset(self): + + if not (dist.is_available() and dist.is_initialized()): + return self._build_dataset() + + timeout = timedelta( + minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=30))) + print_log(f'xtuner_dataset_timeout = {timeout}', logger='current') + + gloo_group = dist.new_group(backend='gloo', timeout=timeout) + + if dist.get_rank() == 0: + dataset = self._build_dataset() + objects = [dataset] + else: + objects = [None] + + dist.monitored_barrier(group=gloo_group, timeout=timeout) + dist.broadcast_object_list(objects, src=0) + + return objects[0] + + def _build_dataset(self): + + if self.data_cached: + dataset = load_from_disk(self.data_cached) + if self.pack_to_max_length: + dataset = self._pack_dataset(dataset) + return dataset + + dataset = [] + for file in self.data_files: + dataset.extend(json.load(open(file))) + print_log(f'Loaded json data from {file}', logger='current') + + if self.sample_ratio < 1: + num_samples = int(self.sample_ratio * len(dataset)) + dataset = random.sample(dataset, num_samples) + print_log( + f'Randomly selected {num_samples} samples', logger='current') + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + dataset = list( + tqdm( + executor.map(self.map_fn, dataset), + desc='Map Dataset', + total=len(dataset))) + + dataset = self.filter_non_labels_data(dataset) + + self.analysis_tokens_labels(dataset) + self.analysis_image_samples(dataset) + + dataset = Dataset.from_list(dataset) + + if self.pack_to_max_length: + dataset = self._pack_dataset(dataset) + + return dataset + + def _pack_dataset(self, dataset): + + unpacked_samples = len(dataset) + dataset = _PackDataset(dataset, self.max_length) + packed_samples = len(dataset) + print_log( + 'Before pack multi samples to max length: ' + f'{unpacked_samples} samples', + logger='current') + print_log( + 'After pack multi samples to max length: ' + f'{packed_samples} samples', + logger='current') + return dataset + + def filter_non_labels_data(self, dataset): + + def filter_fn(item): + return any(item['labels'][i] >= 0 for i in range(self.max_length)) + + ori_samples = len(dataset) + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + results = list( + tqdm( + executor.map(filter_fn, dataset), + desc='Filter Dataset', + total=len(dataset))) + + new_dataset = [x for x, passed in zip(dataset, results) if passed] + + new_samples = len(new_dataset) + print_log(f'Before filter: {ori_samples} samples', logger='current') + print_log(f'After filter: {new_samples} samples', logger='current') + print_log( + f'Filtered {ori_samples - new_samples} samples ' + '(all labels are ignore)', + logger='current') + return new_dataset + + def analysis_image_samples(self, dataset): + + def img_sample_counter(item): + return len(item['image_urls']) > 0 + + def img_counter(item): + return len(item['image_urls']) + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + images = list( + tqdm( + executor.map(img_counter, dataset), + desc='Count Images', + total=len(dataset))) + + samples = list( + tqdm( + executor.map(img_sample_counter, dataset), + desc='Count Contain Image Samples', + total=len(dataset))) + + num_images = sum(images) + num_samples = sum(samples) + print_log( + f'There are a total of {num_samples} samples with images, ' + f'amounting to {num_images} images.', + logger='current') + + def analysis_tokens_labels(self, dataset): + + def label_counter(item): + return sum([1 for i in item['labels'] if i >= 0]) + + def token_counter(item): + return len(item['input_ids']) + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + tokens = list( + tqdm( + executor.map(token_counter, dataset), + desc='Count Tokens', + total=len(dataset))) + + labels = list( + tqdm( + executor.map(label_counter, dataset), + desc='Count Labels', + total=len(dataset))) + + num_tokens = sum(tokens) + num_labels = sum(labels) + print_log( + f'There are a total of {num_tokens} tokens, ' + f'of which {num_labels} tokens need loss calculation.', + logger='current') + + def cache(self, cache_dir: str): + cache_dir = Path(cache_dir) + + if self.pack_to_max_length: + hf_dataset = Dataset.from_list(self.dataset.dataset) + else: + hf_dataset = Dataset.from_list(self.dataset) + + hf_dataset.save_to_disk(cache_dir) + + dset_conf = { + 'image_dir': str(self.image_dir), + 'data_files': self.data_files, + 'max_length': self.max_length, + 'chat_template': self.chat_template.model_dump(), + 'pack_to_max_length': self.pack_to_max_length, + 'tokenizer': type(self.tokenizer).__name__, + } + + with open(cache_dir / 'dataset_configuration.json', 'w') as f: + json.dump(dset_conf, f) + + self.tokenizer.save_pretrained(cache_dir / 'tokenizer') + self.image_processor.save_pretrained(cache_dir / 'image_processor') + + def load_image(self, url): + image_file = self.image_dir / url + image = Image.open(image_file).convert('RGB') + + if self.pad_img_to_squared: + background = tuple( + int(x * 255) for x in self.image_processor.image_mean) + image = expand2square(image, background) + + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + + return image + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item: int) -> Dict[str, List]: + + data = self.dataset[item] + + pixel_values = [] + for url in data['image_urls']: + image = self.load_image(url) + + pixel_values.append(image) + + data['pixel_values'] = pixel_values + + return data + + +if __name__ == '__main__': + + from transformers import CLIPImageProcessor + + chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' + ) + + processor = CLIPImageProcessor.from_pretrained( + 'openai/clip-vit-large-patch14-336', + trust_remote_code=True, + ) + + from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens, + llava_to_openai, + openai_to_raw_training) + + data_dir = './llava_data/LLaVA-Instruct-150K/' + image_dir = './llava_data/llava_images/' + data_files = 'llava_v1_5_mix665k.json' + + dataset = HybridDataset( + 'internlm/internlm2-chat-1_8b', + chat_template, + sample_ratio=1, + max_length=32 * 1024, + data_dir=data_dir, + data_files=data_files, + image_dir=image_dir, + image_processor=processor, + pack_to_max_length=True, + mappings=[ + llava_to_openai, + openai_to_raw_training, + insert_img_pad_tokens, + ], + num_workers=4) + + print(dataset[0]) + + dataset.cache('cached_llava') + dataset = HybridDataset( + 'internlm/internlm2-chat-1_8b', + chat_template, + sample_ratio=1, + max_length=32 * 1024, + data_cached='cached_llava', + image_dir=image_dir, + image_processor=processor, + pack_to_max_length=True, + mappings=[ + llava_to_openai, + openai_to_raw_training, + insert_img_pad_tokens, + ], + num_workers=4) + print(dataset[0]) + + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + from xtuner.dataset.hybrid.collate import hybrid_collate_fn + loader = DataLoader( + dataset, + 4, + num_workers=0, + collate_fn=hybrid_collate_fn, + sampler=DefaultSampler(dataset, shuffle=True)) + + for data in tqdm(loader): + continue diff --git a/xtuner/dataset/hybrid/mappings.py b/xtuner/dataset/hybrid/mappings.py new file mode 100644 index 000000000..e104885ff --- /dev/null +++ b/xtuner/dataset/hybrid/mappings.py @@ -0,0 +1,172 @@ +import re +from typing import Callable, Dict, List, Type + +from mmengine.config.lazy import LazyObject + +from xtuner.types import TrainingHybridChatMessages + + +def map_protocol( + input_keys: Dict[str, Type] = {}, + output_keys: Dict[str, Type] = {}, + added_keys: Dict[str, Type] = {}, +) -> Callable: + + def decorator(func): + + def wrapper(data, *args, **kwargs): + + for key, _type in input_keys.items(): + assert key in data + if not isinstance(data[key], _type): + breakpoint() + + result = func(data, *args, **kwargs) + + for key, _type in output_keys.items(): + assert key in result + assert isinstance(result[key], _type) + + return result + + return wrapper + + setattr(decorator, 'input_keys', input_keys) + setattr(decorator, 'output_keys', output_keys) + setattr(decorator, 'added_keys', added_keys) + + return decorator + + +def map_sequential(mappings: List[Callable]): + + if not isinstance(mappings, List): + mappings = list(mappings) + + for i in range(len(mappings)): + if isinstance(mappings[i], LazyObject): + mappings[i] = mappings[i].build() + + def _sequential(item, tokenizer, chat_template): + + for func in mappings: + item = func(item, tokenizer, chat_template) + + return item + + return _sequential + + +@map_protocol( + input_keys=dict(input_ids=list, labels=list, image_urls=list), + output_keys=dict( + input_ids=list, labels=list, image_urls=list, image_ranges=list), +) +def insert_img_pad_tokens(data, tokenizer, chat_template) -> Dict: + + image_urls = data['image_urls'] + if len(image_urls) == 0: + data['image_ranges'] = [] + return data + + input_ids = data['input_ids'] + labels = data['labels'] + + img_token = chat_template.image_token_index + img_token_inds = [i for i, t in enumerate(input_ids) if t == img_token] + assert len(img_token_inds) == len( + image_urls), f'{img_token_inds} {image_urls}' + + for url, ind in zip(image_urls, img_token_inds): + # image = self.load_image(url) + h, w = 336 // 14, 336 // 14 + + pad_tokens = [tokenizer.pad_token_id] * (h * w) + pad_labels = [labels[ind]] * (h * w) + + input_ids[ind] = pad_tokens + labels[ind] = pad_labels + + new_ids = [] + new_labels = [] + assert len(input_ids) == len(labels) + + img_ranges = [] + for i, _ in enumerate(zip(input_ids, labels)): + if isinstance(input_ids[i], list): + assert isinstance(labels[i], list) + assert len(input_ids[i]) == len(labels[i]) + + img_begin = len(new_ids) + img_end = img_begin + len(input_ids[i]) + img_ranges.append([img_begin, img_end]) + + new_ids.extend(input_ids[i]) + new_labels.extend(labels[i]) + + else: + new_ids.append(input_ids[i]) + new_labels.append(labels[i]) + + data['input_ids'] = new_ids + data['labels'] = new_labels + data['image_ranges'] = img_ranges + + return data + + +@map_protocol( + input_keys=dict(messages=list), + output_keys=dict(input_ids=list, labels=list, image_urls=list), +) +def openai_to_raw_training(item: dict, tokenizer, chat_template) -> Dict: + + data = TrainingHybridChatMessages.from_dict(item) + data = data.tokenize(tokenizer, chat_template) + + return data + + +@map_protocol( + input_keys=dict(conversations=list), + output_keys=dict(messages=list), +) +def llava_to_openai(data, tokenizer=None, chat_template=None): + + image_token = '' + conversations = data['conversations'] + messages = [] + + if 'image' in data: + image_url = data['image'] + else: + image_url = None + + while conversations and conversations[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + conversations = conversations[1:] + + for convs in conversations: + if convs['from'] == 'human': + pattern = f'({image_token})' + chunks = re.split(pattern, convs['value']) + + content = [] + for chunk in chunks: + if chunk == image_token: + assert isinstance(image_url, str), image_url + item = dict(type='image_url', image_url=image_url) + content.append(item) + elif len(chunk.strip()): + item = dict(type='text', text=chunk.strip()) + content.append(item) + + msg = {'role': 'user', 'content': content} + messages.append(msg) + + elif convs['from'] == 'gpt': + msg = {'role': 'assistant', 'content': convs['value']} + messages.append(msg) + else: + raise NotImplementedError + return {'messages': messages} diff --git a/xtuner/model/__init__.py b/xtuner/model/__init__.py index 39547b2d7..6ff30761d 100644 --- a/xtuner/model/__init__.py +++ b/xtuner/model/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .agent import AgentFinetune +from .auto import AutoModelForCausalLM +from .hybrid import HybridFinetune from .llava import LLaVAModel from .sft import SupervisedFinetune -__all__ = ['SupervisedFinetune', 'LLaVAModel'] +__all__ = [ + 'HybridFinetune', 'SupervisedFinetune', 'LLaVAModel', 'AgentFinetune' +] diff --git a/xtuner/model/agent.py b/xtuner/model/agent.py new file mode 100644 index 000000000..dd3f1bde6 --- /dev/null +++ b/xtuner/model/agent.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmengine.model import BaseModel +from peft import LoraConfig +from torch import nn + +from xtuner.registry import BUILDER +from xtuner.utils.config import build_from_cfg_or_obj +from .modules import ProjectorConfig, ProjectorModel, dispatch_modules +from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing, + get_peft_model_state_dict, prepare_for_llm_lora, + prepare_for_vision_lora, + smart_tokenizer_and_embedding_resize) + + +class AgentFinetune(BaseModel): + + def __init__( + self, + llm, + tokenizer=None, + llm_lora=None, + use_activation_checkpointing=True, + use_varlen_attn=False, + ): + super().__init__() + + # Build the base language model without initialization. + # This will greatly reduce the time to build the model. + with LoadWoInit(): + self.llm = build_from_cfg_or_obj(llm, nn.Module) + self.llm.config.use_cache = False + dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) + + if tokenizer is not None: + if isinstance(tokenizer, dict): + tokenizer = BUILDER.build(tokenizer) + smart_tokenizer_and_embedding_resize(tokenizer, self.llm) + + if use_activation_checkpointing: + # For backward compatibility + enable_hf_model_gradient_checkpointing(self.llm) + + self.use_llm_lora = llm_lora is not None + + # Prepare the model for LoRA if specified + if self.use_llm_lora: + lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig) + self.llm = prepare_for_llm_lora(self.llm, lora_conf, + use_activation_checkpointing) + + self._is_init = True + + # Determines whether to calculate attention based on the + # seq_len dimension (use_varlen_attn = False) or the actual length of + # the sequence. + self.use_varlen_attn = use_varlen_attn + + def init_weights(self): + """Parent class method. + + To avoid overwriting the loaded weights, overload it to an empty + function. + """ + pass + + def forward(self, data, data_samples=None, mode='loss'): + """Overload parent class method, only support training.""" + + if mode == 'loss': + return self.compute_loss(data) + else: + raise NotImplementedError( + f"{type(self)}'s forward is only supported for use during " + 'training. If you want to get predictions or chat, please ' + "directly use `llm`'s forward.") + + def compute_loss(self, data): + + input_ids = data['input_ids'] + labels = data['labels'] + # position_ids = data['position_ids'] + attention_mask = data['attention_mask'] + # breakpoint() + bs, tokens = input_ids.shape + if self.use_varlen_attn: + assert bs == 1 + + cumulative_len = data['cumulative_len'][0] + max_seqlen = (cumulative_len[1:] - cumulative_len[:-1]).max() + + position_ids = [] + for i in range(1, len(cumulative_len)): + chunk_tokens = cumulative_len[i] - cumulative_len[i - 1] + position_ids.append(torch.arange(chunk_tokens)) + position_ids = torch.cat(position_ids, dim=0).unsqueeze(0) + + from mmengine import MessageHub + rank = dist.get_rank() + message_hub = MessageHub.get_instance('varlen_attn_args') + message_hub.update_info(f'cumulative_len_rank_{rank}', + cumulative_len) + message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen) + else: + + position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1) + + outputs = self.llm( + input_ids=input_ids, + # position_ids=position_ids, + attention_mask=attention_mask, + labels=labels) + + loss_dict = {'loss': outputs.loss} + return loss_dict + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + to_return = OrderedDict() + # Step 1. LLM + if self.use_llm_lora: + to_return.update( + get_peft_model_state_dict(self.llm, state_dict=state_dict)) + else: + to_return.update( + {k: v + for k, v in state_dict.items() if 'llm.' in k}) + + return to_return + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) diff --git a/xtuner/model/auto.py b/xtuner/model/auto.py new file mode 100644 index 000000000..5546b1a5f --- /dev/null +++ b/xtuner/model/auto.py @@ -0,0 +1,73 @@ +import os +from typing import Dict, Optional, Union + +import torch + +from transformers import AutoConfig as HfAutoConfig +from transformers import AutoModelForCausalLM as HfAutoModelForCausalLM +from transformers import BitsAndBytesConfig + +from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 + + + + +class AutoModelForCausalLM: + + @classmethod + def from_config(cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool = True, + **kwargs): + return HfAutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool = True, + quantization_config: Optional[BitsAndBytesConfig] = None, + **kwargs): + + config = cls.from_config( + pretrained_model_name_or_path, trust_remote_code=True) + attn_kwargs = cls._flash_attn_kwargs(config) + kwargs.update(attn_kwargs) + + if torch.cuda.is_bf16_supported(): + kwargs.update(torch_dtype=torch.bfloat16) + else: + kwargs.update(torch_dtype=torch.float16) + + model = HfAutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + quantization_config=quantization_config, + **kwargs) + + return model + + @staticmethod + def _flash_attn_kwargs(config): + cls_name = type(config).__name__ + _built_in_flash_attn_1 = ('LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', + 'Qwen2Config', 'Starcoder2Config', + 'Starcoder2Config') + + _built_in_flash_attn_2 = ('InternLMConfig', 'InternLM2Config', + 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', + 'Qwen2Config', 'Starcoder2Config', + 'Starcoder2Config') + + attn_kwargs = {} + if SUPPORT_FLASH2 and cls_name in _built_in_flash_attn_2: + attn_kwargs.update(attn_implementation='flash_attention_2') + elif SUPPORT_FLASH1 and cls_name in _built_in_flash_attn_1: + attn_kwargs.update(attn_implementation='sdpa') + + return attn_kwargs \ No newline at end of file diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py new file mode 100644 index 000000000..0f0fc7e76 --- /dev/null +++ b/xtuner/model/hybrid.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmengine.model import BaseModel +from peft import LoraConfig +from torch import nn + +from xtuner.registry import BUILDER +from xtuner.utils.config import build_from_cfg_or_obj +from .modules import ProjectorConfig, ProjectorModel, dispatch_modules +from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing, + get_peft_model_state_dict, prepare_for_llm_lora, + prepare_for_vision_lora, + smart_tokenizer_and_embedding_resize) + + +class HybridFinetune(BaseModel): + + def __init__( + self, + llm, + visual_encoder=None, + visual_select_layer=-2, + projector_depth=2, + pretrained_pth=None, + tokenizer=None, + llm_lora=None, + visual_encoder_lora=None, + freeze_llm=False, + freeze_visual_encoder=False, + use_activation_checkpointing=True, + use_varlen_attn=False, + ): + super().__init__() + + # Build the base language model without initialization. + # This will greatly reduce the time to build the model. + with LoadWoInit(): + self.llm = build_from_cfg_or_obj(llm, nn.Module) + if visual_encoder: + visual_encoder = build_from_cfg_or_obj(visual_encoder, + nn.Module) + self.visual_encoder = visual_encoder + self.visual_select_layer = visual_select_layer + self.llm.config.use_cache = False + dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) + + if tokenizer is not None: + if isinstance(tokenizer, dict): + tokenizer = BUILDER.build(tokenizer) + smart_tokenizer_and_embedding_resize(tokenizer, self.llm) + + projector_config = ProjectorConfig( + visual_hidden_size=self.visual_encoder.config.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=projector_depth) + self.projector = ProjectorModel(projector_config).to( + self.visual_encoder.dtype) + + self.freeze_llm = freeze_llm + self.freeze_visual_encoder = freeze_visual_encoder + if self.freeze_llm: + self.llm.requires_grad_(False) + if self.freeze_visual_encoder: + self.visual_encoder.requires_grad_(False) + + if use_activation_checkpointing: + # For backward compatibility + enable_hf_model_gradient_checkpointing(self.llm) + enable_hf_model_gradient_checkpointing(self.visual_encoder) + + self.projector.enable_input_require_grads() + self.projector.gradient_checkpointing_enable() + + self.use_llm_lora = llm_lora is not None + self.use_visual_encoder_lora = visual_encoder_lora is not None + + # Prepare the model for LoRA if specified + if self.use_llm_lora: + lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig) + self.llm = prepare_for_llm_lora(self.llm, lora_conf, + use_activation_checkpointing) + + if self.use_visual_encoder_lora: + lora_conf = build_from_cfg_or_obj( + visual_encoder_lora, accept=LoraConfig) + self.visual_encoder = prepare_for_vision_lora( + self.visual_encoder, lora_conf, use_activation_checkpointing) + self._is_init = True + + # Determines whether to calculate attention based on the + # seq_len dimension (use_varlen_attn = False) or the actual length of + # the sequence. + self.use_varlen_attn = use_varlen_attn + + def init_weights(self): + """Parent class method. + + To avoid overwriting the loaded weights, overload it to an empty + function. + """ + pass + + def forward(self, data, data_samples=None, mode='loss'): + """Overload parent class method, only support training.""" + + if mode == 'loss': + return self.compute_loss(data) + else: + raise NotImplementedError( + f"{type(self)}'s forward is only supported for use during " + 'training. If you want to get predictions or chat, please ' + "directly use `llm`'s forward.") + + def _get_vision_embeds_and_ranges(self, data): + + input_ids = data['input_ids'] + pixel_values = data['pixel_values'] + img_rngs = data['image_ranges'] + img_belongs = data['image_belongs'] + + bs, tokens = input_ids.shape + + img_embeds = [] + ranges_in_flat_batch = [] + + if pixel_values is not None: + assert isinstance(pixel_values, torch.Tensor) + assert len(img_rngs) == len(img_belongs) == pixel_values.size(0) + + batch_total_imgs = len(img_rngs) + + visual_outputs = self.visual_encoder( + pixel_values, output_hidden_states=True) + features = self.projector( + visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + batch_total_imgs, real_img_tokens, _ = features.shape + + for i in range(batch_total_imgs): + img_start, img_end = img_rngs[i] + exp_img_tokens = img_end - img_start + img_emb = features[i] + img_bs_ind = img_belongs[i] + + if real_img_tokens == exp_img_tokens: + img_embeds.append(img_emb) + elif not real_img_tokens == exp_img_tokens and img_start == 0: + img_embeds.append(img_emb[real_img_tokens - img_end:]) + elif (not real_img_tokens == exp_img_tokens + and img_end == tokens): + img_embeds.append(img_emb[:exp_img_tokens]) + else: + raise RuntimeError + + flat_offset = tokens * img_bs_ind + + left = flat_offset + img_start + right = flat_offset + img_end + ranges_in_flat_batch.append((left, right)) + + return img_embeds, ranges_in_flat_batch + + def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): + + assert len(mm_embeds) == len(ranges) + if len(mm_embeds) == 0: + return flat_embeds + + _empty_embeds = torch.zeros_like(flat_embeds) + for (start, end), emb in zip(ranges, mm_embeds): + _empty_embeds[start:end] += emb + + flat_embeds = flat_embeds * (_empty_embeds == 0) + + return flat_embeds + _empty_embeds + + def compute_loss(self, data): + + input_ids = data['input_ids'] + labels = data['labels'] + # position_ids = data['position_ids'] + attention_mask = data['attention_mask'] + # breakpoint() + bs, tokens = input_ids.shape + if self.use_varlen_attn: + assert bs == 1 + + cumulative_len = data['cumulative_len'][0] + max_seqlen = (cumulative_len[1:] - cumulative_len[:-1]).max() + + position_ids = [] + for i in range(1, len(cumulative_len)): + chunk_tokens = cumulative_len[i] - cumulative_len[i - 1] + position_ids.append(torch.arange(chunk_tokens)) + position_ids = torch.cat(position_ids, dim=0).unsqueeze(0) + + from mmengine import MessageHub + rank = dist.get_rank() + message_hub = MessageHub.get_instance('varlen_attn_args') + message_hub.update_info(f'cumulative_len_rank_{rank}', + cumulative_len) + message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen) + else: + + position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1) + + input_embeds = self.llm.get_input_embeddings()(input_ids) + + bs, tokens, dim = input_embeds.shape + flat_embeds = input_embeds.flatten(0, 1) + + img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, + flat_bs_img_rngs) + input_embeds = flat_embeds.reshape(bs, tokens, dim) + + outputs = self.llm( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=input_embeds, + labels=labels) + + loss_dict = {'loss': outputs.loss} + return loss_dict + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + to_return = OrderedDict() + # Step 1. visual_encoder + if self.use_visual_encoder_lora: + to_return.update( + get_peft_model_state_dict( + self.visual_encoder, state_dict=state_dict)) + elif not self.freeze_visual_encoder: + to_return.update({ + k: v + for k, v in state_dict.items() if 'visual_encoder.' in k + }) + # Step 2. LLM + if self.use_llm_lora: + to_return.update( + get_peft_model_state_dict(self.llm, state_dict=state_dict)) + elif not self.freeze_llm: + to_return.update( + {k: v + for k, v in state_dict.items() if 'llm.' in k}) + # Step 3. Projector + to_return.update( + {k: v + for k, v in state_dict.items() if 'projector.' in k}) + return to_return + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index a166e8bae..8aa664ad5 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -173,7 +173,7 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, max_seqlen): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) + attn_output = flash_attn_varlen_func( q_unpad, k_unpad, diff --git a/xtuner/model/modules/dispatch/utils.py b/xtuner/model/modules/dispatch/utils.py index 4cfa26cd1..5355bce74 100644 --- a/xtuner/model/modules/dispatch/utils.py +++ b/xtuner/model/modules/dispatch/utils.py @@ -25,7 +25,6 @@ def upad_qkv(query_layer, key_layer, value_layer, attention_mask, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) diff --git a/xtuner/model/utils.py b/xtuner/model/utils.py index dce86315d..0e5dfb826 100644 --- a/xtuner/model/utils.py +++ b/xtuner/model/utils.py @@ -1,13 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from contextlib import nullcontext from typing import List, Optional import torch from mmengine import print_log from mmengine.utils.misc import get_object_from_string -from peft import PeftType +from peft import (LoraConfig, PeftModel, PeftType, get_peft_model, + prepare_model_for_kbit_training) from torch import nn -from transformers import PreTrainedModel +from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.integrations import is_deepspeed_zero3_enabled from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX @@ -50,6 +53,34 @@ def find_all_linear_names(model): return list(lora_module_names) +def collect_linear_suffix_names(model: torch.nn.Module, + exclude: list[str] = []) -> list[str]: + """Collect suffix names of nn.Linear modules from a PyTorch model. + + Args: + model: The PyTorch model. + exclude: A list of keys to be excluded from the collected + suffix names. Default: ['lm_head', 'output_layer']. + + Returns: + A list of collected suffix names after excluding specified keys. + """ + suffix_names = set() + + # Iterate through all named modules in the model + for name, module in model.named_modules(): + # Check if the module is an instance of nn.Linear + if isinstance(module, torch.nn.Linear): + names = name.split('.') + suffix_names.add(names[0] if len(names) == 1 else names[-1]) + + # Remove exclude_keys from the collected suffix_names + for key in exclude: + suffix_names.remove(key) + + return list(suffix_names) + + class LoadWoInit: """Context manager that disable parameter initialization.""" @@ -286,6 +317,73 @@ def make_inputs_require_grad(module, input, output): output.requires_grad_(True) +def prepare_for_llm_lora(model: PreTrainedModel, + lora_config: LoraConfig, + gradient_checkpointing: bool = True) -> PeftModel: + model = prepare_model_for_kbit_training(model, gradient_checkpointing) + if lora_config.target_modules is None: + modules = collect_linear_suffix_names(model, exclude=['output']) + lora_config.target_modules = modules + + model = get_peft_model(model, lora_config) + return model + + +def prepare_for_vision_lora(model: PreTrainedModel, + lora_config: LoraConfig, + gradient_checkpointing: bool = True) -> PeftModel: + + if lora_config.target_modules is None: + modules = collect_linear_suffix_names(model) + lora_config.target_modules = modules + + model = get_peft_model(model, lora_config) + return model + + +def smart_tokenizer_and_embedding_resize( + tokenizer: PreTrainedTokenizer, + model: PreTrainedModel, +): + """Resize embedding.""" + if is_deepspeed_zero3_enabled(): + import deepspeed + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings( + ) is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters( + params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + assert isinstance(model.get_output_embeddings(), nn.Linear) + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + num_new_tokens = len(tokenizer) - current_embedding_size + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + print_log( + f'Resized token embeddings from {current_embedding_size} to ' + f'{len(tokenizer)}.', 'current') + + def guess_load_checkpoint(pth_model): if osp.isfile(pth_model): state_dict = torch.load(pth_model, map_location='cpu') @@ -307,3 +405,19 @@ def guess_load_checkpoint(pth_model): else: raise FileNotFoundError(f'Cannot find {pth_model}') return state_dict + + +def enable_hf_model_gradient_checkpointing(model: PreTrainedModel) -> None: + # For backward compatibility + if hasattr(model, 'enable_input_require_grads'): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable() diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py new file mode 100644 index 000000000..cc230e8f8 --- /dev/null +++ b/xtuner/types/__init__.py @@ -0,0 +1,6 @@ +from .chat_template import HybridChatTemplate +from .train import RawTrainingData, TrainingHybridChatMessages + +__all__ = [ + 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages' +] diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py new file mode 100644 index 000000000..cd0a4d4a7 --- /dev/null +++ b/xtuner/types/chat.py @@ -0,0 +1,173 @@ +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel + +from .chat_template import HybridChatTemplate + + +class TextContentItem(BaseModel): + type: Literal['text'] + text: str + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return self.text + + +class ImageContentItem(BaseModel): + type: Literal['image_url'] + image_url: str + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.image_token + + +class FileContentItem(BaseModel): + type: Literal['file_url'] + file_url: str + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return self.file_url + + +MultModalContentType = Union[TextContentItem, ImageContentItem] +ContentType = Union[str, List[MultModalContentType]] + + +class ChatMsg(BaseModel): + role: Literal['assistant', 'user', 'system'] + content: ContentType + files: List[Union[str, Dict]] = [] + + def collect_img_urls(self) -> List[str]: + img_urls = [] + if isinstance(self.content, list): + for item in self.content: + if isinstance(item, ImageContentItem): + img_urls.append(item.image_url) + return img_urls + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + if isinstance(self.content, str): + text = self.content + elif isinstance(self.content, list): + text = '' + for i, item in enumerate(self.content): + if i == 0: + text += item.apply_chat_template(chat_template) + else: + text += '\n' + item.apply_chat_template(chat_template) + else: + raise NotImplementedError + + if self.role == 'system': + prompt = chat_template.decorate_system(text) + elif self.role == 'user': + if len(self.files) > 0: + stop_word = chat_template.stop_words[0] + text += f'\n{stop_word}\n{chat_template.decorate_files(self.files)}' + prompt = chat_template.decorate_user(text) + + elif self.role == 'assistant': + prompt = chat_template.decorate_assistant(text) + else: + raise NotImplementedError + + return prompt + + +# Function Call + + +class FunctionCall(BaseModel): + name: str + arguments: Dict + + +class FunctionCallMsg(BaseModel): + + role: Literal['assistant'] + content: str + function_call: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + return chat_template.decorate_function_call(self.content, + self.function_call) + + +class FunctionResultMsg(BaseModel): + role: Literal['function'] + name: str + content: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.decorate_function_result(self.content) + + +class CodeInterpreterCallMsg(BaseModel): + + role: Literal['assistant'] + content: str + conde_interpreter_call: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + return chat_template.decorate_code_interpreter_call( + self.content, self.conde_interpreter_call) + + +class CodeInterpreterResultMsg(BaseModel): + role: Literal['code_interpreter'] + content: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.decorate_code_interpreter_result(self.content) + + +class Functions(BaseModel): + + name: str + description: Union[str, Dict] + parameters: Union[str, Dict] + + +HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg, + CodeInterpreterCallMsg, CodeInterpreterResultMsg] + + +class HybridChatMessages(BaseModel): + + messages: List[HybridChatMsgType] = [] + # images: List[Image.Image] = [] + functions: List[Functions] = [] + code_interpreter: Optional[str] = None + + # TODO (pppppM) add audio and video + + def collect_img_urls(self) -> List[str]: + img_urls = [] + for msg in self.messages: + img_urls.extend(msg.collect_img_urls()) + return img_urls + + def pop_latest_msg(self): + return self.messages.pop() + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + prompt = '' + + if self.code_interpreter: + prompt += chat_template.decorate_functions(self.code_interpreter) + + if len(self.functions) > 0: + + functions = [func.model_dump() for func in self.functions] + + prompt += chat_template.decorate_functions(functions) + + for msg in self.messages: + prompt += msg.apply_chat_template(chat_template) + + return prompt diff --git a/xtuner/types/chat_template.py b/xtuner/types/chat_template.py new file mode 100644 index 000000000..4318c6104 --- /dev/null +++ b/xtuner/types/chat_template.py @@ -0,0 +1,195 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, field_validator + + +class HybridChatTemplate(BaseModel): + """Define a Pydantic data model for a hybrid chat with attributes for + system, user and assistant chat as well as function and interpreter calls + and results.""" + + # Normal Chat + system: str # System message format + user: str # User message format + assistant: str # Assistant message format + stop_words: List[str] # List of stop words + + # Multimodal Chat + # Predefined token and index for images + image_token: str = '' + image_token_index: int = -100 + + # Agent Chat + + # Interpreter and function related strings + files: Optional[str] = None + + functions: Optional[str] = None # Function description format + function_call: Optional[str] = None # Function call format + function_result: Optional[str] = None # Function result format + + code_interpreter: Optional[str] = None + code_interpreter_call: Optional[str] = None # Interpreter call format + code_interpreter_result: Optional[str] = None # Interpreter result format + + function_token: Optional[str] = None + code_interpreter_token: Optional[str] = None + action_start_token: Optional[str] = None + action_end_token: Optional[str] = None + + @property + def mm_token_maps(self) -> Dict[str, int]: + """Return a dictionary that maps multimodal tokens to corresponding + token indexes.""" + return {self.image_token: self.image_token_index} + + def decorate_system(self, text: str) -> str: + """Decorate text with the `system` template.""" + return self.system.format(system=text) + + def decorate_assistant(self, text: str) -> str: + """Decorate text with the `assistant` template.""" + return self.assistant.format(assistant=text) + + def decorate_user(self, text: str) -> str: + """Decorate text with the `user` template.""" + return self.user.format(user=text) + + def decorate_files(self, text: str) -> str: + """Decorate text with the `functions` template.""" + return self.files.format(files=text) + + def decorate_functions(self, text: str) -> str: + """Decorate text with the `functions` template.""" + return self.functions.format(functions=text) + + def decorate_function_call(self, text: str, func: str) -> str: + """Decorate text with the `function_call` template.""" + return self.function_call.format(assistant=text, function_call=func) + + def decorate_function_result(self, text: str) -> str: + """Decorate text with the `function_result` template.""" + return self.function_result.format(function_result=text) + + def decorate_code_interpreter(self, text: str) -> str: + """Decorate text with the `code_interpreter` template.""" + return self.code_interpreter.format(code_interpreter=text) + + def decorate_code_interpreter_call(self, text: str, func: str) -> str: + """Decorate text with the `code_interpreter_call` template.""" + return self.code_interpreter_call.format( + assistant=text, code_interpreter_call=func) + + def decorate_code_interpreter_result(self, text: str) -> str: + """Decorate text with the `code_interpreter_result` template.""" + return self.code_interpreter_result.format( + code_interpreter_result=text) + + @field_validator('system') + def check_system(cls, v: str) -> str: + """Validate that `system` contains '{system}'. + + If not, raises a ValueError. + """ + if v is not None and '{system}' not in v: + raise ValueError("system must contain the keyword '{system}'") + return v + + @field_validator('user') + def check_user(cls, v: str) -> str: + """Validate that `user` contains '{user}'. + + If not, raises a ValueError. + """ + if v is not None and '{user}' not in v: + raise ValueError("user must contain the keyword '{user}'") + return v + + @field_validator('assistant') + def check_assistant(cls, v: str) -> str: + """Validate that `assistant` contains '{assistant}'. + + If not, raises a ValueError. + """ + if v is not None and '{assistant}' not in v: + raise ValueError( + "assistant must contain the keyword '{assistant}'") + return v + + @field_validator('function_call') + def check_function_call(cls, v: str) -> str: + """Validate that `function_call` contains '{function_call}'. + + If not, raises a ValueError. + """ + if (v is not None and '{function_call}' not in v + and '{assistant}' not in v): + raise ValueError( + "function_call must contain the keywords '{function_call}'") + if v is not None and '{assistant}' not in v: + raise ValueError( + "function_call must contain the keyword '{assistant}' and " + "'{function_call}'") + return v + + @field_validator('function_result') + def check_function_result(cls, v: str) -> str: + """Validate that `function_result` contains '{function_result}'. + + If not, raises a ValueError. + """ + if v is not None and '{function_result}' not in v: + raise ValueError( + "function_result must contain the keyword '{function_result}'") + return v + + @field_validator('functions') + def check_functions(cls, v: str) -> str: + """Validate that `functions` contains '{functions}'. + + If not, raises a ValueError. + """ + if v is not None and '{functions}' not in v: + raise ValueError( + "functions must contain the keyword '{functions}'") + return v + + @field_validator('code_interpreter') + def check_code_interpreter(cls, v: str) -> str: + """Validate that `code_interpreter` contains '{code_interpreter}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter}' not in v: + raise ValueError('code_interpreter must contain the keyword ' + "'{code_interpreter}'") + return v + + @field_validator('code_interpreter_call') + def check_code_interpreter_call(cls, v: str) -> str: + """Validate that `code_interpreter_call` contains + '{code_interpreter_call}'. + + If not, raises a ValueError. + """ + if (v is not None and '{code_interpreter_call}' not in v + and '{assistant}' not in v): + raise ValueError('code_interpreter_call must contain the keywords ' + "'{assistant}' and '{code_interpreter_call}'") + if v is not None and '{assistant}' not in v: + raise ValueError('code_interpreter_call must contain the keywords ' + "'{assistant}' and '{code_interpreter_call}'") + return v + + @field_validator('code_interpreter_result') + def check_code_interpreter_result(cls, v: str) -> str: + """Validate that `code_interpreter_result` contains + '{code_interpreter_result}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter_result}' not in v: + raise ValueError( + 'code_interpreter_result must contain the keyword ' + "'{code_interpreter_result}'") + return v diff --git a/xtuner/types/train.py b/xtuner/types/train.py new file mode 100644 index 000000000..9971308cd --- /dev/null +++ b/xtuner/types/train.py @@ -0,0 +1,364 @@ +import copy +import re +from typing import Dict, List, Optional, Union + +import torch +from pydantic import BaseModel +from transformers.tokenization_utils import PreTrainedTokenizer + +from xtuner.utils import IGNORE_INDEX +from xtuner.utils.tokenizer import get_bos_token_ids +from .chat import (ChatMsg, CodeInterpreterCallMsg, CodeInterpreterResultMsg, + FileContentItem, FunctionCallMsg, FunctionResultMsg, + Functions, ImageContentItem, TextContentItem) +from .chat_template import HybridChatTemplate + + +class TrainingChatMsg(ChatMsg): + loss: Optional[bool] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.loss is None: + if self.role == 'system': + self.loss = False + elif self.role == 'user': + self.loss = False + elif self.role == 'assistant': + self.loss = True + else: + raise NotImplementedError + + def _encode_mm_content(self, text: str, mm_token_maps: Dict[str, int], + tokenizer: PreTrainedTokenizer): + + mm_tokens = mm_token_maps.keys() + + pattern = r'(' + '|'.join(mm_tokens) + r')' + chunks = re.split(pattern, text) + + assert len(chunks) > 1 + + token_ids = [] + for c in chunks: + if c in mm_tokens: + token_ids.append(mm_token_maps[c]) + else: + token_ids.extend(tokenizer.encode(c, add_special_tokens=False)) + + return token_ids + + def _with_multi_modal_content(self): + flag = False + + if isinstance(self.content, list): + for item in self.content: + # TODO (pppppM) support video and audio + if isinstance(item, ImageContentItem): + flag = True + break + return flag + + def tokenize( + self, + tokenizer: PreTrainedTokenizer, + chat_template: HybridChatTemplate, + ): + + decorated = self.apply_chat_template(chat_template) + + if self._with_multi_modal_content(): + token_maps = chat_template.mm_token_maps + token_ids = self._encode_mm_content(decorated, token_maps, + tokenizer) + else: + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + image_urls = self.collect_img_urls() + + return { + 'input_ids': token_ids, + 'labels': label_ids, + 'image_urls': image_urls + } + + +class TrainingFunctionCallMsg(FunctionCallMsg): + loss: bool = True + + def tokenize( + self, + tokenizer: PreTrainedTokenizer, + chat_template: HybridChatTemplate, + ): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class TrainingFunctionResultMsg(FunctionResultMsg): + loss: bool = False + + def tokenize(self, tokenizer, chat_template: HybridChatTemplate): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class TrainingCodeInterpreterCallMsg(CodeInterpreterCallMsg): + loss: bool = True + + def tokenize(self, tokenizer, chat_template: HybridChatTemplate): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class TrainingCodeInterpreterResultMsg(CodeInterpreterResultMsg): + loss: bool = False + + def tokenize(self, tokenizer, chat_template: HybridChatTemplate): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class RawTrainingData(BaseModel): + + input_ids: List[int] + labels: List[int] + image_urls: List[str] = [] + + +class ProcessedTrainingData(BaseModel): + + input_ids: List[int] + labels: List[int] + length: int + cumulative_len: List[int] + position_ids: List[int] + image_urls: List[str] = [] + pixel_values: List[torch.Tensor] = [] + image_ranges: List[tuple] = [] + + class Config: + arbitrary_types_allowed = True + + +TraingHybridMessageType = Union[TrainingChatMsg, TrainingFunctionCallMsg, + TrainingFunctionResultMsg, + TrainingCodeInterpreterCallMsg, + TrainingCodeInterpreterResultMsg] + + +class TrainingHybridChatMessages(BaseModel): + messages: List[TraingHybridMessageType] + functions: Optional[List[Functions]] = None + code_interpreter: Optional[str] = None + + @classmethod + def from_dict(cls, item) -> 'TrainingHybridChatMessages': + ''' + item + { + 'messages':[ + {'role':'user', 'content':'hello'}, + {'role':'assistant', 'content':'hello!'}, + ], + 'funcitons': [], + } + + ''' + + assert 'messages' in item, item + + _messages = item['messages'] + messages = [] + functions = None + code_interpreter = None + + for _msg in _messages: + assert 'role' in _msg and 'content' in _msg + _role = _msg['role'] + _content = _msg['content'] + if _role == 'function': + msg_factory = TrainingFunctionResultMsg + assert 'name' in _msg + name = _msg['name'] + msg = msg_factory(role=_role, name=name, content=_content) + messages.append(msg) + continue + + if _role == 'code_interpreter': + msg_factory = TrainingCodeInterpreterResultMsg + msg = msg_factory(role=_role, content=_content) + messages.append(msg) + continue + + if isinstance(_content, list): + + content = [] + for c_item in _content: + assert 'type' in c_item + _type = c_item['type'] + if _type == 'text': + assert 'text' in c_item + _text = c_item['text'] + content.append(TextContentItem(type=_type, text=_text)) + elif _type == 'image_url': + assert 'image_url' in c_item + _url = c_item['image_url'] + content.append( + ImageContentItem(type=_type, image_url=_url)) + elif _type == 'file_url': + assert 'file_url' in c_item + _url = c_item['file_url'] + content.append(FileContentItem(file_url=_url)) + else: + raise NotImplementedError + + msg = TrainingChatMsg(role=_role, content=content) + messages.append(msg) + continue + + if isinstance(_content, str) and 'function_call' in _msg: + _call = _msg['function_call'] + msg = TrainingFunctionCallMsg( + role=_role, content=_content, function_call=_call) + messages.append(msg) + continue + elif isinstance(_content, str) and 'code_interpreter' in _msg: + _call = _msg['function_call'] + msg = TrainingCodeInterpreterCallMsg( + role=_role, content=_content, code_interpreter_call=_call) + messages.append(msg) + continue + + if isinstance(_content, str): + # breakpoint() + msg = TrainingChatMsg(role=_role, content=_content) + messages.append(msg) + + # TODO (pppppM) add format warning + + if 'code_interpreter' in item: + + assert isinstance(item['code_interpreter'], str) + code_interpreter = item['code_interpreter'] + + if 'functions' in item: + _functions = item['functions'] + assert isinstance(_functions, list) + functions = [] + + for _func in _functions: + assert 'name' in _func + assert 'description' in _func + assert 'parameters' in _func + + _name = _func['name'] + _desc = _func['description'] + _params = _func['parameters'] + + func = Functions( + name=_name, description=_desc, parameters=_params) + functions.append(func) + + return cls( + messages=messages, + functions=functions, + code_interpreter=code_interpreter) + + def collect_img_urls(self) -> List[str]: + img_urls = [] + for msg in self.messages: + img_urls.extend(msg.collect_img_urls()) + return img_urls + + def pop_latest_msg(self): + return self.messages.pop() + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + prompt = '' + + if isinstance(self.functions, list) and len(self.functions) > 0: + + functions = [func.model_dump() for func in self.functions] + + prompt += chat_template.decorate_functions(functions) + + for msg in self.messages: + if msg.role == 'system': + prompt = msg.apply_chat_template(chat_template) + prompt + else: + prompt += msg.apply_chat_template(chat_template) + + return prompt + + def tokenize(self, tokenizer: PreTrainedTokenizer, + chat_template: HybridChatTemplate) -> RawTrainingData: + + input_ids = [] + labels = [] + image_urls = [] + + bos_token_ids = get_bos_token_ids(tokenizer) + input_ids.extend(bos_token_ids) + labels.extend([IGNORE_INDEX] * len(bos_token_ids)) + + for msg in self.messages: + res = msg.tokenize(tokenizer, chat_template) + token_ids, label_ids = res['input_ids'], res['labels'] + + input_ids.extend(token_ids) + labels.extend(label_ids) + + if 'image_urls' in res: + image_urls.extend(res['image_urls']) + + # TODO (pppppM) Verify whether sep and suffix_as_eos are necessary + + training_data = { + 'input_ids': input_ids, + 'labels': labels, + 'image_urls': image_urls + } + return training_data diff --git a/xtuner/utils/__init__.py b/xtuner/utils/__init__.py index 6bc9a1173..75bcad2bf 100644 --- a/xtuner/utils/__init__.py +++ b/xtuner/utils/__init__.py @@ -3,9 +3,10 @@ IGNORE_INDEX, IMAGE_TOKEN_INDEX) from .stop_criteria import StopWordStoppingCriteria from .templates import PROMPT_TEMPLATE, SYSTEM_TEMPLATE +from .tokenizer import build_tokenizer __all__ = [ 'IGNORE_INDEX', 'DEFAULT_PAD_TOKEN_INDEX', 'PROMPT_TEMPLATE', 'DEFAULT_IMAGE_TOKEN', 'SYSTEM_TEMPLATE', 'StopWordStoppingCriteria', - 'IMAGE_TOKEN_INDEX' + 'IMAGE_TOKEN_INDEX', 'build_tokenizer' ] diff --git a/xtuner/utils/config.py b/xtuner/utils/config.py new file mode 100644 index 000000000..0514dd8bf --- /dev/null +++ b/xtuner/utils/config.py @@ -0,0 +1,131 @@ +import dataclasses +from typing import TypeVar, Union + +import torch +from mmengine.config import Config +from mmengine.logging import print_log +from mmengine.utils import get_object_from_string + +from xtuner.registry import BUILDER + + +def convert_dtype_cfg_to_obj(config: Union[dict, list[dict]]) -> None: + """Convert dtype related config to python object. + + When MMEngine Runner is training, it will save the config file for + resuming training. + But in the saved config file, python objects of type torch.dtype are + converted to strings like 'torch.float16'. In order to accommodate this, + after loading the config, all dtype strings need to be converted into + python objects. + + Args: + config: A dict or list that potentially contains dtypes as strings. + + Returns: + None. The input 'config' is modified in-place. + """ + # If the config is a dictionary + if isinstance(config, dict): + for key, value in config.items(): + # Recursively call the function if the value is also a dict + if isinstance(value, dict): + convert_dtype_cfg_to_obj(value) + + # Replace the string with the corresponding dtype object + # if it's a recognized dtype string + elif value in ['torch.float16', 'torch.float32', 'torch.bfloat16']: + config[key] = getattr(torch, value.split('.')[-1]) + + # If the config is a list + elif isinstance(config, list): + for item in config: + convert_dtype_cfg_to_obj(item) + + +def convert_dataclass_cfg_to_obj(config: Union[dict, list[dict]]) -> None: + """Convert dataclass related config to python object. + + Huggingface's code uses dataclasses extensively. + In order to use Huggingface's interfaces in the MMEngine config, + we need to specifically handle these configurations. + + Note: + Before executing this function, you must first run + `convert_dtype_cfg_to_obj`, otherwise the dataclass config containing + dtype cannot be properly converted ! + + Args: + config: A dictionary or list that potentially contains configurations + as dataclasses. + + Returns: + None. The input 'config' is modified in-place. + """ + # If the config is a dictionary + if isinstance(config, dict): + for key, value in config.items(): + # Recursively call the function if the value is also a dict + if isinstance(value, dict): + convert_dataclass_cfg_to_obj(value) + + # Check if the type of value is a dataclass + if 'type' in value and dataclasses.is_dataclass(value['type']): + builder = value.pop( + 'type') # remove 'type' from value and get its content + + # Convert the builder to an object if it is a string + if isinstance(builder, str): + builder = get_object_from_string(builder) + + # Build a new_value using the remaining items in value + new_value = builder(**value) + # replace the original value with new_value + config[key] = new_value + print_log(f'{key} convert to {builder}') + + # If the config is a list + elif isinstance(config, list): + for item in config: + convert_dataclass_cfg_to_obj(item) + + +OBJ_T = TypeVar('OBJ_T') + + +def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T], + accept: OBJ_T) -> OBJ_T: + """Build a python object from a config or return an existed object. + + Args: + cfg_or_obj (dict, OBJ_T]): an object of a type specified in + `accept_obj_types`, or a dict. + accept_obj (OBJ_T): the type of object that return without any + modification. + + Returns: + If 'cfg_or_obj' is an object of `accept_obj` , it is returned directly. + If 'cfg_or_obj' is a dict, it is built into an object using + `build_from_cfg()`. + + Raises: + TypeError: If `cfg_or_obj` is not dict or `accept_obj`. + """ + + if isinstance(cfg_or_obj, accept): + return cfg_or_obj + + elif isinstance(cfg_or_obj, (dict, Config)): + convert_dtype_cfg_to_obj(cfg_or_obj) + convert_dataclass_cfg_to_obj(cfg_or_obj) + obj = BUILDER.build(cfg_or_obj) + + if not isinstance(obj, accept): + raise TypeError( + f'Expect an object of {accept}, but there is an object of ' + f'{type(obj)}.') + return obj + + else: + raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got ' + f'{type(cfg_or_obj)}') diff --git a/xtuner/utils/tokenizer.py b/xtuner/utils/tokenizer.py new file mode 100644 index 000000000..7c79b9fba --- /dev/null +++ b/xtuner/utils/tokenizer.py @@ -0,0 +1,46 @@ +from typing import List, Union + +from transformers import AutoTokenizer + +from xtuner.registry import BUILDER + + +def build_tokenizer(tokenizer: Union[str, dict]): + + if isinstance(tokenizer, str): + return AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + elif isinstance(tokenizer, dict): + return BUILDER.build(tokenizer) + else: + raise TypeError + + +def get_bos_token_ids(tokenizer) -> List[int]: + + if tokenizer.__class__.__name__ == 'QWenTokenizer': + bos_token_ids = [] + elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': + bos_token_ids = [64790, 64792] + else: + bos_token_ids = tokenizer.bos_token_id + + if isinstance(bos_token_ids, int): + bos_token_ids = [bos_token_ids] + + return bos_token_ids + + +def get_eos_token_ids(tokenizer) -> List[int]: + if tokenizer.__class__.__name__ == 'QWenTokenizer': + eos_token_ids = tokenizer.eos_token_id + assert eos_token_ids is not None, \ + 'Please set eos_token for Qwen tokenizer!' + elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': + eos_token_ids = tokenizer.eos_token_id + else: + eos_token_ids = tokenizer.eos_token_id + + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + + return eos_token_ids