26
26
import torch .distributed as dist
27
27
import torch .multiprocessing as mp
28
28
from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
29
+ from torch ._C import _SDPBackend as SDPBackend
29
30
30
31
from PIL import Image
31
32
@@ -531,6 +532,7 @@ def decode_n_tokens(
531
532
callback = lambda _ : _ ,
532
533
eos_token_id : int = 2 ,
533
534
eot_id : Optional [int ] = None ,
535
+ attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
534
536
** sampling_kwargs ,
535
537
):
536
538
new_tokens , new_probs = [], []
@@ -539,7 +541,7 @@ def decode_n_tokens(
539
541
num_new_tokens - 1
540
542
): # -1 to save space to run an EoS if dont generate it naturally
541
543
# Actually better for Inductor to codegen attention here
542
- with torch .nn .attention .sdpa_kernel ([torch . nn . attention . SDPBackend . MATH ]):
544
+ with torch .nn .attention .sdpa_kernel ([attention_backend ]):
543
545
544
546
out_token = cur_token .clone ()
545
547
next_token , next_prob = self .decode_one_token (
@@ -683,6 +685,7 @@ def generate(
683
685
sequential_prefill = True ,
684
686
callback = lambda x : x ,
685
687
max_seq_length : int ,
688
+ attention_backend : str = "math" ,
686
689
seed : Optional [int ] = None ,
687
690
** sampling_kwargs ,
688
691
) -> torch .Tensor :
@@ -799,6 +802,7 @@ def generate(
799
802
if self .is_llama3_model
800
803
else None
801
804
),
805
+ attention_backend = attention_backend ,
802
806
** sampling_kwargs ,
803
807
):
804
808
generated_tokens .append (generated_token .view (- 1 ))
@@ -1186,6 +1190,7 @@ def callback(x, *, done_generating=False):
1186
1190
start_pos = start_pos ,
1187
1191
skip_cache_setup = not is_first_sample ,
1188
1192
max_seq_length = max_seq_length ,
1193
+ attention_backend = self .builder_args .attention_backend ,
1189
1194
)
1190
1195
for token_tensor , metrics in generator_func :
1191
1196
if token_tensor is not None :
0 commit comments