Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial prototype of differentiable grouped_scaled_mm function for torchao #1969

Open
wants to merge 61 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
134242b
grouped_mm forward pass
danielvegamyhre Mar 26, 2025
2113753
add unit test
danielvegamyhre Mar 26, 2025
0a90f0b
only support float8
danielvegamyhre Mar 26, 2025
a761549
rowwise scaling test passing
danielvegamyhre Mar 27, 2025
8d15a8a
add 3Dx3D test
danielvegamyhre Mar 27, 2025
cced381
numeric unit tests passing
danielvegamyhre Mar 27, 2025
46d7e42
lint
danielvegamyhre Mar 27, 2025
e32d528
update 3Dx3D case
danielvegamyhre Mar 27, 2025
c42af73
lint
danielvegamyhre Mar 27, 2025
e61c71d
lint
danielvegamyhre Mar 27, 2025
94a0cba
change func name
danielvegamyhre Mar 27, 2025
fce469b
lint
danielvegamyhre Mar 27, 2025
3899bb2
B must be 3D
danielvegamyhre Mar 27, 2025
5099838
add docstring
danielvegamyhre Mar 27, 2025
4e04022
allow other axiswise dims so we can pass in 3D B tensor tranposed
danielvegamyhre Mar 27, 2025
4117a9e
clean up
danielvegamyhre Mar 27, 2025
61f0ee4
add todo
danielvegamyhre Mar 27, 2025
dc40622
lint
danielvegamyhre Mar 27, 2025
80b7630
rename var
danielvegamyhre Mar 27, 2025
dc013a3
check input dims are compatible
danielvegamyhre Mar 27, 2025
4f385e5
add detailed comments
danielvegamyhre Mar 27, 2025
72a9b9f
update comments
danielvegamyhre Mar 28, 2025
c4c6c99
update comments
danielvegamyhre Mar 28, 2025
cf42af1
add backward pass
danielvegamyhre Mar 28, 2025
dc6bcf3
add detailed comments
danielvegamyhre Mar 28, 2025
4c5e9db
2d-2d working
danielvegamyhre Mar 28, 2025
c9d30b6
backward working for everything except 2d-3d
danielvegamyhre Mar 28, 2025
c19bc88
all test cases working
danielvegamyhre Mar 28, 2025
90b99ba
docstring
danielvegamyhre Mar 28, 2025
526d88c
update test
danielvegamyhre Mar 28, 2025
25fa1c8
handle jagged 2d tensors
danielvegamyhre Mar 28, 2025
281950c
lint
danielvegamyhre Mar 29, 2025
9f15ac4
work on test for gradients
danielvegamyhre Apr 1, 2025
10a9823
grad is none
danielvegamyhre Apr 1, 2025
f20ddf3
add assert
danielvegamyhre Apr 1, 2025
922b842
grads not none
danielvegamyhre Apr 1, 2025
4b3ca69
outputs mismatch
danielvegamyhre Apr 1, 2025
5d367df
forward matches but _scaled_mm has no backward
danielvegamyhre Apr 1, 2025
7d21bbb
all outputs match
danielvegamyhre Apr 1, 2025
7dc7c73
gradients match
danielvegamyhre Apr 1, 2025
5703cfd
cleanup
danielvegamyhre Apr 1, 2025
6f65dae
use quant primitives manually in forward
danielvegamyhre Apr 1, 2025
93c2692
clean up
danielvegamyhre Apr 1, 2025
d7949c4
lint
danielvegamyhre Apr 1, 2025
212b47f
don't change float8 tensor
danielvegamyhre Apr 1, 2025
4b42be3
lint
danielvegamyhre Apr 1, 2025
fa708fd
lint
danielvegamyhre Apr 1, 2025
fad9d36
fix test
danielvegamyhre Apr 1, 2025
c54b528
improve test readability
danielvegamyhre Apr 1, 2025
b571442
remove old comment
danielvegamyhre Apr 1, 2025
302b554
reorganize
danielvegamyhre Apr 1, 2025
2864068
only support rowwise scaling
danielvegamyhre Apr 1, 2025
fb48868
lint
danielvegamyhre Apr 1, 2025
1cd3658
explicit re-export
danielvegamyhre Apr 1, 2025
c154222
lint
danielvegamyhre Apr 1, 2025
e9f2174
reorganize
danielvegamyhre Apr 1, 2025
4ba8453
add tests for invalid dims
danielvegamyhre Apr 1, 2025
c2e5d42
validate group sizes are multiples of 16
danielvegamyhre Apr 1, 2025
7466ce4
use save_for_backward for offs
danielvegamyhre Apr 1, 2025
527525b
remove group size assert to avoid device-host sync
danielvegamyhre Apr 1, 2025
3ea7455
precompute B_non_transposed_fp8_col_major for backward to save memory
danielvegamyhre Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,12 @@ def __new__(
linear_mm_config if linear_mm_config is not None else LinearMMConfig()
)
self._gemm_input_role = gemm_input_role
assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}"
assert axiswise_dim in (
None,
0,
-1,
-2,
), f"unsupported axiswise_dim {axiswise_dim}"
self._axiswise_dim = axiswise_dim

