Skip to content

Commit f3fa465

Browse files
author
pytorchbot
committed
2025-02-27 nightly release (4f0bb6f)
1 parent 4d003e2 commit f3fa465

14 files changed

+578
-767
lines changed

.github/workflows/build-test-linux.yml

+2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ jobs:
140140
python -m pip install -r requirements.txt
141141
cd dynamo
142142
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
143+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py
144+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
143145
popd
144146
145147
tests-py-dynamo-fe:

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+2-32
Original file line numberDiff line numberDiff line change
@@ -2722,40 +2722,10 @@ def aten_ops_max_pool(
27222722
)
27232723

27242724

2725-
def attention_validator(
2726-
node: Node, settings: Optional[CompilationSettings] = None
2727-
) -> bool:
2728-
# Currently, `attn_mask` is not supported
2729-
return args_bounds_check(node.args, 3) is None
2730-
2731-
2725+
@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
27322726
@dynamo_tensorrt_converter(
2733-
torch.nn.functional.scaled_dot_product_attention,
2734-
capability_validator=attention_validator,
2735-
supports_dynamic_shapes=True,
2727+
torch.ops.aten._reshape_copy.default, supports_dynamic_shapes=True
27362728
)
2737-
def tensorrt_scaled_dot_product_attention(
2738-
ctx: ConversionContext,
2739-
target: Target,
2740-
args: Tuple[Argument, ...],
2741-
kwargs: Dict[str, Argument],
2742-
name: str,
2743-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2744-
return impl.attention.scaled_dot_product_attention(
2745-
ctx,
2746-
target,
2747-
SourceIR.TORCHTRT_LOWERED,
2748-
name,
2749-
args[0],
2750-
args[1],
2751-
args[2],
2752-
args_bounds_check(args, 5, False),
2753-
kwargs.get("scale", None),
2754-
)
2755-
2756-
2757-
@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
2758-
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
27592729
@enforce_tensor_types(
27602730
{
27612731
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
activation,
33
addmm,
44
arange,
5-
attention,
65
cast,
76
cat,
87
condition,

py/torch_tensorrt/dynamo/conversion/impl/attention.py

-165
This file was deleted.

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+132-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from enum import Enum, auto
3-
from typing import Any, Callable, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44

55
import torch
66
from torch._decomp import register_decomposition
@@ -435,6 +435,137 @@ def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
435435
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])
436436

437437

438+
@register_torch_trt_decomposition(aten.view.default, registry=TORCH_TRT_DECOMPOSITIONS)
439+
def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tensor:
440+
return aten._reshape_copy.default(x, size)
441+
442+
443+
@register_torch_trt_decomposition(
444+
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
445+
)
446+
def scaled_dot_product_attention_decomposition(
447+
query: torch.Tensor,
448+
key: torch.Tensor,
449+
value: torch.Tensor,
450+
attn_mask: Optional[torch.Tensor] = None,
451+
dropout_p: float = 0.0,
452+
is_causal: bool = False,
453+
*,
454+
scale: Optional[float] = None,
455+
enable_gqa: bool = False,
456+
) -> torch.Tensor:
457+
L, S = query.size(-2), key.size(-2)
458+
device = query.device
459+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)
460+
461+
if is_causal:
462+
assert attn_mask is None, "attn_mask must be None when is_causal=True"
463+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
464+
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
465+
466+
if attn_mask is not None:
467+
if attn_mask.dtype == torch.bool:
468+
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
469+
else:
470+
attn_bias = attn_mask + attn_bias
471+
472+
if enable_gqa:
473+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
474+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
475+
476+
attn_weight = query @ key.transpose(-2, -1)
477+
478+
if scale is None:
479+
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
480+
attn_weight = attn_weight / scale
481+
else:
482+
attn_weight = attn_weight * scale
483+
484+
attn_weight = attn_weight + attn_bias
485+
attn_weight = torch.softmax(attn_weight, dim=-1)
486+
return attn_weight @ value
487+
488+
489+
@register_torch_trt_decomposition(
490+
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
491+
)
492+
def scaled_dot_product_flash_attention_decomposition(
493+
query: torch.Tensor,
494+
key: torch.Tensor,
495+
value: torch.Tensor,
496+
dropout_p: float = 0.0,
497+
is_causal: bool = False,
498+
return_debug_mask: bool = False,
499+
*,
500+
scale: Optional[float] = None,
501+
) -> Tuple[
502+
torch.Tensor,
503+
torch.Tensor,
504+
torch.Tensor,
505+
torch.Tensor,
506+
torch.SymInt,
507+
torch.SymInt,
508+
torch.Tensor,
509+
torch.Tensor,
510+
torch.Tensor,
511+
]:
512+
attn = scaled_dot_product_attention_decomposition(
513+
query, key, value, None, dropout_p, is_causal, scale=scale
514+
)
515+
return attn, None, None, None, 0, 0, None, None, None
516+
517+
518+
@register_torch_trt_decomposition(
519+
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
520+
)
521+
def scaled_dot_product_efficient_attention_decomposition(
522+
query: torch.Tensor,
523+
key: torch.Tensor,
524+
value: torch.Tensor,
525+
attn_bias: Optional[torch.Tensor],
526+
compute_log_sumexp: bool,
527+
dropout_p: float = 0.0,
528+
is_causal: bool = False,
529+
*,
530+
scale: Optional[float] = None,
531+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
532+
attn = scaled_dot_product_attention_decomposition(
533+
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
534+
)
535+
return attn, None, None, None
536+
537+
538+
@register_torch_trt_decomposition(
539+
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
540+
)
541+
def scaled_dot_product_cudnn_attention_decomposition(
542+
query: torch.Tensor,
543+
key: torch.Tensor,
544+
value: torch.Tensor,
545+
attn_bias: Optional[torch.Tensor],
546+
compute_log_sumexp: bool,
547+
dropout_p: float = 0.0,
548+
is_causal: bool = False,
549+
return_debug_mask: bool = False,
550+
*,
551+
scale: Optional[float] = None,
552+
) -> Tuple[
553+
torch.Tensor,
554+
torch.Tensor,
555+
torch.Tensor,
556+
torch.Tensor,
557+
torch.SymInt,
558+
torch.SymInt,
559+
torch.Tensor,
560+
torch.Tensor,
561+
torch.Tensor,
562+
]:
563+
attn = scaled_dot_product_attention_decomposition(
564+
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
565+
)
566+
return attn, None, None, None, 0, 0, None, None, None
567+
568+
438569
def get_decompositions(
439570
enable_experimental_decompositions: bool = False,
440571
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

-4
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
from .constant_folding import constant_fold
99
from .fuse_distributed_ops import fuse_distributed_ops
1010
from .fuse_prims_broadcast import fuse_prims_broadcast
11-
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1211
from .pass_manager import DynamoPassManager
1312
from .remove_assert_nodes import remove_assert_nodes
1413
from .remove_detach import remove_detach
1514
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1615
from .repair_input_as_output import repair_input_as_output
1716
from .replace_max_pool_with_indices import replace_max_pool_with_indices
18-
from .view_to_reshape import view_to_reshape
1917

2018
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
2119
[
@@ -25,8 +23,6 @@
2523
fuse_prims_broadcast,
2624
fuse_distributed_ops,
2725
replace_max_pool_with_indices,
28-
lower_scaled_dot_product_attention,
29-
view_to_reshape,
3026
remove_assert_nodes,
3127
accumulate_fp32_matmul,
3228
]

0 commit comments

Comments
 (0)