Skip to content

Commit 1f410bb

Browse files
authored
Merge pull request #225 from nuaalixu/for-merge-aispeech-asr
Add:An example, aispeech_asr, and a dataset, speech_dataset_large, have been added and supporting multi-machine multi-GPU decoding
2 parents aa9ac13 + 990709a commit 1f410bb

25 files changed

+1882
-57
lines changed

examples/aispeech_asr/README.md

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# AISPEECH_ASR
2+
3+
## Overview
4+
5+
This example is designed for large-scale industrial data training, suitable for datasets on the order of 100,000 hours. Its main features include:
6+
- **Support for multi-task training**: Designed to support tasks such as ASR and ST through a unified data format.
7+
- **Dynamic prompt selection**: Supports random selection from multiple prompts.
8+
- **Iterative dataset**: Uses an iterative dataset format to reduce startup time for large datasets.
9+
- **Deepspeed training**: Supports DeepSpeed training to significantly reduce memory usage.
10+
- **Multi-machine multi-GPU inference**: Supports distributed inference across multiple machines and GPUs to reduce evaluation time.
11+
- **Dynamic frame batching**: Dynamically combines frames based on audio size rather than using a fixed batch size, significantly reducing training and evaluation time (reduces training time by 3/4 for 100,000 hours of data).
12+
13+
This example is modified from `mala_asr_slidespeech`.
14+
15+
## Model Architecture
16+
17+
The model architecture can be dynamically selected within the scope supported by SLAM-LMM. Below are some recommended configurations:
18+
- **Encoder**: WavLM, Whisper
19+
- **Projector**: Linear
20+
- **LLM**: Qwen2.5-7B-Instruct, Vicuna1.5-7B
21+
22+
## Data Preparation
23+
24+
The following two files are required:
25+
- `multitask.jsonl`
26+
- `multiprompt.jsonl`
27+
28+
### multitask.jsonl
29+
30+
The format of this file is as follows, where `path` supports both ark format and wav files:
31+
```json
32+
{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"}
33+
{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"}
34+
{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"}
35+
{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"}
36+
```
37+
38+
### multiprompt.jsonl
39+
40+
The format of this file is as follows:
41+
```json
42+
{"task": "ASR", "prompt": "Transcribe speech to text."}
43+
{"task": "ASR", "prompt": "请识别语音."}
44+
{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"}
45+
{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"}
46+
{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."}
47+
{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."}
48+
```
49+
50+
### Notes
51+
- If multiple prompts are provided, one will be selected dynamically.
52+
- For additional information (e.g., hotwords), include the task-named information in `multitask.jsonl` and use `{}` in the prompt to inject this information. Additionally, update the `append_info_tasks` in the `aispeech_config` file:
53+
```python
54+
append_info_tasks: List = field(default_factory=lambda: ["hotword"])
55+
```
56+
57+
## Training a New Model
58+
59+
### Script Preparation
60+
61+
Prepare and modify the following content in `scripts/finetune_deepspeed.sh` or `scripts/finetune_torchrun.sh` (Deepspeed is recommended):
62+
```bash
63+
run_dir= # Directory to save the model
64+
train_scp_file_path= # Path to training data
65+
dev_scp_file_path= # Path to validation data
66+
train_max_frame_length=1500 # Maximum frame length for training
67+
eval_max_frame_length=1000 # Maximum frame length for evaluation
68+
multitask_prompt_path= # Path to multitask.jsonl
69+
prompt_style="\{\}" # Prompt style, e.g., "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" or "USER: {}\n ASSISTANT:"
70+
projector=linear # Type of projector
71+
encoder_name=whisper # Name of the encoder
72+
llm_name=Qwen2.5-7B-Instruct # Name of the LLM
73+
use_peft=false # Whether to use PEFT (for LLM)
74+
use_fp16=true # Whether to use FP16
75+
freeze_encoder=true # Whether to freeze the encoder
76+
pad_or_trim=true # Whether to use pad_or_trim (for Whisper)
77+
deepspeed_config= # Path to DeepSpeed configuration file
78+
```
79+
80+
Typically, we first train the projector and then fine-tune the LoRA. For projector training, set:
81+
```bash
82+
use_peft=false
83+
```
84+
85+
For LoRA training, set (with `ckpt_path` pointing to the model saved in the previous step):
86+
```bash
87+
use_peft=true
88+
if [[ $use_peft == "true" ]]; then
89+
ckpt_path= # For DDP training, provide the path to the saved pt file; for DeepSpeed training, convert mp_rank_00_model_states.pt to model.pt using the `scripts/transcribe_deepspeed_to_pt.py` script
90+
fi
91+
```
92+
### Deepspeed
93+
When using `bf16`/`fp16` for training, deepspeed saves about 20GB of GPU memory compared to `torchrun` when training a 7B model. For 7B models, it's recommended to use `zero-0`/`1`/`2`, while for extremely large models, `zero-3` can be used, though communication may become a bottleneck.
94+
95+
```json
96+
{
97+
"train_micro_batch_size_per_gpu": 4,
98+
"gradient_accumulation_steps": 1,
99+
"optimizer": {
100+
"type": "Adam",
101+
"params": {
102+
"lr": 1e-4
103+
}
104+
},
105+
"fp16": {
106+
"enabled": true
107+
},
108+
"zero_optimization": {
109+
"stage": 2,
110+
"offload_optimizer": {
111+
"device": "cpu"
112+
}
113+
}
114+
}
115+
```
116+
117+
Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python scripts/transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`.
118+
119+
```
120+
global_step1000
121+
global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
122+
...
123+
global_step1000/mp_rank_00_model_states.pt
124+
latest
125+
zero_to_fp32.py
126+
```
127+
128+
If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`.
129+
130+
```
131+
global_step50
132+
global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
133+
global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt
134+
...
135+
latest
136+
zero_to_fp32.py
137+
```
138+
If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`:
139+
140+
```python
141+
with autocast() # original code
142+
with autocast(dtype=torch.bfloat16) # must work
143+
with autocast(dtype=torch.float16)
144+
```
145+
## Decoding
146+
147+
- **Single-machine single-GPU decoding**: Refer to `scripts/decode.sh`
148+
- **Single-machine multi-GPU decoding**: Refer to `scripts/decode_deepspeed.sh`
149+
150+
## Multi-Machine Multi-GPU Support
151+
152+
Multi-machine multi-GPU training can be supported with minor modifications to the `finetune_deepspeed.sh` or `scripts/decode_deepspeed.sh` scripts. Due to environment-specific requirements, this example does not include dedicated scripts for multi-machine multi-GPU setups.

examples/aispeech_asr/README_zh.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# AISPEECH_ASR
2+
3+
## 概述
4+
5+
这是为工业界大规模数据训练准备的示例,适用于10万小时量级的数据训练,主要特点如下:
6+
- **多任务训练支持**:通过设计数据格式,支持包括ASR、ST等多种任务。
7+
- **动态Prompt选择**:支持在多个Prompt中随机选择。
8+
- **迭代式dataset**:采用迭代形式的dataset,减少大数据量时的启动时间。
9+
- **Deepspeed训练**:支持Deepspeed训练,显著减少内存使用。
10+
- **多机多卡推理**:支持多机多卡推理,减少评估时间。
11+
- **动态帧数组合**:根据每个音频大小动态组合合适的帧数进行训练,而非使用固定的batch_size,大大减少了训练和评估时间(在10万小时量级的数据上,训练时间减少了3/4)。
12+
13+
本示例基于`mala_asr_slidespeech`进行修改。
14+
15+
## 模型架构
16+
17+
可以根据需要,在SLAM—LMM支持的范围内动态选择模型架构。以下是一些推荐的模型配置:
18+
- **Encoder**:WavLM, Whisper
19+
- **Projector**:Linear
20+
- **LLM**:Qwen2.5-7B-Instruct, Vicuna1.5-7B
21+
22+
## 数据准备
23+
24+
需要准备以下两个文件:
25+
- `multitask.jsonl`
26+
- `multiprompt.jsonl`
27+
28+
### multitask.jsonl
29+
30+
该文件的内容格式如下,其中`path`支持ark格式和wav文件:
31+
```json
32+
{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"}
33+
{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"}
34+
{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"}
35+
{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"}
36+
```
37+
38+
### multiprompt.jsonl
39+
40+
该文件的内容格式如下:
41+
```json
42+
{"task": "ASR", "prompt": "Transcribe speech to text."}
43+
{"task": "ASR", "prompt": "请识别语音."}
44+
{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"}
45+
{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"}
46+
{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."}
47+
{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."}
48+
```
49+
50+
### 注意事项
51+
- 如果有多条Prompt,会动态选择其中一条。
52+
- 如果有额外信息(如热词),请在`multitask.jsonl`中提供与任务同名的信息,并在Prompt中使用`{}`注入该信息。同时,修改`aispeech_config`文件中的`append_info_tasks`
53+
```python
54+
append_info_tasks: List = field(default_factory=lambda: ["hotword"])
55+
```
56+
57+
## 训练新模型
58+
59+
### 脚本准备
60+
61+
`scripts/finetune_deepspeed.sh``scripts/finetune_torchrun.sh`中准备并修改以下内容(推荐使用Deepspeed):
62+
```bash
63+
run_dir= # 模型保存目录
64+
train_scp_file_path= # 训练数据路径
65+
dev_scp_file_path= # 验证数据路径
66+
train_max_frame_length=1500 # 训练时的最大帧长度
67+
eval_max_frame_length=1000 # 评估时的最大帧长度
68+
multitask_prompt_path= # multitask.jsonl文件路径
69+
prompt_style="\{\}" # Prompt样式,可选格式如"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"或"USER: {}\n ASSISTANT:"
70+
projector=linear # Projector类型
71+
encoder_name=whisper # Encoder名称
72+
llm_name=Qwen2.5-7B-Instruct # LLM名称
73+
use_peft=false # 是否使用PEFT(对于LLM)
74+
use_fp16=true # 是否使用FP16
75+
freeze_encoder=true # 是否冻结Encoder
76+
pad_or_trim=true # 是否使用pad_or_trim(对于Whisper)
77+
deepspeed_config= # DeepSpeed配置文件路径
78+
```
79+
80+
通常,我们首先训练Projector,然后再训练LoRA。训练Projector时,设置如下:
81+
```bash
82+
use_peft=false
83+
```
84+
85+
训练LoRA时,设置如下(`ckpt_path`是上一步训练保存的模型路径):
86+
```bash
87+
use_peft=true
88+
if [[ $use_peft == "true" ]]; then
89+
ckpt_path= # 如果是DDP训练,直接写入保存的pt文件路径;如果是Deepspeed训练,需将mp_rank_00_model_states.pt文件转化为model.pt,可使用`scripts/transcribe_deepspeed_to_pt.py`脚本
90+
fi
91+
```
92+
93+
## 解码
94+
95+
- **单机单卡解码**:参考`scripts/decode.sh`
96+
- **单机多卡解码**:参考`scripts/decode_deepspeed.sh`
97+
98+
## 多机多卡支持
99+
简单修改脚本finetune_deepspeed.sh 或者scripts/decode_deepspeed.sh`后可以支持多机多卡训练,因为环境不同所做的修改也不同,本实例就不放出多机多卡的脚本了
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, List
3+
from torch.distributed.fsdp import ShardingStrategy
4+
5+
6+
@dataclass
7+
class ModelConfig:
8+
file: str = "examples/aispeech_asr/model/aispeech_asr.py:model_factory"
9+
llm_name: str = "vicuna-7b-v1.5"
10+
llm_path: str = "PATH/to/LLAMA/7B"
11+
llm_type: str = "decoder_only"
12+
llm_dim: int = 4096
13+
whisper_decode : Optional[bool] = False
14+
encoder_name: Optional[str] = None
15+
encoder_ds_rate: int = 2
16+
encoder_path: Optional[str] = None
17+
encoder_path_hf: Optional[str] = None
18+
encoder_dim: int = 1280
19+
encoder_projector: str = "linear"
20+
qformer_layers : int = 8
21+
encoder_projector_ds_rate: int = 5
22+
modal: str = "audio"
23+
normalize: Optional[bool] = field(default=False, metadata={
24+
"help": "whether input is normalized, used for models such as wavlm"
25+
})
26+
encoder_type: str = field(default="finetune", metadata={
27+
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
28+
})
29+
30+
31+
@dataclass
32+
class PeftConfig:
33+
peft_method: str = "lora" # None , llama_adapter, prefix
34+
r: int = 64
35+
lora_alpha: int = 16
36+
target_modules: List = field(default_factory=lambda: [ "q_proj","k_proj", "v_proj", "o_proj", "up_proj","gate_proj","down_proj"])
37+
bias: str = "none"
38+
task_type: str = "CAUSAL_LM"
39+
lora_dropout: float = 0.05
40+
inference_mode: bool = False
41+
42+
@dataclass
43+
class TrainConfig:
44+
model_name:str = "PATH/to/LLAMA/7B"
45+
enable_ddp:bool = False
46+
enable_deepspeed:bool = False
47+
enable_fsdp:bool = False
48+
low_cpu_fsdp:bool = False
49+
run_validation:bool = True
50+
batch_size_training: Optional[int] = None
51+
batching_strategy:str = field(default="packing", metadata={
52+
"help":"alternative: padding"
53+
}) #
54+
context_length:int = 4096
55+
gradient_accumulation_steps:int = 1
56+
num_epochs:int = 3
57+
num_workers_dataloader:int = 1
58+
warmup_steps:int = 1000
59+
total_steps:int = 100000
60+
validation_interval:int = 1000
61+
lr:float = 1e-4
62+
weight_decay:float = 0.0
63+
gamma:float = 0.85
64+
seed:int = 42
65+
use_fp16:bool = False
66+
mixed_precision:bool = True
67+
val_batch_size:Optional[int] = None
68+
69+
use_peft:bool = False
70+
peft_config:PeftConfig = field(default_factory=PeftConfig)
71+
output_dir:str = "PATH/to/save/PEFT/model"
72+
freeze_layers:bool = False
73+
num_freeze_layers:int = 1
74+
quantization:bool = False
75+
one_gpu:bool = False
76+
save_model:bool = True
77+
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
78+
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
79+
save_optimizer:bool = False # will be used if using FSDP
80+
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
81+
run_test_during_validation:bool = False
82+
run_test_during_validation_file:str = "test.wav"
83+
run_test_during_validation_prompt:str = "<|ASR|>"
84+
freeze_llm:bool = field(default=False, metadata={
85+
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
86+
})
87+
freeze_encoder:bool = False
88+
89+
@dataclass
90+
class DataConfig:
91+
dataset: str = "multitask_dataset"
92+
train_max_frame_length: int = 1500
93+
eval_max_frame_length: int = 1000
94+
multitask_prompt_path: str = "/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl"
95+
prompt_style: str = "\{\}" #
96+
append_info_tasks : List = field(default_factory=lambda: [ "hotword"])
97+
file: str = "examples/aispeech_asr/slam_llm/datasets/speech_dataset_large.py:get_speech_dataset"
98+
train_scp_file_path: str = ""
99+
dev_scp_file_path: str = ""
100+
test_scp_file_path: str = ""
101+
train_split: str = "train"
102+
dev_split: str = "dev"
103+
test_split:str = "test"
104+
pad_or_trim: bool = True
105+
prompt: Optional[str] = None
106+
use_ocr: bool = True
107+
inference_mode: bool = False
108+
lower: bool = False
109+
fix_length_audio: int = -1
110+
inference_mode:bool = False
111+
input_type: str = field(default="raw", metadata={
112+
"help":"Use raw when input is wav, mel when for whisper"
113+
})
114+
mel_size: int = field(default=80, metadata={
115+
"help": "80 for whisper large v1 and v2, 128 for v3"
116+
})
117+
normalize: Optional[bool] = field(default=False, metadata={
118+
"help": "whether input is normalized, used for models such as wavlm"
119+
})
120+
121+
@dataclass
122+
class FSDPConfig:
123+
mixed_precision: bool = True
124+
use_fp16: bool = False
125+
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
126+
sharding_strategy: ShardingStrategy = "SHARD_GRAD_OP" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
127+
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
128+
fsdp_activation_checkpointing: bool = True
129+
fsdp_cpu_offload: bool = False
130+
pure_bf16: bool = False
131+
optimizer: str = "AdamW"
132+
133+
@dataclass
134+
class LogConfig:
135+
use_wandb: bool = False
136+
wandb_dir: str = "tmp/test_wandb"
137+
wandb_entity_name: str = "project_name"
138+
wandb_project_name: str = "project_name"
139+
wandb_exp_name: str = "exp_name"
140+
log_file: str = "tmp/test.log"
141+
log_interval: int = 5

0 commit comments

Comments
 (0)