return self
Expand Down
280 changes: 280 additions & 0 deletions torchao/prototype/grouped_mm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
from typing import Optional

import torch

from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig


def _grouped_scaled_mm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we move this out of __init__.py into its own file?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

A: torch.Tensor,
B: torch.Tensor,
float8_recipe: Float8LinearRecipeName,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_recipe is more for the user facing UX level (creating a config). For a lower level function, can we be consistent with torch._scaled_mm and communicate the information we need in the individual arguments?

if we need to communicate "rowwise scaling", I would just mention in the docblock that we are doing rowwise scaling and call it a day for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done.

offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
"""
This function performs dynamic float8 quantization on the input tensors A and B using the given recipe,
then performs a scaled grouped GEMM and returns the results.

Args:
A (torch.Tensor): The first input tensor, which can be 2D or 3D.
B (torch.Tensor): The second input tensor which must be 3D. Dim 1 of B must match the final dim of A.
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
offs (Optional[torch.Tensor]): The offsets to use to mark the starting index of each group. This
is required when 2D A tensor is used, otherwise it should be None.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
use_fast_accum (bool): Whether to use fast accumulation or not. Default is False.
"""
return _Float8GroupedMM.apply(
A,
B,
float8_recipe,
offs,
out_dtype,
use_fast_accum,
)


class _Float8GroupedMM(torch.autograd.Function):
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""

