66## 环境准备
77``` shell
88# 避免未来出现与文档的不兼容情况
9- pip install " ms-swift>=3.10.2,<3.11 "
9+ pip install " ms-swift>=4.0 "
1010
1111pip install " transformers==4.57.*" " qwen_omni_utils==0.0.8"
1212```
@@ -17,15 +17,14 @@ pip install "transformers==4.57.*" "qwen_omni_utils==0.0.8"
1717第一步,我们需要注册模型,来获取模型和processor。
1818
1919``` python
20- from swift.llm import (
21- register_model, ModelMeta, ModelGroup, Model, register_model_arch, MultiModelKeys,
22- get_model_tokenizer_with_flash_attn, get_model_tokenizer
23- )
24- from swift.llm.model.model.qwen import patch_qwen_vl_utils
25- from swift.llm.model.utils import use_submodel_func
26- from swift.llm.model.patcher import patch_get_input_embeddings
27- from swift.utils import get_env_args
20+ from transformers import PretrainedConfig, PreTrainedModel
2821
22+ from swift.model import (Model, ModelGroup, ModelMeta, MultiModelKeys, get_model_processor, register_model,
23+ register_model_arch, ModelLoader)
24+ from swift.model.models.qwen import patch_qwen_vl_utils
25+ from swift.model.patcher import patch_get_input_embeddings
26+ from swift.model.utils import use_submodel_func
27+ from swift.utils import get_env_args, Processor
2928
3029register_model_arch(
3130 MultiModelKeys(
@@ -41,33 +40,44 @@ register_model_arch(
4140 generator = [' talker' , ' token2wav' ],
4241 ))
4342
44- def get_model_tokenizer_qwen2_5_omni (model_dir , * args , ** kwargs ):
45- from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniConfig
46- from qwen_omni_utils import vision_process
47- print (' Run my_qwen2_5_omni...' )
48- kwargs[' automodel_class' ] = kwargs[' automodel_class' ] or Qwen2_5OmniForConditionalGeneration
49- # 自定义`get_model_tokenizer_with_flash_attn`中获取tokenizer和config的方式
50- processor = Qwen2_5OmniProcessor.from_pretrained(model_dir, trust_remote_code = True )
51- kwargs[' tokenizer' ] = processor.tokenizer
52- kwargs[' model_config' ] = Qwen2_5OmniConfig.from_pretrained(model_dir, trust_remote_code = True )
53- enable_audio_output = get_env_args(' ENABLE_AUDIO_OUTPUT' , bool , None )
54- if enable_audio_output is not None :
55- kwargs[' model_config' ].enable_audio_output = enable_audio_output
56- # 可以通过环境变量来控制qwen_omni_utils库中的常量,例如:`MAX_PIXELS`等
57- patch_qwen_vl_utils(vision_process)
58- # 请尽量使用该函数来获取model和tokenizer。而避免直接使用AutoModelForCausalLM(会产生不兼容问题)。
59- model, _ = get_model_tokenizer_with_flash_attn(model_dir, * args, ** kwargs)
60- if model:
61- # 为了多模态模型的统一性,我们将模型的forward/generate函数替换为其language_model的forward/generate函数。
62- # 自己处理额外的部分。
43+ class Qwen2_5OmniLoader (ModelLoader ):
44+
45+
46+ def get_config (self , model_dir : str ) -> PretrainedConfig:
47+ from transformers import Qwen2_5OmniConfig
48+ config = Qwen2_5OmniConfig.from_pretrained(model_dir, trust_remote_code = True )
49+ enable_audio_output = get_env_args(' ENABLE_AUDIO_OUTPUT' , bool , None )
50+ if enable_audio_output is not None :
51+ config.enable_audio_output = enable_audio_output
52+ return config
53+
54+ def get_processor (self , model_dir : str , config : PretrainedConfig) -> Processor:
55+ from transformers import Qwen2_5OmniProcessor
56+ from qwen_omni_utils import vision_process
57+ processor = Qwen2_5OmniProcessor.from_pretrained(model_dir, trust_remote_code = True )
58+ # Control constants in qwen_omni_utils library via environment variables,
59+ # e.g., `MAX_PIXELS`, etc.
60+ patch_qwen_vl_utils(vision_process)
61+ return processor
62+
63+ def get_model (self , model_dir : str , config : PretrainedConfig, processor : Processor,
64+ model_kwargs ) -> PreTrainedModel:
65+ from transformers import Qwen2_5OmniForConditionalGeneration
66+ print (' Run my_qwen2_5_omni...' )
67+ self .auto_model_cls = self .auto_model_cls or Qwen2_5OmniForConditionalGeneration
68+ model = super ().get_model(model_dir, config, processor, model_kwargs)
69+ # For multimodal model consistency, we replace the model's forward/generate functions
70+ # with those of its language_model.
71+ # Handle additional parts separately.
6372 use_submodel_func(model, ' thinker' )
64- # 一些对model/config的自定义(通常不需要设置,若训练/推理中出现报错,则根据特定模型进行配置)
73+ # Avoid inplace operations on leaf_variable during training
74+ # (replacing parts of input_embeds with images_embeds)
75+ patch_get_input_embeddings(model.thinker.visual, ' patch_embed' )
76+ # Some custom settings for model/config (usually not needed; configure based on
77+ # specific model if errors occur during training/inference)
6578 model.config.keys_to_ignore_at_inference += [' hidden_states' , ' attention_mask' ]
6679 model.config.talker_config.pad_token_id = None
67- # 避免在训练时对leaf_variable进行inplace操作导致报错(将input_embeds中的部分内容替换为images_embeds的行为)
68- patch_get_input_embeddings(model.thinker.visual, ' patch_embed' )
69- # 最终需要返回model和 processor(多模态)/tokenizer(纯文本)
70- return model, processor
80+ return model
7181
7282
7383register_model(
@@ -79,9 +89,9 @@ register_model(
7989 Model(' Qwen/Qwen2.5-Omni-7B' , ' Qwen/Qwen2.5-Omni-7B' ),
8090 ]),
8191 ],
82- ' my_qwen2_5_omni' ,
8392 # 用来获取model和processor的函数。
84- get_model_tokenizer_qwen2_5_omni,
93+ Qwen2_5OmniLoader,
94+ template = ' my_qwen2_5_omni' ,
8595 is_multimodal = True , # 是否是多模态模型
8696 model_arch = ' my_qwen2_5_omni' , # 通常只为多模态模型设置
8797 # 用于model_type的自动匹配
@@ -96,7 +106,7 @@ register_model(
96106
97107if __name__ == ' __main__' :
98108 # 测试与debug
99- model, processor = get_model_tokenizer (' Qwen/Qwen2.5-Omni-7B' , model_type = ' my_qwen2_5_omni' )
109+ model, processor = get_model_processor (' Qwen/Qwen2.5-Omni-7B' , model_type = ' my_qwen2_5_omni' )
100110```
101111
102112## 注册模板
@@ -110,18 +120,16 @@ template的功能如下:
110120
111121
112122``` python
113- from swift.llm import (
114- register_template, Template, get_packed_seq_params, to_float_dtype, TemplateMeta,
115- get_template, get_model_tokenizer
116- )
117- from transformers.integrations import is_deepspeed_zero3_enabled
118- from swift.llm.template.template_inputs import StdTemplateInputs
119- from swift.llm.template.utils import Context, findall
120- from swift.llm.template.vision_utils import load_audio
121- from swift.utils import get_env_args, get_logger, is_deepspeed_enabled
122123from functools import partial
123- from typing import Dict, List, Any, Literal, Optional
124+ from typing import Any, Dict, List, Literal, Optional
125+
124126import torch
127+ from transformers.integrations import is_deepspeed_zero3_enabled
128+ from swift import get_model_processor
129+ from swift.template import StdTemplateInputs, Template, TemplateMeta, get_template, register_template
130+ from swift.template.utils import Context, findall
131+ from swift.template.vision_utils import load_audio
132+ from swift.utils import Processor, get_env_args, get_logger, get_packed_seq_params, is_deepspeed_enabled, to_float_dtype
125133
126134logger = get_logger()
127135
@@ -135,7 +143,7 @@ class Qwen2_5OmniTemplate(Template):
135143 # 并会使用简略方式打印(调用`template.safe_decode`)
136144 placeholder_tokens = [' <|IMAGE|>' , ' <|AUDIO|>' , ' <|VIDEO|>' ]
137145
138- def init_processor (self , processor ) -> None :
146+ def init_processor (self , processor : Processor ) -> None :
139147 """ 在初始化processor时,额外初始化所需的一些常量"""
140148 if processor is None :
141149 return
@@ -416,7 +424,7 @@ class Qwen2_5OmniTemplate(Template):
416424 return res
417425
418426 def generate (self , model , * args , ** kwargs ):
419- """ `PtEngine `会调用template.generate方法进行文本生成,这里继承进行自定义。"""
427+ """ `TransformersEngine `会调用template.generate方法进行文本生成,这里继承进行自定义。"""
420428 if kwargs.get(' video_grid_thw' ) is not None :
421429 kwargs[' use_audio_in_video' ] = self .use_audio_in_video
422430 return super ().generate(model, * args, ** kwargs)
@@ -432,8 +440,8 @@ register_template(
432440
433441if __name__ == ' __main__' :
434442 # 测试与debug
435- model, processor = get_model_tokenizer (' Qwen/Qwen2.5-Omni-7B' , model_type = ' my_qwen2_5_omni' )
436- template = get_template(' my_qwen2_5_omni' , processor )
443+ model, processor = get_model_processor (' Qwen/Qwen2.5-Omni-7B' , model_type = ' my_qwen2_5_omni' )
444+ template = get_template(processor, template_type = ' my_qwen2_5_omni' )
437445 data = {
438446 ' messages' : [
439447 {' role' : ' user' , ' content' : ' 描述视频<video>与图片<image>内容。' },
@@ -451,14 +459,14 @@ if __name__ == '__main__':
451459
452460
453461## 推理对齐
454- 接下来,你需要进行PtEngine与transformers的推理对齐 。通常你需要对齐` input_ids ` 以及输出内容。你可以书写以下测试函数:
462+ 接下来,你需要进行TransformersEngine与transformers的推理对齐 。通常你需要对齐` input_ids ` 以及输出内容。你可以书写以下测试函数:
455463
456464``` python
457465import os
458466from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
459467from qwen_omni_utils import process_mm_info
460468from modelscope import snapshot_download
461- from swift.llm import PtEngine , InferRequest, RequestConfig
469+ from swift.infer_engine import TransformersEngine , InferRequest, RequestConfig
462470import requests
463471
464472def infer_hf ():
@@ -494,7 +502,7 @@ def infer_hf():
494502 return inputs[' input_ids' ][0 ].tolist(), text[0 ]
495503
496504def test_my_qwen2_5_omni ():
497- engine = PtEngine (' Qwen/Qwen2.5-Omni-7B' , model_type = ' my_qwen2_5_omni' , attn_impl = ' flash_attention_2' )
505+ engine = TransformersEngine (' Qwen/Qwen2.5-Omni-7B' , model_type = ' my_qwen2_5_omni' , attn_impl = ' flash_attention_2' )
498506 infer_request = InferRequest(messages = [{
499507 " role" : " user" ,
500508 " content" : " <video><image>描述视频和图像。" ,
@@ -503,14 +511,14 @@ def test_my_qwen2_5_omni():
503511 images = [" http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png" ],
504512 )
505513 request_config = RequestConfig(temperature = 0 , max_tokens = 512 )
506- input_ids = engine.default_template .encode(infer_request)[' input_ids' ]
514+ input_ids = engine.template .encode(infer_request)[' input_ids' ]
507515 resp_list = engine.infer([infer_request], request_config)
508516 resp = resp_list[0 ].choices[0 ].message.content
509517 return input_ids, resp
510518
511519
512520if __name__ == ' __main__' :
513- # 开启debug模式,会打印`PtEngine .infer`的input_ids和generate_ids
521+ # 开启debug模式,会打印`TransformersEngine .infer`的input_ids和generate_ids
514522 os.environ[' SWIFT_DEBUG' ] = ' 1'
515523 input_ids_hf, response_hf = infer_hf()
516524 input_ids_swift, response_swift = test_my_qwen2_5_omni()
@@ -524,13 +532,13 @@ if __name__ == '__main__':
524532
525533使用python代码训练,这通常更容易debug:
526534``` python
527- from swift.llm import sft_main, TrainArguments
535+ from swift import sft_main, SftArguments
528536import os
529537if __name__ == ' __main__' :
530538 os.environ[' MAX_PIXELS' ] = ' 1003520'
531- sft_main(TrainArguments (
539+ sft_main(SftArguments (
532540 model = ' Qwen/Qwen2.5-Omni-7B' ,
533- dataset = ' AI-ModelScope/LaTeX_OCR#5000' ,
541+ dataset = [ ' AI-ModelScope/LaTeX_OCR#5000' ] ,
534542 model_type = ' my_qwen2_5_omni' ,
535543 template = ' my_qwen2_5_omni' ,
536544 load_from_cache_file = True ,
@@ -545,7 +553,7 @@ if __name__ == '__main__':
545553 learning_rate = 1e-4 ,
546554 lora_rank = 8 ,
547555 lora_alpha = 32 ,
548- target_modules = ' all-linear' ,
556+ target_modules = [ ' all-linear' ] ,
549557 freeze_vit = True ,
550558 freeze_aligner = True ,
551559 gradient_accumulation_steps = 1 ,
@@ -574,7 +582,7 @@ swift sft \
574582 --model Qwen/Qwen2.5-Omni-7B \
575583 --model_type my_qwen2_5_omni \
576584 --template my_qwen2_5_omni \
577- --custom_register_path ' examples/custom/my_qwen2_5_omni/my_register.py' \
585+ --external_plugins ' examples/custom/my_qwen2_5_omni/my_register.py' \
578586 --dataset ' AI-ModelScope/alpaca-gpt4-data-zh#2000' \
579587 ' AI-ModelScope/LaTeX_OCR:human_handwrite#2000' \
580588 ' speech_asr/speech_asr_aishell1_trainsets:validation#2000' \
@@ -586,7 +594,7 @@ swift sft \
586594 --attn_impl flash_attn \
587595 --padding_free true \
588596 --packing true \
589- --num_train_epochs 1 \
597+ --num_train_epochs 3 \
590598 --per_device_train_batch_size 1 \
591599 --per_device_eval_batch_size 1 \
592600 --learning_rate 1e-4 \
@@ -618,7 +626,7 @@ MAX_PIXELS=1003520 \
618626swift infer \
619627 --adapters output/vx-xxx/checkpoint-xxx \
620628 --stream true \
621- --max_new_tokens 2048 \
629+ --max_new_tokens 512 \
622630 --load_data_args true
623631```
624632
0 commit comments