|
| 1 | +import logging |
| 2 | +import math |
| 3 | +from typing import Any, Dict, Optional, Tuple, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import tensorrt as trt |
| 7 | +import torch |
| 8 | +import torch_tensorrt |
| 9 | +from torch.fx.node import Target |
| 10 | +from torch_tensorrt._enums import dtype |
| 11 | +from torch_tensorrt.dynamo.conversion import impl |
| 12 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
| 13 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 14 | + SourceIR, |
| 15 | + cast_trt_tensor, |
| 16 | + get_trt_tensor, |
| 17 | +) |
| 18 | +from torch_tensorrt.fx.types import TRTTensor |
| 19 | + |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +def tril( |
| 24 | + ctx: ConversionContext, |
| 25 | + target: Union[Target, str], |
| 26 | + source_ir: Optional[SourceIR], |
| 27 | + name: str, |
| 28 | + row: TRTTensor, |
| 29 | + col: TRTTensor, |
| 30 | +) -> TRTTensor: |
| 31 | + row_arange_tensor = impl.arange.arange( |
| 32 | + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 |
| 33 | + ) |
| 34 | + row_reshape_tensor = impl.shuffle.reshape( |
| 35 | + ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] |
| 36 | + ) |
| 37 | + |
| 38 | + col_arange_tensor = impl.arange.arange( |
| 39 | + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 |
| 40 | + ) |
| 41 | + col_reshape_tensor = impl.shuffle.reshape( |
| 42 | + ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] |
| 43 | + ) |
| 44 | + |
| 45 | + mask = impl.elementwise.ge( |
| 46 | + ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor |
| 47 | + ) |
| 48 | + return mask |
| 49 | + |
| 50 | + |
| 51 | +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( |
| 52 | + torch.nn.functional.scaled_dot_product_attention, |
| 53 | + enabled=True, |
| 54 | + supports_dynamic_shapes=True, |
| 55 | +) |
| 56 | +def scaled_dot_product_attention( |
| 57 | + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, |
| 58 | + target: Target, |
| 59 | + args: Tuple[Any, ...], |
| 60 | + kwargs: Dict[str, Any], |
| 61 | + name: str, |
| 62 | +) -> TRTTensor: |
| 63 | + # TODO: Handle attn_mask and is_causal arguments in the future |
| 64 | + query, key, value, attn_mask, dropout_p, is_causal = args |
| 65 | + logger.info( |
| 66 | + "Ignoring attn_mask and is_causal arguments provided by the original graph. " |
| 67 | + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " |
| 68 | + "and for generate phase, is_causal=False since we pass only 1 input token at a time" |
| 69 | + ) |
| 70 | + |
| 71 | + # TODO: remove this once we have a better way to handle the causal mask |
| 72 | + scale = kwargs.get("scale", None) |
| 73 | + source_ir = SourceIR.ATEN |
| 74 | + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html |
| 75 | + mm = impl.matmul.matrix_multiply( |
| 76 | + ctx, |
| 77 | + target, |
| 78 | + source_ir, |
| 79 | + name + "_mm", |
| 80 | + query, |
| 81 | + key, |
| 82 | + other_matrix_op=trt.MatrixOperation.TRANSPOSE, |
| 83 | + ) |
| 84 | + if scale is None: |
| 85 | + scale = query.shape[-1] |
| 86 | + if scale < 0: |
| 87 | + # dynamic shape |
| 88 | + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) |
| 89 | + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) |
| 90 | + else: |
| 91 | + # static shape |
| 92 | + sqrt_scaled = math.sqrt(scale) |
| 93 | + scaled = impl.elementwise.div( |
| 94 | + ctx, |
| 95 | + target, |
| 96 | + source_ir, |
| 97 | + name + "_scale", |
| 98 | + mm, |
| 99 | + sqrt_scaled, |
| 100 | + ) |
| 101 | + else: |
| 102 | + scaled = impl.elementwise.mul( |
| 103 | + ctx, |
| 104 | + target, |
| 105 | + source_ir, |
| 106 | + name + "_scale", |
| 107 | + mm, |
| 108 | + scale, |
| 109 | + ) |
| 110 | + |
| 111 | + # If is_causal is True, we need to generate a causal mask |
| 112 | + if is_causal: |
| 113 | + L, S = query.shape[-2], key.shape[-2] |
| 114 | + if L >= 0 and S >= 0: |
| 115 | + # static shape |
| 116 | + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) |
| 117 | + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) |
| 118 | + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) |
| 119 | + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") |
| 120 | + else: |
| 121 | + # if any of the L or S is dynamic shape |
| 122 | + if L < 0: |
| 123 | + L = impl.shape.shape( |
| 124 | + ctx, target, source_ir, name + "_shape_0", query, 2 |
| 125 | + ) |
| 126 | + if S < 0: |
| 127 | + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) |
| 128 | + |
| 129 | + # generate the mask tensor |
| 130 | + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) |
| 131 | + |
| 132 | + temp_mask = impl.unary.logical_not( |
| 133 | + ctx, target, source_ir, name + "_logical_not", tril_tensor |
| 134 | + ) |
| 135 | + temp_mask_casted = cast_trt_tensor( |
| 136 | + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir |
| 137 | + ) |
| 138 | + one_minus_temp_mask = impl.elementwise.sub( |
| 139 | + ctx, |
| 140 | + target, |
| 141 | + source_ir, |
| 142 | + name + "_one_minus_temp_mask", |
| 143 | + 1.0, |
| 144 | + temp_mask_casted, |
| 145 | + ) |
| 146 | + attn_bias = impl.unary.log( |
| 147 | + ctx, target, source_ir, name + "_log", one_minus_temp_mask |
| 148 | + ) |
| 149 | + |
| 150 | + scaled_add_attn_bias = impl.elementwise.add( |
| 151 | + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias |
| 152 | + ) |
| 153 | + else: |
| 154 | + scaled_add_attn_bias = scaled |
| 155 | + |
| 156 | + # Create a if condition to check if is_causal is True |
| 157 | + if isinstance(is_causal, TRTTensor): |
| 158 | + if_layer = ctx.net.add_if_conditional() |
| 159 | + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled |
| 160 | + if_layer.set_condition(condition) |
| 161 | + output_layer = if_layer.add_output(true_branch, false_branch) |
| 162 | + scaled_add_attn_bias = output_layer.get_output(0) |
| 163 | + |
| 164 | + softmax = impl.normalization.softmax( |
| 165 | + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False |
| 166 | + ) |
| 167 | + out = impl.matmul.matrix_multiply( |
| 168 | + ctx, |
| 169 | + target, |
| 170 | + source_ir, |
| 171 | + name + "_out", |
| 172 | + softmax, |
| 173 | + value, |
| 174 | + ) |
| 175 | + |
| 176 | + return out |
0 commit comments