|
13 | 13 | import torch._dynamo.config
|
14 | 14 | import torch._inductor.config
|
15 | 15 | from loguru import logger
|
16 |
| -from torch.nn.attention import SDPBackend, sdpa_kernel |
17 | 16 | from tqdm import tqdm
|
18 | 17 | from transformers import AutoTokenizer
|
19 | 18 |
|
|
24 | 23 | TextPart,
|
25 | 24 | VQPart,
|
26 | 25 | )
|
27 |
| -from fish_speech.models.text2semantic.llama import ( |
28 |
| - BaseModelArgs, |
29 |
| - BaseTransformer, |
30 |
| - DualARTransformer, |
31 |
| - NaiveTransformer, |
32 |
| -) |
| 26 | +from fish_speech.models.text2semantic.llama import BaseModelArgs |
33 | 27 | from fish_speech.text import clean_text, split_text
|
34 | 28 | from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
35 | 29 |
|
|
42 | 36 | torch._inductor.config.fx_graph_cache = True
|
43 | 37 |
|
44 | 38 |
|
| 39 | +from torch.nn.attention import SDPBackend, sdpa_kernel |
| 40 | + |
| 41 | +from fish_speech.models.text2semantic.llama import ( |
| 42 | + BaseTransformer, |
| 43 | + DualARTransformer, |
| 44 | + NaiveTransformer, |
| 45 | +) |
| 46 | + |
| 47 | + |
45 | 48 | def multinomial_sample_one_no_sync(
|
46 | 49 | probs_sort,
|
47 | 50 | ): # Does multinomial sampling without a cuda synchronization
|
@@ -369,12 +372,8 @@ def decode_n_tokens(
|
369 | 372 | window = previous_tokens[:, i - win_size : i]
|
370 | 373 |
|
371 | 374 | with (
|
372 |
| - sdpa_kernel( |
373 |
| - [ |
374 |
| - SDPBackend.FLASH_ATTENTION, |
375 |
| - SDPBackend.EFFICIENT_ATTENTION, |
376 |
| - SDPBackend.MATH, |
377 |
| - ] |
| 375 | + torch.backends.cuda.sdp_kernel( |
| 376 | + enable_flash=False, enable_mem_efficient=False, enable_math=True |
378 | 377 | )
|
379 | 378 | if torch.cuda.is_available()
|
380 | 379 | else nullcontext()
|
|
0 commit comments