Skip to content

Commit ee6f7f2

Browse files
🐞 fix(transcribe): 修复 funasr 存在的一堆问题 (#22)
* 🐞 fix(transcribe): 修复 funasr 存在的一堆问题 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4c2acf1 commit ee6f7f2

File tree

4 files changed

+60
-15
lines changed

4 files changed

+60
-15
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ This repo contains some scripts for audio processing. Main features include:
1313
- [x] Audio data statistics (supports determining audio length)
1414
- [x] Audio resampling
1515
- [x] Audio transcribe (.lab)
16+
- [x] Audio transcribe via FunASR (use `--model-type funasr` to enable, detailed usage can be found at code)
1617
- [ ] Audio transcribe via WhisperX
1718

1819
([ ] indicates not completed, [x] indicates completed)
@@ -22,11 +23,10 @@ This repo contains some scripts for audio processing. Main features include:
2223
## Getting Started:
2324

2425
```
25-
pip install -e .
26+
pip install -e .
2627
fap --help
2728
```
2829

2930
## Reference
3031

3132
- [Batch Whisper](https://github.com/Blair-Johnson/batch-whisper)
32-

README.zh.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
- [x] 音频数据统计(支持判断音频长度)
1414
- [x] 音频重采样
1515
- [x] 音频打标 (.lab)
16+
- [x] 音频打标 FunASR(使用 `--model-type funasr` 开启, 详细使用方法可查看代码)
17+
- [ ] 音频打标 WhisperX
1618

1719
([ ] 表示未完成, [x] 表示已完成)
1820

@@ -21,11 +23,10 @@
2123
## 上手指南:
2224

2325
```
24-
pip install -e .
26+
pip install -e .
2527
fap --help
2628
```
2729

2830
## 引用
2931

3032
- [Batch Whisper](https://github.com/Blair-Johnson/batch-whisper)
31-

fish_audio_preprocess/cli/transcribe.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm import tqdm
88

99
from fish_audio_preprocess.utils.file import AUDIO_EXTENSIONS, list_files, split_list
10-
from fish_audio_preprocess.utils.transcribe import batch_transcribe
10+
from fish_audio_preprocess.utils.transcribe import ASRModelType, batch_transcribe
1111

1212

1313
def replace_lastest(string, old, new):
@@ -32,8 +32,9 @@ def replace_lastest(string, old, new):
3232
)
3333
@click.option(
3434
"--model-size",
35-
help="whisper model size or funasr",
36-
default="tiny",
35+
# whisper 默认 medium, funasr 默认 paraformer-zh
36+
help="asr model size(default medium for whisper, paraformer-zh for funasr)",
37+
default="medium",
3738
show_default=True,
3839
type=str,
3940
)
@@ -42,22 +43,41 @@ def replace_lastest(string, old, new):
4243
default=False,
4344
help="Search recursively",
4445
)
46+
@click.option(
47+
"--model-type",
48+
help="ASR model type (funasr or whisper)",
49+
default="whisper",
50+
show_default=True,
51+
)
4552
def transcribe(
4653
input_dir: str,
4754
num_workers: int,
4855
lang: str,
4956
model_size: str,
5057
recursive: bool,
58+
model_type: ASRModelType,
5159
):
5260
"""
5361
Transcribe audio files in a directory.
5462
"""
63+
ctx = click.get_current_context()
64+
provided_options = {
65+
key: value
66+
for key, value in ctx.params.items()
67+
if ctx.get_parameter_source(key) == click.core.ParameterSource.COMMANDLINE
68+
}
69+
70+
# 如果是 funasr 且没有提供 model_size, 则默认为 paraformer-zh
71+
if model_type == "funasr" and "model_size" not in provided_options:
72+
logger.info("Using paraformer-zh model for funasr as default")
73+
model_size = "paraformer-zh"
74+
5575
if not torch.cuda.is_available():
5676
logger.warning(
5777
"CUDA is not available, using CPU. This will be slow and even this script can not work. "
5878
"To speed up, use a GPU enabled machine or install torch with cuda builtin."
5979
)
60-
80+
logger.info(f"Using {num_workers} workers for processing")
6181
logger.info(f"Transcribing audio files in {input_dir}")
6282
# 扫描出所有的音频文件
6383
audio_files = list_files(input_dir, recursive=recursive)
@@ -78,6 +98,7 @@ def transcribe(
7898
batch_transcribe,
7999
files=chunk,
80100
model_size=model_size,
101+
model_type=model_type,
81102
lang=lang,
82103
pos=len(tasks),
83104
)
+30-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from pathlib import Path
2+
from typing import Literal
23

4+
from loguru import logger
35
from tqdm import tqdm
46

57
PROMPT = {
@@ -8,12 +10,21 @@
810
"jp": "先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読から自動運転車の操縦まで、AIの応用は広範囲に及びます。",
911
}
1012

13+
ASRModelType = Literal["funasr", "whisper"]
1114

12-
def batch_transcribe(files: list[Path], model_size: str, lang: str, pos: int):
15+
16+
def batch_transcribe(
17+
files: list[Path],
18+
model_size: str,
19+
model_type: ASRModelType,
20+
lang: str,
21+
pos: int,
22+
):
1323
results = {}
14-
if "funasr" not in model_size:
24+
if model_type == "whisper":
1525
import whisper
1626

27+
logger.info(f"Loading {model_size} model for {lang} transcription")
1728
model = whisper.load_model(model_size)
1829
for file in tqdm(files, position=pos):
1930
if lang in PROMPT:
@@ -23,17 +34,29 @@ def batch_transcribe(files: list[Path], model_size: str, lang: str, pos: int):
2334
else:
2435
result = model.transcribe(file, language=lang)
2536
results[str(file)] = result["text"]
26-
else:
37+
elif model_type == "funasr":
2738
from funasr import AutoModel
2839

40+
logger.info(f"Loading {model_size} model for {lang} transcription")
2941
model = AutoModel(
30-
model="paraformer-zh",
42+
model=model_size,
43+
vad_model="fsmn-vad",
3144
punc_model="ct-punc",
3245
log_level="ERROR",
3346
disable_pbar=True,
3447
)
3548
for file in tqdm(files, position=pos):
36-
result = model.generate(input=file, batch_size_s=300)
37-
results[str(file)] = result[0]["text"]
38-
49+
if lang in PROMPT:
50+
result = model.generate(
51+
input=file, batch_size_s=300, hotword=PROMPT[lang]
52+
)
53+
else:
54+
result = model.generate(input=file, batch_size_s=300)
55+
# print(result)
56+
if isinstance(result, list):
57+
results[str(file)] = "".join([item["text"] for item in result])
58+
else:
59+
results[str(file)] = result["text"]
60+
else:
61+
raise ValueError(f"Unsupported model type: {model_type}")
3962
return results

0 commit comments

Comments
 (0)