diff --git a/align_anything/configs/format_dataset.py b/align_anything/configs/format_dataset.py index 85516722..25f0968d 100644 --- a/align_anything/configs/format_dataset.py +++ b/align_anything/configs/format_dataset.py @@ -368,6 +368,31 @@ def format_supervised_sample( ], {} +@register_template('Qwen_Omni_TI2T') +class Qwen_Omni_TI2T(BaseFormatter): + system_prompt: str = 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.' + + def format_supervised_sample( + self, raw_sample: dict[str, Any] + ) -> tuple[list[dict[str, Any]], dict[str, Any]]: + prompt = raw_sample['prompt'] + answer = raw_sample['response'] + image = raw_sample['image'].convert('RGBA') + + return [ + {'role': 'system', 'content': self.system_prompt}, + { + 'role': 'user', + 'content': [ + {'type': 'image', 'image': image}, + {'type': 'text', 'text': prompt}, + ], + }, + {'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]}, + ], {'image': image} + + + @register_template('AA_TI2T') class AA_TI2T(BaseFormatter): system_prompt: str = '' diff --git a/align_anything/configs/template.py b/align_anything/configs/template.py index 7e1e5f82..3248a5d7 100644 --- a/align_anything/configs/template.py +++ b/align_anything/configs/template.py @@ -40,6 +40,7 @@ def format_supervised_sample(self, raw_sample: dict[str, Any]) -> tuple[str, str raw_sample ) raw_prompt = raw_conversation[:-1] + multi_modal_info['raw_conversation'] = raw_conversation return ( self.model_formatter(raw_prompt), self.model_formatter(raw_conversation), diff --git a/align_anything/configs/train/text_image_to_text/sft.yaml b/align_anything/configs/train/text_image_to_text/sft.yaml index 78c83518..cb73d8c6 100644 --- a/align_anything/configs/train/text_image_to_text/sft.yaml +++ b/align_anything/configs/train/text_image_to_text/sft.yaml @@ -31,7 +31,7 @@ train_cfgs: # Batch size per device for evaluation per_device_eval_batch_size: 1 # The number of gradient accumulation steps - gradient_accumulation_steps: 16 + gradient_accumulation_steps: 1 # Whether to use gradient checkpointing gradient_checkpointing: True # Initial learning rate diff --git a/align_anything/datasets/qwen_omni/supervised.py b/align_anything/datasets/qwen_omni/supervised.py new file mode 100755 index 00000000..005590c7 --- /dev/null +++ b/align_anything/datasets/qwen_omni/supervised.py @@ -0,0 +1,208 @@ +# Copyright 2024 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import os +from typing import Any, Callable +from typing_extensions import TypedDict # Python 3.10+ + +import torch +import transformers +from torch.utils.data import Dataset +from torchvision import transforms +from transformers.tokenization_utils import PaddingStrategy, TruncationStrategy + +from align_anything.utils.multi_process import get_current_device +from align_anything.utils.tools import convert_to_rgb, ends_with_any +from datasets import load_dataset +from qwen_omni_utils import process_mm_info + + +IGNORE_INDEX = -100 + +__all__ = [ + 'SupervisedDataset', + 'SupervisedCollator', + 'SupervisedSample', + 'SupervisedBatch', +] + + +class SupervisedSample(TypedDict, total=True): + input_ids: torch.LongTensor # size = (L,) + labels: torch.LongTensor # size = (L,) + pixel_values: torch.LongTensor | None # size = (B, C, H, W) + + +class SupervisedBatch(TypedDict, total=True): + input_ids: torch.LongTensor # size = (B, L) + labels: torch.LongTensor # size = (B, L) + attention_mask: torch.BoolTensor # size = (B, L) + pixel_values: torch.LongTensor | None # size = (B, C, H, W) + + +class SupervisedDataset(Dataset): + + def __init__( + self, + path: str, + template: str, + tokenizer: transformers.PreTrainedTokenizer, + processor: transformers.ProcessorMixin | transforms.Compose | None = None, + padding_side: str = 'right', + name: str | None = None, + size: int | None = None, + split: str | None = None, + data_files: str | None = None, + optional_args: list | str = [], + ): + super().__init__() + assert path, f'You must set the valid datasets path! Here is {path}' + assert template, f'You must set the valid template path! Here is {template}' + self.path = path + self.tokenizer = tokenizer + self.processor = processor + self.padding_side = padding_side + self.raw_data = load_dataset( + path, + name=name if name and name != 'None' else None, + split=split if split and split != 'None' else None, + data_files=data_files if data_files and data_files != 'None' else None, + *optional_args, + trust_remote_code=True, + num_proc=16, + ) + if size: + self.raw_data = self.raw_data.select(range(int(size))) + self.template = template + + def preprocess(self, raw_sample: dict[str, Any]) -> SupervisedSample: + return_dict = {} + prompt, conversation, meta_info = self.template.format_supervised_sample(raw_sample) + conversation = conversation[0] + if not ends_with_any(conversation, self.tokenizer.eos_token): + conversation += self.tokenizer.eos_token + + # return necessary information + return_dict['prompt'] = prompt + return_dict['conversation'] = conversation + return_dict['image'] = meta_info['image'] + return_dict['raw_conversation'] = meta_info['raw_conversation'] + + # set the labels masked by the prompt + return_dict['prompt_lens'] = len( + self.tokenize(prompt, add_special_tokens=False)['input_ids'][0] + ) + + return return_dict + + def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]: + return SupervisedCollator(self.tokenizer.pad_token_id, self.processor, self.padding_side) + + def tokenize( + self, + conversation: str, + add_special_tokens: bool = True, + padding: bool | str | PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation: bool | str | TruncationStrategy = TruncationStrategy.LONGEST_FIRST, + max_length: int | None = None, + ) -> torch.LongTensor: # size = (L,) + """Tokenize a text string into a tensor representation.""" + if max_length is None: + max_length = self.tokenizer.model_max_length + + return self.tokenizer( + text=conversation, + add_special_tokens=add_special_tokens, + padding=padding, + max_length=max_length, + truncation=truncation, + return_tensors='pt', + ) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + """Get a tokenized data sample by index.""" + raw_sample = self.raw_data[index] + data = self.preprocess(raw_sample.copy()) + return data + + def __len__(self) -> int: + """Get the number of samples in the dataset.""" + return len(self.raw_data) + + +class SupervisedCollator: + + def __init__( + self, + pad_token_id: int, + processor: transformers.ProcessorMixin | transforms.Compose | None = None, + padding_side: str = 'right', + ) -> None: + """Initialize a collator.""" + self.pad_token_id = pad_token_id + self.processor = processor + self.padding_side = padding_side + + def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch: + return_dict = {'meta_info': {}} + current_device = get_current_device() + + concated_text = [sample['conversation'] for sample in samples] + concated_raw_conversation = [sample['raw_conversation'] for sample in samples] + + if os.environ.get('MULTI_IMAGES_INFERENCE_MODELS') == 'Yes': + images = [[sample['image']] for sample in samples] + else: + images = [sample['image'] for sample in samples] + + audios, images, videos = process_mm_info(concated_raw_conversation, use_audio_in_video=True) + return_dict['meta_info']['images'] = images + + multi_modal_padding = self.processor( + images=images, + text=concated_text, + audios=audios, + videos=videos, + return_tensors='pt', + padding=True, + padding_side=self.padding_side, + return_attention_mask=True, + ) + + inputs_ids = multi_modal_padding['input_ids'] + labels = inputs_ids.clone() + + for i in range(len(samples)): + prompt_lens = samples[i]['prompt_lens'] + labels[i, :prompt_lens] = IGNORE_INDEX + + return_dict.update(multi_modal_padding) + return_dict['labels'] = labels + for key, value in return_dict.items(): + if isinstance(value, torch.Tensor): + return_dict[key] = value.to(current_device) + elif key == 'pixel_values': + + def move_to_device(item): + if isinstance(item, list): + return [move_to_device(sub_item) for sub_item in item] + elif isinstance(item, torch.Tensor): + return item.to(current_device) + return item + + return_dict[key] = move_to_device(value) + + return return_dict diff --git a/align_anything/models/baichuan_m1.py b/align_anything/models/baichuan_m1.py index d6178e3d..4f2003e3 100644 --- a/align_anything/models/baichuan_m1.py +++ b/align_anything/models/baichuan_m1.py @@ -22,23 +22,26 @@ from transformers import AutoConfig, AutoTokenizer from transformers.dynamic_module_utils import get_class_from_dynamic_module +try: -MODEL_NAME_OR_PATH = os.environ.get('MODEL_NAME_OR_PATH', 'baichuan-inc/Baichuan-M1-14B-Instruct') -CONFIG = AutoConfig.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True) -CLASS_REF = CONFIG.auto_map['AutoModelForCausalLM'] -BaichuanM1 = get_class_from_dynamic_module(CLASS_REF, MODEL_NAME_OR_PATH) + MODEL_NAME_OR_PATH = os.environ.get('MODEL_NAME_OR_PATH', 'baichuan-inc/Baichuan-M1-14B-Instruct') + CONFIG = AutoConfig.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True) + CLASS_REF = CONFIG.auto_map['AutoModelForCausalLM'] + BaichuanM1 = get_class_from_dynamic_module(CLASS_REF, MODEL_NAME_OR_PATH) -class AccustomedBaichuanM1(BaichuanM1): + class AccustomedBaichuanM1(BaichuanM1): - def __init__(self, config: AutoConfig): - super().__init__(config) - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True) - self.system_prompt = '' + def __init__(self, config: AutoConfig): + super().__init__(config) + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True) + self.system_prompt = '' - def apply_chat_template( - self, messages: list[dict[str, Any]], add_generation_prompt: bool = False - ) -> dict[str, Any]: - return self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=add_generation_prompt - ) + def apply_chat_template( + self, messages: list[dict[str, Any]], add_generation_prompt: bool = False + ) -> dict[str, Any]: + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=add_generation_prompt + ) +except: + print("BaichuanM1 is not supported in this version of transformers") \ No newline at end of file diff --git a/align_anything/models/model_registry.py b/align_anything/models/model_registry.py index fe5fe5ea..098d3ad5 100644 --- a/align_anything/models/model_registry.py +++ b/align_anything/models/model_registry.py @@ -94,6 +94,7 @@ def get_model_class_for_trust_remote_code(model_type, model_mapping_names): ('idefics2', 'AccustomedIdefics2Model'), ('gemma3', 'AccustomedGemma3Model'), ('opt', 'AccustomedOPTModel'), + ('qwen2_5_omni', 'AccustomedQwen2_5_OmniModel'), ], ) diff --git a/align_anything/models/qwen2_5_omni.py b/align_anything/models/qwen2_5_omni.py new file mode 100644 index 00000000..6dd63f1a --- /dev/null +++ b/align_anything/models/qwen2_5_omni.py @@ -0,0 +1,50 @@ +# Copyright 2025 PKU-Alignment Team. All Rights Reserved. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +try: + + from transformers import Qwen2_5OmniThinkerForConditionalGeneration + + class AccustomedQwen2_5_OmniModel(Qwen2_5OmniThinkerForConditionalGeneration): + + @property + def processor_available(self): + return True + + @property + def embed_tokens(self): + return self.model.embed_tokens + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + **kwargs, + ): + config = config.thinker_config + return super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + **kwargs, + ) + +except: + print("Qwen2_5OmniThinkerForConditionalGeneration is not supported in this version of transformers") \ No newline at end of file diff --git a/align_anything/trainers/qwen_omni/ti2t_sft.py b/align_anything/trainers/qwen_omni/ti2t_sft.py new file mode 100644 index 00000000..4395ad8d --- /dev/null +++ b/align_anything/trainers/qwen_omni/ti2t_sft.py @@ -0,0 +1,100 @@ +# Copyright 2024 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Trainer for supervised training.""" + + +import argparse +import os +import sys + +import deepspeed +from transformers.integrations.deepspeed import HfDeepSpeedConfig + +from align_anything.datasets.qwen_omni.supervised import SupervisedDataset +from align_anything.models.pretrained_model import load_pretrained_models +from align_anything.trainers.text_to_text.sft import SupervisedTrainer as SupervisedtextTrainer +from align_anything.utils.device_utils import torch_set_device +from align_anything.utils.multi_process import get_current_device +from align_anything.utils.tools import ( + custom_cfgs_to_dict, + dict_to_namedtuple, + read_cfgs, + seed_everything, + update_dict, +) + + +class SuperviseTrainer(SupervisedtextTrainer): + + def init_datasets(self) -> None: + """Initialize training and evaluation datasets.""" + if self.cfgs.data_cfgs.load_multi_datasets: + self.train_dataloader, self.eval_dataloader = self.get_multi_dataloaders( + SupervisedDataset, SupervisedDataset + ) + else: + self.train_dataloader, self.eval_dataloader = self.get_dataloaders( + SupervisedDataset, SupervisedDataset + ) + + def init_models(self) -> None: + """Initialize model and tokenizer.""" + if self.ds_train_cfgs is not None and self.ds_train_cfgs['zero_optimization']['stage'] == 3: + self.dstchf = HfDeepSpeedConfig(self.ds_train_cfgs) + self.model, self.tokenizer, self.processor = load_pretrained_models( + self.cfgs.model_cfgs.model_name_or_path, + model_max_length=self.cfgs.model_cfgs.model_max_length, + padding_side='right', + trust_remote_code=True, + freeze_mm_proj=self.cfgs.train_cfgs.freeze_mm_proj, + freeze_vision_tower=self.cfgs.train_cfgs.freeze_vision_tower, + freeze_language_model=self.cfgs.train_cfgs.freeze_language_model, + processor_kwargs=self.cfgs.train_cfgs.processor_kwargs, + modality=['image'], + ) + self.tokenizer.model_max_length = self.cfgs.model_cfgs.model_max_length + + +def main(): + # setup distribution training + deepspeed.init_distributed() + current_device = get_current_device() + torch_set_device(current_device) + + # read default configs from the yaml file + task = os.path.join('text_image_to_text', 'sft') + dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task) + + # get custom configs from command line + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + _, unparsed_args = parser.parse_known_args() + keys = [k[2:] for k in unparsed_args[1::2]] + values = list(unparsed_args[2::2]) + unparsed_args = dict(zip(keys, values)) + for k, v in unparsed_args.items(): + dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v)) + + # setup training + cfgs = dict_to_namedtuple(dict_cfgs) + seed_everything(cfgs.train_cfgs.seed) + + # finetune the model + trainer = SuperviseTrainer(cfgs=cfgs, ds_cfgs=ds_cfgs) + trainer.train() + trainer.save() + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/align_anything/trainers/text_to_text/sft.py b/align_anything/trainers/text_to_text/sft.py index 5553c202..d3b1d26a 100644 --- a/align_anything/trainers/text_to_text/sft.py +++ b/align_anything/trainers/text_to_text/sft.py @@ -132,9 +132,10 @@ def train(self) -> None: for epoch in range(int(remain_epoch)): self.model.train() - progress_bar.set_description( - f'Resuming from checkpoint {epoch + 1}/{self.cfgs.train_cfgs.epochs} epoch ' - ) + if self.cfgs.train_cfgs.load_checkpoint: + progress_bar.set_description( + f'Resuming from checkpoint {epoch + 1}/{self.cfgs.train_cfgs.epochs} epoch ' + ) for batch_idx, batch in enumerate(self.train_dataloader): if epoch == 0 and batch_idx < start_batch_idx: diff --git a/pyproject.toml b/pyproject.toml index 45f52293..633c9674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,6 @@ train = [ "torch >= 1.13", "frechet_audio_distance", "transformers", - "flash-attn", "datasets", "tokenizers >= 0.13.3", "accelerate", diff --git a/scripts/qwen_omni_sft.sh b/scripts/qwen_omni_sft.sh new file mode 100755 index 00000000..bc6b42f1 --- /dev/null +++ b/scripts/qwen_omni_sft.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# +# Copyright 2024 PKU-Alignment Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +MODEL_NAME_OR_PATH="Qwen/Qwen2.5-Omni-7B" # model path + +TRAIN_DATASETS="PKU-Alignment/Align-Anything-TI2T-Instruction-100K" # dataset path +TRAIN_TEMPLATE="Qwen_Omni_TI2T" # dataset template +TRAIN_SPLIT="train" # split the dataset + +OUTPUT_DIR="../output/qwen_omni_sft" # output dir + +# For wandb online logging +export WANDB_API_KEY="" +# Source the setup script +source ./setup.sh + +# Execute deepspeed command +deepspeed \ + --master_port ${MASTER_PORT} \ + --module align_anything.trainers.qwen_omni.ti2t_sft \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --train_datasets ${TRAIN_DATASETS} \ + --train_template ${TRAIN_TEMPLATE} \ + --train_split ${TRAIN_SPLIT} \ + --output_dir ${OUTPUT_DIR} \ + --save_total_limit 2 \ + --per_device_train_batch_size 1 \ + --train_size 10 \ + --epochs 3