Skip to content

Commit 3eb2f32

Browse files
Fix inference speed. (#928)
* Fix inference speed. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cddfd4d commit 3eb2f32

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

fish_speech/models/text2semantic/inference.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515
from loguru import logger
16-
from torch.nn.attention import SDPBackend, sdpa_kernel
1716
from tqdm import tqdm
1817
from transformers import AutoTokenizer
1918

@@ -24,12 +23,7 @@
2423
TextPart,
2524
VQPart,
2625
)
27-
from fish_speech.models.text2semantic.llama import (
28-
BaseModelArgs,
29-
BaseTransformer,
30-
DualARTransformer,
31-
NaiveTransformer,
32-
)
26+
from fish_speech.models.text2semantic.llama import BaseModelArgs
3327
from fish_speech.text import clean_text, split_text
3428
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
3529

@@ -42,6 +36,15 @@
4236
torch._inductor.config.fx_graph_cache = True
4337

4438

39+
from torch.nn.attention import SDPBackend, sdpa_kernel
40+
41+
from fish_speech.models.text2semantic.llama import (
42+
BaseTransformer,
43+
DualARTransformer,
44+
NaiveTransformer,
45+
)
46+
47+
4548
def multinomial_sample_one_no_sync(
4649
probs_sort,
4750
): # Does multinomial sampling without a cuda synchronization
@@ -369,12 +372,8 @@ def decode_n_tokens(
369372
window = previous_tokens[:, i - win_size : i]
370373

371374
with (
372-
sdpa_kernel(
373-
[
374-
SDPBackend.FLASH_ATTENTION,
375-
SDPBackend.EFFICIENT_ATTENTION,
376-
SDPBackend.MATH,
377-
]
375+
torch.backends.cuda.sdp_kernel(
376+
enable_flash=False, enable_mem_efficient=False, enable_math=True
378377
)
379378
if torch.cuda.is_available()
380379
else nullcontext()

0 commit comments

Comments
 (0)