@staticmethod
def forward(
ctx,
A: torch.Tensor,
B: torch.Tensor,
float8_recipe_name: Float8LinearRecipeName,
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
# torch._scaled_grouped_mm only supports rowwise scaling currently.
assert (
float8_recipe_name == Float8LinearRecipeName.ROWWISE
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."

assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
assert 2 <= B.ndim <= 3, "B must be 2D or 3D"

# Dim 1 of B must match the final dim of A.
assert A.size(-1) == B.size(
-2
), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"

# offsets are required for 2D A tensor, otherwise it should be None.
if A.ndim == 2 or B.ndim == 2:
assert offs is not None, "offs must be specified for 2D tensor"

# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.

# Fetch float8 config from specified recipe name.
float8_config = Float8LinearConfig.from_recipe_name(float8_recipe_name)

# Store what we need for backward.
ctx.save_for_backward(A, B)
ctx.float8_config = float8_config
ctx.offs = offs

# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
# A shape: (M, K) or (B, M, K)
# A_scale shape: (M,1) or (B, M, 1)
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
# A_scale shape: (M,) or (B, M)
A_fp8_row_major = hp_tensor_to_float8_dynamic(
A,
float8_config.cast_config_input.target_dtype,
linear_mm_config=LinearMMConfig(),
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=float8_config.cast_config_input.scaling_granularity,
axiswise_dim=-1,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
A_scale = A_fp8_row_major._scale.squeeze()

# Convert B to float8, column-major for right operand of grouped GEMM.
# B shape: (K,N) or (B, K, N)
# B scales must be computed rowwise keeping the outer/final dim, so:
# B_scale shape: (1,N) or (B, 1, N)
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
# B scale shape: (N,) or (B, N)
B_fp8_col_major = hp_tensor_to_float8_dynamic(
B,
float8_config.cast_config_input.target_dtype,
linear_mm_config=LinearMMConfig(),
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=float8_config.cast_config_weight.scaling_granularity,
axiswise_dim=-2,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
B_scale = B_fp8_col_major._scale.squeeze()

# Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
# which is the size of the `offs` tensor.
if A.ndim == 2 and B.ndim == 2:
A_scale = A_scale.repeat(offs.numel())
B_scale = B_scale.repeat(offs.numel())

# Perform scaled grouped GEMM and return result.
# output shape: (M, N) or (B, M, N)
return torch._scaled_grouped_mm(
A_fp8_row_major._data,
B_fp8_col_major._data,
A_scale,
B_scale,
offs,
out_dtype=out_dtype,
use_fast_accum=use_fast_accum,
)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
A, B = ctx.saved_tensors
offs = ctx.offs
float8_config = ctx.float8_config

# Convert grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_A: grad_output @ B
#
# grad_output shape: (M, N) or (B, M, N)
# grad_output_scale shape: (M, 1) or (B, M, 1)
# squeeze grad_output_scale to remove empty dim, as required by torch._scaled_grouped_mm.
# grad_output_scale shape: (M,) or (B, M)
grad_output_fp8_row_major = hp_tensor_to_float8_dynamic(
grad_output,
float8_config.cast_config_grad_output.target_dtype,
linear_mm_config=LinearMMConfig(),
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=float8_config.cast_config_grad_output.scaling_granularity,
axiswise_dim=-1,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
grad_output_scale = grad_output_fp8_row_major._scale.squeeze()

# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM
# needed for grad_A: grad_output @ B.
# Since B was transposed before entry to forward, we need to transpose it back here for this.
B_non_transposed_col_major = B.contiguous().transpose(-2, -1)

# - B shape: (K,N) or (B, K, N)
# - B scales must be computed rowwise keeping the outer/final dim, so:
# - B_scale shape: (1,N) or (B, 1, N)
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
# - B scale shape: (N,) or (B, N)
B_non_transposed_fp8_col_major = hp_tensor_to_float8_dynamic(
B_non_transposed_col_major,
float8_config.cast_config_input.target_dtype,
linear_mm_config=LinearMMConfig(),
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=float8_config.cast_config_weight.scaling_granularity,
axiswise_dim=-2,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
B_scale = B_non_transposed_fp8_col_major._scale.squeeze()

# Compute grad_A.
#
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N), output=(B,M,N)
# grad_A = grad_output @ B
# grad_A = (B,M,N) @ (B,N,K) = (B,M,K)
#
# Case 2: A=3D, B=2D with A=(B,M,K), B^T=(K,N) case, output=(B,M,N)
# grad_A = grad_output @ B
# grad_A = (B,M,N) @ (N,K) = (B,M,K)
#
# Case 3: A=3D, B=3D with A=(B,M,K), B^T=(B,K,N) case, output=(B,M,N)
# grad_A = grad_output @ B
# grad_A = (B,M,N) @ (B,N,K) = (B,M,K)
#
# Case 4: A=2D, B=2D with A=(M,K), B^T=(K,N) case, output=(M,N)
# grad_A = grad_output @ B
# grad_A = (M,N) @ (N,K) = (M,K)
grad_A = torch._scaled_grouped_mm(
grad_output_fp8_row_major._data,
B_non_transposed_fp8_col_major._data,
grad_output_scale,
B_scale,
offs,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)

# Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
grad_output_t = grad_output.transpose(-2, -1)

# - grad_output_t shape: (N, M) or (B, N, M)
# - grad_output_t_scale shape: (N, 1) or (B, N, 1)
# - squeeze grad_output_t_scale to remove empty dim, as required by torch._scaled_grouped_mm.
# - grad_output_t_scale shape: (N,) or (B, N)
grad_output_t_fp8_row_major = hp_tensor_to_float8_dynamic(
grad_output_t,
float8_config.cast_config_grad_output.target_dtype,
linear_mm_config=LinearMMConfig(),
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=float8_config.cast_config_grad_output.scaling_granularity,
axiswise_dim=-1,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
grad_output_t_scale = grad_output_t_fp8_row_major._scale.squeeze()

# Convert A to float8, column-major for right operand of grouped GEMM:
# needed for grad_B: grad_output_t @ A
#
# - A shape: (M, K) or (B, M, K)
# - A scales must be computed rowwise keeping the outer/final dim, for right operand in grouped GEMM, so:
# - A_scale shape: (1,K) or (B, 1, K)
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
# - A scale shape: (K,) or (B, K)
A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1)
A_fp8_col_major = hp_tensor_to_float8_dynamic(
A_col_major,
float8_config.cast_config_input.target_dtype,
linear_mm_config=LinearMMConfig(),
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=float8_config.cast_config_input.scaling_granularity,
axiswise_dim=-2,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
A_scale = A_fp8_col_major._scale.squeeze()

# Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
# which is the size of the `offs` tensor.
if grad_output_t_fp8_row_major.ndim == 2 and A_fp8_col_major.ndim == 2:
grad_output_t_scale = grad_output_t_scale.repeat(offs.numel())
A_scale = A_scale.repeat(offs.numel())

# Compute grad_B = grad_output_t @ A.
#
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(M,N) <-- special case, B reduced?
# grad_B = grad_output_t @ A
# grad_B = (N,M) @ (M,K) = (N,K) <-- do we need to repeat along dim0 so it's (B,N,K)?
#
# Case 2: A=3D, B=2D with A=(B,M,K), B^T=(K,N) case, output=(B,M,N)
# grad_B = grad_output_t @ A
# grad_B = (B,N,M) @ (B,M,K) = (B,N,K) ----> do we need to reduce along dim0 so it's (N,K)?
#
# Case 3: A=3D, B=3D with A=(B,M,K), B^T=(B,K,N) case, output=(B,M,N)
# grad_B = grad_output_t @ A
# grad_B = (B,N,M) @ (B,M,K) = (B,N,K)
#
# Case 4: A=2D, B=2D with A=(M,K), B^T=(K,N) case, output=(M,N)
# grad_B = grad_output_t @ A
# grad_B = (N,M) @ (M,K) = (N,K)
grad_B = torch._scaled_grouped_mm(
grad_output_t_fp8_row_major._data,
A_fp8_col_major._data,
grad_output_t_scale,
A_scale,
offs,
out_dtype=grad_output.dtype,
use_fast_accum=False,
)
# Since B was transposed before entry to forward, we need to transpose the gradient to match.
grad_B = grad_B.transpose(-2, -1)

return grad_A, grad_B, None, None, None, None
Loading
Loading