Skip to content

Commit ef58fce

Browse files
yanbing-jJack-Khuu
authored and
vmpuri
committed
Add attention_backend as a configurable option (#1456)
bump this into the constructor of BuilderArgs Co-authored-by: Jack-Khuu <[email protected]>
1 parent cd10377 commit ef58fce

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

torchchat/cli/builder.py

+13
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class BuilderArgs:
6969
prefill_possible: bool = False
7070
dynamic_shapes: bool = False
7171
max_seq_length: Optional[int] = None
72+
attention_backend: str = "math"
7273

7374
def __post_init__(self):
7475
if self.device is None:
@@ -183,6 +184,17 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
183184
pp = getattr(args, "pp", 1)
184185
tp = getattr(args, "tp", 1)
185186
chpt_from = getattr(args, "chpt_from", "hf")
187+
sdp_backend_dict = {
188+
'math': torch.nn.attention.SDPBackend.MATH,
189+
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
190+
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
191+
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
192+
}
193+
attention_backend = sdp_backend_dict[args.attention_backend]
194+
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
195+
or args.attention_backend == "cudnn_attention"):
196+
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
197+
attention_backend = torch.nn.attention.SDPBackend.MATH
186198
return cls(
187199
checkpoint_dir=checkpoint_dir,
188200
checkpoint_path=checkpoint_path,
@@ -207,6 +219,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
207219
is_chat_model=is_chat_model,
208220
dynamic_shapes=getattr(args, "dynamic_shapes", False),
209221
max_seq_length=getattr(args, "max_seq_length", None),
222+
attention_backend=attention_backend,
210223
)
211224

212225
@classmethod

torchchat/cli/cli.py

+7
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ def _add_model_config_args(parser, verb: str) -> None:
179179
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180180
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
181181
)
182+
model_config_parser.add_argument(
183+
"--attention-backend",
184+
type=str,
185+
default="math",
186+
choices=["math", "flash_attention", "efficient_attention", "cudnn_attention"],
187+
help="SDPBackend to use. Options: MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION",
188+
)
182189

183190

184191
# Add CLI Args representing output paths of exported model files

torchchat/generate.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch.distributed as dist
2727
import torch.multiprocessing as mp
2828
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
29+
from torch._C import _SDPBackend as SDPBackend
2930

3031
from PIL import Image
3132

@@ -531,6 +532,7 @@ def decode_n_tokens(
531532
callback=lambda _: _,
532533
eos_token_id: int = 2,
533534
eot_id: Optional[int] = None,
535+
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
534536
**sampling_kwargs,
535537
):
536538
new_tokens, new_probs = [], []
@@ -539,7 +541,7 @@ def decode_n_tokens(
539541
num_new_tokens - 1
540542
): # -1 to save space to run an EoS if dont generate it naturally
541543
# Actually better for Inductor to codegen attention here
542-
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
544+
with torch.nn.attention.sdpa_kernel([attention_backend]):
543545

544546
out_token = cur_token.clone()
545547
next_token, next_prob = self.decode_one_token(
@@ -683,6 +685,7 @@ def generate(
683685
sequential_prefill=True,
684686
callback=lambda x: x,
685687
max_seq_length: int,
688+
attention_backend: str = "math",
686689
seed: Optional[int] = None,
687690
**sampling_kwargs,
688691
) -> torch.Tensor:
@@ -799,6 +802,7 @@ def generate(
799802
if self.is_llama3_model
800803
else None
801804
),
805+
attention_backend=attention_backend,
802806
**sampling_kwargs,
803807
):
804808
generated_tokens.append(generated_token.view(-1))
@@ -1186,6 +1190,7 @@ def callback(x, *, done_generating=False):
11861190
start_pos=start_pos,
11871191
skip_cache_setup=not is_first_sample,
11881192
max_seq_length=max_seq_length,
1193+
attention_backend=self.builder_args.attention_backend,
11891194
)
11901195
for token_tensor, metrics in generator_func:
11911196
if token_tensor is not None:

0 commit comments

Comments
 (0)