Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 73 additions & 30 deletions tensorrt_llm/visual_gen/visual_gen/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# @article{
# li2024svdquant,
# title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
# author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
# author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and
# Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
# journal={arXiv preprint arXiv:2411.05007},
# year={2024}
# }
Expand Down Expand Up @@ -76,6 +77,7 @@

try:
from deep_gemm import fp8_gemm_nt

from visual_gen.ops.deep_gemm_quant import quant_and_transform_ue8m0
except ImportError:
logger.warning("Deep_gemm is not installed")
Expand All @@ -86,7 +88,11 @@ def __init__(self):
pass

def register_tunable_params(
self, value: Any, param_range: Optional[List[Any]], name: Optional[str], description: Optional[str]
self,
value: Any,
param_range: Optional[List[Any]],
name: Optional[str],
description: Optional[str],
):
return TunableParam(value, param_range, name, description)

Expand Down Expand Up @@ -115,7 +121,6 @@ def __call__(
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:

if tensorrt_llm is None:
logger.error("TensorRT-LLM is not installed")

Expand All @@ -128,7 +133,9 @@ def __call__(
input = input.reshape(-1, input.shape[-1])

act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight, input_scale, weight_scale)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, weight, input_scale, weight_scale
)

if bias is not None:
if bias.dtype != output.dtype:
Expand Down Expand Up @@ -167,9 +174,16 @@ def __call__(
qinput, cur_input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(input)
cur_input_scale = cur_input_scale.to(torch.float32)
# This op does not support bias now.
if qinput.dim() == 3:
if qinput.dim() >= 3:
qinput = qinput.reshape(-1, qinput.shape[-1])

K = qinput.shape[-1]
kp = (K + 15) // 16 * 16
qinput = F.pad(qinput, (0, kp - K))
weight = F.pad(weight, (0, 0, 0, kp - K))
weight = weight.t().contiguous().t()
cur_input_scale = cur_input_scale.reshape(-1, cur_input_scale.shape[-1])

output = torch.ops.trtllm.cublas_scaled_mm(
qinput,
weight,
Expand Down Expand Up @@ -204,7 +218,6 @@ def __call__(
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:

# input
origin_shape = input.shape
origin_dtype = input.dtype
Expand All @@ -218,7 +231,9 @@ def __call__(
input = input.reshape(-1, input.shape[-1])
if input.shape[0] % 8 != 0:
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight, input_scale, weight_scale)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, weight, input_scale, weight_scale
)
if bias is not None:
output = output + bias
else:
Expand Down Expand Up @@ -251,7 +266,9 @@ def __call__(
is_2D_scaled=True,
quantizer=w_quantizer_cur,
)
output, *_ = ext.general_gemm(A=w_fp8_te, B=x_fp8_te, out_dtype=torch.bfloat16, bias=bias)
output, *_ = ext.general_gemm(
A=w_fp8_te, B=x_fp8_te, out_dtype=torch.bfloat16, bias=bias
)

if output.dim() != len(origin_shape):
output_shape = list(origin_shape)
Expand All @@ -265,7 +282,6 @@ def __call__(

@LinearOpManager.register_linear("te-MXFP8-blockwise-32")
class TeMXFP8Blockwise32Linear(BaseLinear):

@torch.compiler.disable()
def run_te_gemm(self, input, weight, bias):
input_quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
Expand All @@ -292,7 +308,6 @@ def __call__(
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:

# input
origin_shape = input.shape
origin_dtype = input.dtype
Expand Down Expand Up @@ -330,7 +345,9 @@ def __call__(
@LinearOpManager.register_linear("te-fp8-per-tensor")
class TeFp8PerTensorLinear(BaseLinear):
def __init__(self):
self.quantizer_cur = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
self.quantizer_cur = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"
)

def __call__(
self,
Expand Down Expand Up @@ -359,12 +376,21 @@ def __call__(
fp8_scale_inv=weight_scale.flatten().to(dtype=torch.float32),
)
if hasgelu:
gelu_in = torch.randn((input.shape[0], weight.shape[1]), dtype=torch.bfloat16, device="cuda")
gelu_in = torch.randn(
(input.shape[0], weight.shape[1]), dtype=torch.bfloat16, device="cuda"
)
output, *_ = ext.general_gemm(
A=w_fp8_te, B=input_fp8_te, out_dtype=torch.bfloat16, bias=bias, gelu=True, gelu_in=gelu_in
A=w_fp8_te,
B=input_fp8_te,
out_dtype=torch.bfloat16,
bias=bias,
gelu=True,
gelu_in=gelu_in,
)
else:
output, *_ = ext.general_gemm(A=w_fp8_te, B=input_fp8_te, out_dtype=torch.bfloat16, bias=bias)
output, *_ = ext.general_gemm(
A=w_fp8_te, B=input_fp8_te, out_dtype=torch.bfloat16, bias=bias
)

if output.dim() != len(origin_shape):
output_shape = list(origin_shape)
Expand All @@ -380,7 +406,6 @@ def __call__(
@torch.compiler.disable()
@LinearOpManager.register_linear("svd-nvfp4")
class SvdNvfp4Linear(BaseLinear):

def __call__(
self,
input: torch.Tensor,
Expand All @@ -397,7 +422,6 @@ def __call__(
svd_wcscales=None,
svd_bias=None,
) -> torch.Tensor:

B, M, K = input.shape
N = svd_qweight.shape[0]
R = svd_lora_up.shape[1]
Expand All @@ -407,7 +431,9 @@ def __call__(
assert N % 128 == 0, "N must be divisible by 128"

act = torch.empty(B * M, int(K / 2), dtype=torch.int8, device=input.device).contiguous()
ascales = torch.empty(int(K / 16), B * M, dtype=torch.float8_e4m3fn, device=input.device).contiguous()
ascales = torch.empty(
int(K / 16), B * M, dtype=torch.float8_e4m3fn, device=input.device
).contiguous()
lora_act = torch.empty(B * M, R, dtype=torch.float32, device=input.device).contiguous()
out = torch.empty(B, M, N, dtype=torch.bfloat16, device=input.device).contiguous()
lora_scales = [1.0] * (R // 16)
Expand Down Expand Up @@ -469,7 +495,6 @@ def __call__(
scaling_vector_size: int = 16,
input_scale: torch.Tensor = None,
) -> torch.Tensor:

if tensorrt_llm is None:
logger.error("TensorRT-LLM is not installed")

Expand All @@ -486,7 +511,9 @@ def __call__(

alpha = 1 / (input_scale_2 * weight_scale_2)
if input.dim() == 3:
act_fp4, act_sf = torch.ops.trtllm.fp4_batched_quantize(input, input_scale_2, scaling_vector_size, False)
act_fp4, act_sf = torch.ops.trtllm.fp4_batched_quantize(
input, input_scale_2, scaling_vector_size, False
)
if not self.trtllm_tuned:
with torch.inference_mode(), autotune():
output = torch.ops.trtllm.fp4_bmm(
Expand All @@ -510,10 +537,17 @@ def __call__(
out_dtype=origin_dtype,
)
else:
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(input, input_scale_2, scaling_vector_size, False)
input_shape = input.shape
if input.dim() > 2:
input = input.view(-1, input_shape[-1])
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, input_scale_2, scaling_vector_size, False
)
output = torch.ops.trtllm.nvfp4_gemm(
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype=origin_dtype
)
if len(input_shape) > 2:
output = output.view(*input_shape[:-1], output.shape[-1])

if bias is not None:
if bias.dtype != output.dtype:
Expand All @@ -538,7 +572,6 @@ def __call__(
scaling_vector_size: int = 16,
input_scale: torch.Tensor = None,
) -> torch.Tensor:

origin_dtype = input.dtype
if origin_dtype != torch.bfloat16:
input = input.to(torch.bfloat16)
Expand All @@ -560,14 +593,18 @@ def __call__(
batch_size = input.shape[0]
input = input.reshape(-1, input.shape[-1])
# currently still reuse trtllm fp4 quantize kernel for better performance
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(input, input_scale_2, scaling_vector_size, False)
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, input_scale_2, scaling_vector_size, False
)

output = scaled_mm_nvfp4(
act_fp4,
weight,
1 / input_scale_2,
1 / weight_scale_2,
act_sf.reshape(-1, input.shape[1] // scaling_vector_size).view(torch.float8_e4m3fn).contiguous(),
act_sf.reshape(-1, input.shape[1] // scaling_vector_size)
.view(torch.float8_e4m3fn)
.contiguous(),
weight_scale.reshape(-1, weight.shape[1] * 2 // scaling_vector_size)
.view(torch.float8_e4m3fn)
.contiguous(),
Expand All @@ -576,14 +613,18 @@ def __call__(
)
output = output.reshape(batch_size, -1, output.shape[-1])
else:
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(input, input_scale_2, scaling_vector_size, False)
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, input_scale_2, scaling_vector_size, False
)
output = scaled_mm_nvfp4(
act_fp4,
weight,
1 / input_scale_2,
1 / weight_scale_2,
act_sf.reshape(-1, input.shape[1] // scaling_vector_size).view(torch.float8_e4m3fn),
weight_scale.reshape(-1, weight.shape[1] * 2 // scaling_vector_size).view(torch.float8_e4m3fn),
weight_scale.reshape(-1, weight.shape[1] * 2 // scaling_vector_size).view(
torch.float8_e4m3fn
),
out_dtype=origin_dtype,
bias=bias.to(origin_dtype) if bias is not None else None,
)
Expand Down Expand Up @@ -613,7 +654,9 @@ def __call__(

class FlashInferNVFP4Linear:
def __init__(self):
self.input_scale_2 = torch.tensor([8.0], dtype=torch.float32, device=torch.cuda.current_device())
self.input_scale_2 = torch.tensor(
[8.0], dtype=torch.float32, device=torch.cuda.current_device()
)

def _nvfp4_gemm(
self,
Expand Down Expand Up @@ -644,7 +687,9 @@ def _nvfp4_gemm(
else:
# TODO: magic number for static quantization scale, it should be passed by a quantized ckpt
input_global_sf = self.input_scale_2
input_fp4, input_sf = nvfp4_quantize(input, input_global_sf, sfLayout=sfLayout, do_shuffle=do_shuffle)
input_fp4, input_sf = nvfp4_quantize(
input, input_global_sf, sfLayout=sfLayout, do_shuffle=do_shuffle
)

output = mm_fp4(
input_fp4,
Expand Down Expand Up @@ -765,7 +810,6 @@ def __call__(

@LinearOpManager.register_linear("deepgemm-MXFP8")
class DeepgemmFp8Linear(BaseLinear):

def __call__(
self,
input: torch.Tensor,
Expand All @@ -774,7 +818,6 @@ def __call__(
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:

origin_dtype = input.dtype
if origin_dtype != torch.bfloat16:
input = input.to(torch.bfloat16)
Expand All @@ -790,7 +833,7 @@ def __call__(
(input_fp8, input_scale),
(weight, weight_scale),
output,
None, # bias
None, # bias
disable_ue8m0_cast=False,
)

Expand Down
Loading