Skip to content

Commit 4c2acf1

Browse files
增加了asr时的递归搜索和对funasr的支持 (#21)
* add support for recursive search * add funasr support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 12c2405 commit 4c2acf1

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

fish_audio_preprocess/cli/transcribe.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,23 @@ def replace_lastest(string, old, new):
3232
)
3333
@click.option(
3434
"--model-size",
35-
help="whisper model size",
35+
help="whisper model size or funasr",
3636
default="tiny",
3737
show_default=True,
3838
type=str,
3939
)
40-
def transcribe(input_dir, num_workers, lang, model_size):
40+
@click.option(
41+
"--recursive/--no-recursive",
42+
default=False,
43+
help="Search recursively",
44+
)
45+
def transcribe(
46+
input_dir: str,
47+
num_workers: int,
48+
lang: str,
49+
model_size: str,
50+
recursive: bool,
51+
):
4152
"""
4253
Transcribe audio files in a directory.
4354
"""
@@ -49,7 +60,7 @@ def transcribe(input_dir, num_workers, lang, model_size):
4960

5061
logger.info(f"Transcribing audio files in {input_dir}")
5162
# 扫描出所有的音频文件
52-
audio_files = list_files(input_dir)
63+
audio_files = list_files(input_dir, recursive=recursive)
5364
audio_files = [str(file) for file in audio_files if file.suffix in AUDIO_EXTENSIONS]
5465

5566
if len(audio_files) == 0:

fish_audio_preprocess/utils/transcribe.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,30 @@
1010

1111

1212
def batch_transcribe(files: list[Path], model_size: str, lang: str, pos: int):
13-
import whisper
14-
15-
model = whisper.load_model(model_size)
1613
results = {}
17-
for file in tqdm(files, position=pos):
18-
if lang in PROMPT:
19-
result = model.transcribe(file, language=lang, initial_prompt=PROMPT[lang])
20-
else:
21-
result = model.transcribe(file, language=lang)
22-
results[str(file)] = result["text"]
14+
if "funasr" not in model_size:
15+
import whisper
16+
17+
model = whisper.load_model(model_size)
18+
for file in tqdm(files, position=pos):
19+
if lang in PROMPT:
20+
result = model.transcribe(
21+
file, language=lang, initial_prompt=PROMPT[lang]
22+
)
23+
else:
24+
result = model.transcribe(file, language=lang)
25+
results[str(file)] = result["text"]
26+
else:
27+
from funasr import AutoModel
28+
29+
model = AutoModel(
30+
model="paraformer-zh",
31+
punc_model="ct-punc",
32+
log_level="ERROR",
33+
disable_pbar=True,
34+
)
35+
for file in tqdm(files, position=pos):
36+
result = model.generate(input=file, batch_size_s=300)
37+
results[str(file)] = result[0]["text"]
38+
2339
return results

0 commit comments

Comments
 (0)