-
-
Notifications
You must be signed in to change notification settings - Fork 16.2k
Use FI tinygemm for faster BF16 GEMMs #39921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f6c2425
adf0e1a
a37b29d
1b41f11
8d95c7f
e44cc91
e9c1019
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,87 @@ def apply_penalties( | |
| return logits | ||
|
|
||
|
|
||
| def _tinygemm_bf16_impl( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Real implementation: calls FlashInfer tinygemm via lazy wrapper.""" | ||
| from vllm.utils.flashinfer import flashinfer_tinygemm_bf16 | ||
| out = torch.empty( | ||
| input.shape[0], weight.shape[0], | ||
| dtype=torch.bfloat16, device=input.device, | ||
| ) | ||
| flashinfer_tinygemm_bf16(input, weight, out, bias=bias) | ||
| return out | ||
|
|
||
|
|
||
| def _tinygemm_bf16_fake( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Fake implementation for torch.compile graph tracing.""" | ||
| return torch.empty( | ||
| input.shape[0], weight.shape[0], | ||
| dtype=torch.bfloat16, device=input.device, | ||
| ) | ||
|
|
||
|
|
||
| _TINYGEMM_AVAILABLE = False | ||
|
|
||
|
|
||
| def _init_tinygemm(): | ||
| """Register tinygemm custom op if FlashInfer is available on SM90+.""" | ||
| global _TINYGEMM_AVAILABLE | ||
| try: | ||
| from vllm.utils.flashinfer import has_flashinfer | ||
| if not has_flashinfer(): | ||
| return | ||
| capability = current_platform.get_device_capability() | ||
| if capability is None or capability[0] < 9: | ||
| return | ||
| direct_register_custom_op( | ||
| "tinygemm_bf16", | ||
| _tinygemm_bf16_impl, | ||
| fake_impl=_tinygemm_bf16_fake, | ||
| ) | ||
| _TINYGEMM_AVAILABLE = True | ||
| except Exception: | ||
| pass | ||
|
|
||
|
|
||
| _init_tinygemm() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should we should not import tinygemm on import of this file. It should be delayed to the first call of the unquantized gemm function itself |
||
|
|
||
|
|
||
| def _tinygemm_unquantized_gemm( | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ): | ||
| num_tokens = x.numel() // x.shape[-1] | ||
| if ( | ||
| num_tokens <= 8 | ||
| and x.dtype == torch.bfloat16 | ||
| and weight.dtype == torch.bfloat16 | ||
| and weight.shape[0] % 16 == 0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to see a microbenchmark script to make sure this gemm is good across a range of weight shapes, currently this will apply to most weights small or large |
||
| and x.is_contiguous() | ||
| and weight.is_contiguous() | ||
| and (bias is None or bias.dtype == torch.bfloat16) | ||
| ): | ||
| if bias is None: | ||
| bias = torch.zeros( | ||
| weight.shape[0], dtype=torch.bfloat16, device=x.device, | ||
| ) | ||
|
Comment on lines
+161
to
+164
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems very suboptimal to be always allocating and filling bias on a low latency path. I would avoid this
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems non-optional. Not sure what a better solution looks like, but agree that filling the bias every time is a pain
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can use a pattern like the workspace buffer for FlashInfer attention?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, indeed. I will talk to FlashInfer kernel author to see if that is something that can be changed. |
||
| out_shape = (*x.shape[:-1], weight.shape[0]) | ||
| result = torch.ops.vllm.tinygemm_bf16( | ||
| x.view(num_tokens, -1), weight, bias, | ||
| ) | ||
| return result.view(out_shape) | ||
| return torch.nn.functional.linear(x, weight, bias) | ||
|
|
||
|
|
||
| def default_unquantized_gemm( | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
|
|
@@ -304,5 +385,7 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: | |
| return rocm_unquantized_gemm | ||
| elif current_platform.is_cpu(): | ||
| return cpu_unquantized_gemm | ||
| elif _TINYGEMM_AVAILABLE: | ||
| return _tinygemm_unquantized_gemm | ||
| else: | ||
| return default_unquantized_gemm | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do the capability check only for arches that we have validated? I am wary of enabling this on all NVIDIA GPUs >= Hopper without benchmarks on each platform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do microbenchmarking across multiple platforms to make sure we are honest here.