Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions align_anything/configs/format_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,31 @@ def format_supervised_sample(
], {}


@register_template('Qwen_Omni_TI2T')
class Qwen_Omni_TI2T(BaseFormatter):
system_prompt: str = 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'

def format_supervised_sample(
self, raw_sample: dict[str, Any]
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
prompt = raw_sample['prompt']
answer = raw_sample['response']
image = raw_sample['image'].convert('RGBA')

return [
{'role': 'system', 'content': self.system_prompt},
{
'role': 'user',
'content': [
{'type': 'image', 'image': image},
{'type': 'text', 'text': prompt},
],
},
{'role': 'assistant', 'content': [{'type': 'text', 'text': answer}]},
], {'image': image}



@register_template('AA_TI2T')
class AA_TI2T(BaseFormatter):
system_prompt: str = ''
Expand Down
1 change: 1 addition & 0 deletions align_anything/configs/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def format_supervised_sample(self, raw_sample: dict[str, Any]) -> tuple[str, str
raw_sample
)
raw_prompt = raw_conversation[:-1]
multi_modal_info['raw_conversation'] = raw_conversation
return (
self.model_formatter(raw_prompt),
self.model_formatter(raw_conversation),
Expand Down
2 changes: 1 addition & 1 deletion align_anything/configs/train/text_image_to_text/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ train_cfgs:
# Batch size per device for evaluation
per_device_eval_batch_size: 1
# The number of gradient accumulation steps
gradient_accumulation_steps: 16
gradient_accumulation_steps: 1
# Whether to use gradient checkpointing
gradient_checkpointing: True
# Initial learning rate
Expand Down
208 changes: 208 additions & 0 deletions align_anything/datasets/qwen_omni/supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import os
from typing import Any, Callable
from typing_extensions import TypedDict # Python 3.10+

import torch
import transformers
from torch.utils.data import Dataset
from torchvision import transforms
from transformers.tokenization_utils import PaddingStrategy, TruncationStrategy

from align_anything.utils.multi_process import get_current_device
from align_anything.utils.tools import convert_to_rgb, ends_with_any
from datasets import load_dataset
from qwen_omni_utils import process_mm_info


IGNORE_INDEX = -100

__all__ = [
'SupervisedDataset',
'SupervisedCollator',
'SupervisedSample',
'SupervisedBatch',
]


class SupervisedSample(TypedDict, total=True):
input_ids: torch.LongTensor # size = (L,)
labels: torch.LongTensor # size = (L,)
pixel_values: torch.LongTensor | None # size = (B, C, H, W)


class SupervisedBatch(TypedDict, total=True):
input_ids: torch.LongTensor # size = (B, L)
labels: torch.LongTensor # size = (B, L)
attention_mask: torch.BoolTensor # size = (B, L)
pixel_values: torch.LongTensor | None # size = (B, C, H, W)


class SupervisedDataset(Dataset):

def __init__(
self,
path: str,
template: str,
tokenizer: transformers.PreTrainedTokenizer,
processor: transformers.ProcessorMixin | transforms.Compose | None = None,
padding_side: str = 'right',
name: str | None = None,
size: int | None = None,
split: str | None = None,
data_files: str | None = None,
optional_args: list | str = [],
):
super().__init__()
assert path, f'You must set the valid datasets path! Here is {path}'
assert template, f'You must set the valid template path! Here is {template}'
self.path = path
self.tokenizer = tokenizer
self.processor = processor
self.padding_side = padding_side
self.raw_data = load_dataset(
path,
name=name if name and name != 'None' else None,
split=split if split and split != 'None' else None,
data_files=data_files if data_files and data_files != 'None' else None,
*optional_args,
trust_remote_code=True,
num_proc=16,
)
if size:
self.raw_data = self.raw_data.select(range(int(size)))
self.template = template

def preprocess(self, raw_sample: dict[str, Any]) -> SupervisedSample:
return_dict = {}
prompt, conversation, meta_info = self.template.format_supervised_sample(raw_sample)
conversation = conversation[0]
if not ends_with_any(conversation, self.tokenizer.eos_token):
conversation += self.tokenizer.eos_token

# return necessary information
return_dict['prompt'] = prompt
return_dict['conversation'] = conversation
return_dict['image'] = meta_info['image']
return_dict['raw_conversation'] = meta_info['raw_conversation']

# set the labels masked by the prompt
return_dict['prompt_lens'] = len(
self.tokenize(prompt, add_special_tokens=False)['input_ids'][0]
)

return return_dict

def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
return SupervisedCollator(self.tokenizer.pad_token_id, self.processor, self.padding_side)

def tokenize(
self,
conversation: str,
add_special_tokens: bool = True,
padding: bool | str | PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation: bool | str | TruncationStrategy = TruncationStrategy.LONGEST_FIRST,
max_length: int | None = None,
) -> torch.LongTensor: # size = (L,)
"""Tokenize a text string into a tensor representation."""
if max_length is None:
max_length = self.tokenizer.model_max_length

return self.tokenizer(
text=conversation,
add_special_tokens=add_special_tokens,
padding=padding,
max_length=max_length,
truncation=truncation,
return_tensors='pt',
)

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
"""Get a tokenized data sample by index."""
raw_sample = self.raw_data[index]
data = self.preprocess(raw_sample.copy())
return data

def __len__(self) -> int:
"""Get the number of samples in the dataset."""
return len(self.raw_data)


class SupervisedCollator:

def __init__(
self,
pad_token_id: int,
processor: transformers.ProcessorMixin | transforms.Compose | None = None,
padding_side: str = 'right',
) -> None:
"""Initialize a collator."""
self.pad_token_id = pad_token_id
self.processor = processor
self.padding_side = padding_side

def __call__(self, samples: list[SupervisedSample]) -> SupervisedBatch:
return_dict = {'meta_info': {}}
current_device = get_current_device()

concated_text = [sample['conversation'] for sample in samples]
concated_raw_conversation = [sample['raw_conversation'] for sample in samples]

if os.environ.get('MULTI_IMAGES_INFERENCE_MODELS') == 'Yes':
images = [[sample['image']] for sample in samples]
else:
images = [sample['image'] for sample in samples]

audios, images, videos = process_mm_info(concated_raw_conversation, use_audio_in_video=True)
return_dict['meta_info']['images'] = images

multi_modal_padding = self.processor(
images=images,
text=concated_text,
audios=audios,
videos=videos,
return_tensors='pt',
padding=True,
padding_side=self.padding_side,
return_attention_mask=True,
)

inputs_ids = multi_modal_padding['input_ids']
labels = inputs_ids.clone()

for i in range(len(samples)):
prompt_lens = samples[i]['prompt_lens']
labels[i, :prompt_lens] = IGNORE_INDEX

return_dict.update(multi_modal_padding)
return_dict['labels'] = labels
for key, value in return_dict.items():
if isinstance(value, torch.Tensor):
return_dict[key] = value.to(current_device)
elif key == 'pixel_values':

def move_to_device(item):
if isinstance(item, list):
return [move_to_device(sub_item) for sub_item in item]
elif isinstance(item, torch.Tensor):
return item.to(current_device)
return item

return_dict[key] = move_to_device(value)

return return_dict
33 changes: 18 additions & 15 deletions align_anything/models/baichuan_m1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,26 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.dynamic_module_utils import get_class_from_dynamic_module

try:

MODEL_NAME_OR_PATH = os.environ.get('MODEL_NAME_OR_PATH', 'baichuan-inc/Baichuan-M1-14B-Instruct')
CONFIG = AutoConfig.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)
CLASS_REF = CONFIG.auto_map['AutoModelForCausalLM']
BaichuanM1 = get_class_from_dynamic_module(CLASS_REF, MODEL_NAME_OR_PATH)
MODEL_NAME_OR_PATH = os.environ.get('MODEL_NAME_OR_PATH', 'baichuan-inc/Baichuan-M1-14B-Instruct')
CONFIG = AutoConfig.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)
CLASS_REF = CONFIG.auto_map['AutoModelForCausalLM']
BaichuanM1 = get_class_from_dynamic_module(CLASS_REF, MODEL_NAME_OR_PATH)


