Skip to content

flux fp4 example(WIP) #3537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: lluo/fp4_issue_debugging
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ def aten_ops_neg(
)
else:

@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
@dynamo_tensorrt_converter(
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
)
def aten_ops_quantize_op(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -653,6 +655,38 @@ def aten_ops_dynamic_block_quantize_op(
)


def attention_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
def aten_ops_squeeze(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
activation,
addmm,
arange,
attention,
cast,
cat,
condition,
Expand Down
165 changes: 165 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import math
from typing import Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.types import TRTTensor


def tril(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
# the lower triangle of the tensor means the rows greater than and equal to the cols
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
)
# get the rows
row_tensor = impl.elementwise.trunc_div(
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
)
# get the cols
col_tensor = impl.elementwise.fmod(
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
)
cond = impl.elementwise.ge(
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
)
return impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", cond, [row, col]
)


def scaled_dot_product_attention(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
is_causal: bool,
scale: Optional[float],
) -> TRTTensor:
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
mm = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_mm",
query,
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
if scale is None:
scale = query.shape[-1]
if scale < 0:
# dynamic shape
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
else:
# static shape
sqrt_scaled = math.sqrt(scale)
scaled = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
sqrt_scaled,
)
else:
scaled = impl.elementwise.mul(
ctx,
target,
source_ir,
name + "_scale",
mm,
scale,
)

if is_causal:
L, S = query.shape[-2], key.shape[-2]
if L >= 0 and S >= 0:
# static shape
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
else:
# if any of the L or S is dynamic shape
if L < 0:
L = impl.shape.shape(
ctx, target, source_ir, name + "_shape_0", query, -2
)
if S < 0:
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)

LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)

# this is to generate a tensor which has shape (L, S), type is int32
arange_tensor = impl.arange.arange(
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
)
shape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
)

# since we want our attn_bias to be in float32, so cast it to float32
shape_tensor = cast_trt_tensor(
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
)

# initialize the attn_bias as the zeros tensor
attn_bias = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
)

# generate the mask tensor
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)
inf_tensor = impl.elementwise.mul(
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
)
cond = impl.elementwise.eq(
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
)
# mask out the certain part of the attn_bias
attn_bias = impl.condition.select(
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
)

scaled = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1, False
)
out = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_out",
softmax,
value,
)

return out
52 changes: 32 additions & 20 deletions py/torch_tensorrt/dynamo/conversion/impl/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -28,42 +29,53 @@ def quantize(
"""

with unset_fake_temporarily():
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
trt.float32,
trt.float16,
):
raise ValueError(
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
)
if isinstance(input_tensor, (torch.Tensor, TRTTensor)):
input_tensor = get_trt_tensor(ctx, input_tensor, name)
if input_tensor.dtype not in (
trt.float32,
trt.float16,
):
raise ValueError(
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
)
if num_bits != 8 or exponent_bits not in (0, 4):
raise ValueError(
f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}"
)
else:
raise ValueError(
f"quantize converter received an input of {type(input_tensor)} type. Supported types: torch.Tensor | TRTTensor"
)

if num_bits == 8 and exponent_bits == 0:
max_bound = 127
elif num_bits == 8 and exponent_bits == 4:
max_bound = 448

amax = to_torch(amax, None)
scale = torch.divide(amax, max_bound)
scale = get_trt_tensor(ctx, scale, name + "_scale")
# Add Q node
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
if not isinstance(amax, trt.ITensor):
amax = to_torch(amax, None)
scale = torch.divide(amax, max_bound)
scale = get_trt_tensor(ctx, amax, name + "_scale")
else:
scale = impl.elementwise_divide(
ctx, target, source_ir, name + "_scale", amax, max_bound
)

if num_bits == 8 and exponent_bits == 0:
quantize_layer.set_output_type(0, trt.DataType.INT8)
dtype = trt.DataType.INT8
elif num_bits == 8 and exponent_bits == 4:
quantize_layer.set_output_type(0, trt.DataType.FP8)
dtype = trt.DataType.FP8

# Add Q node
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
q_output = quantize_layer.get_output(0)
# Add DQ node
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
dequantize_layer = ctx.net.add_dequantize(
q_output, scale, output_type=input_tensor.dtype
)
dequantize_layer.to_type = input_tensor.dtype
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
if num_bits == 8 and exponent_bits == 0:
dequantize_layer.precision = trt.DataType.INT8
elif num_bits == 8 and exponent_bits == 4:
# Set DQ layer precision to FP8
dequantize_layer.precision = trt.DataType.FP8
dq_output = dequantize_layer.get_output(0)

return dq_output
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from torch_tensorrt.dynamo.types import TRTTensor

from packaging import version as pkg_version

logger = logging.getLogger(__name__)


Expand All @@ -24,7 +26,7 @@ def unsqueeze(
) -> TRTTensor:
from importlib.metadata import version

if version("tensorrt") < "10.7.0":
if pkg_version.parse(version("tensorrt")) < pkg_version.parse("10.7.0"):
logger.warning(
f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .constant_folding import constant_fold
from .fuse_distributed_ops import fuse_distributed_ops
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_nodes import remove_assert_nodes
from .remove_detach import remove_detach
Expand All @@ -23,6 +24,7 @@
repair_input_as_output,
fuse_prims_broadcast,
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
remove_assert_nodes,
accumulate_fp32_matmul,
remove_num_users_is_0_nodes,
Expand Down
Loading
Loading