class AccustomedBaichuanM1(BaichuanM1):
class AccustomedBaichuanM1(BaichuanM1):

def __init__(self, config: AutoConfig):
super().__init__(config)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)
self.system_prompt = ''
def __init__(self, config: AutoConfig):
super().__init__(config)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True)
self.system_prompt = ''

def apply_chat_template(
self, messages: list[dict[str, Any]], add_generation_prompt: bool = False
) -> dict[str, Any]:
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
def apply_chat_template(
self, messages: list[dict[str, Any]], add_generation_prompt: bool = False
) -> dict[str, Any]:
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
except:
print("BaichuanM1 is not supported in this version of transformers")
1 change: 1 addition & 0 deletions align_anything/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def get_model_class_for_trust_remote_code(model_type, model_mapping_names):
('idefics2', 'AccustomedIdefics2Model'),
('gemma3', 'AccustomedGemma3Model'),
('opt', 'AccustomedOPTModel'),
('qwen2_5_omni', 'AccustomedQwen2_5_OmniModel'),
],
)

Expand Down
50 changes: 50 additions & 0 deletions align_anything/models/qwen2_5_omni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 PKU-Alignment Team. All Rights Reserved.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

try:

from transformers import Qwen2_5OmniThinkerForConditionalGeneration

class AccustomedQwen2_5_OmniModel(Qwen2_5OmniThinkerForConditionalGeneration):

@property
def processor_available(self):
return True

@property
def embed_tokens(self):
return self.model.embed_tokens

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
config=None,
**kwargs,
):
config = config.thinker_config
return super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
**kwargs,
)

except:
print("Qwen2_5OmniThinkerForConditionalGeneration is not supported in this version of transformers")
Loading