Skip to content

Add activation sparsity (24 + fp8 dynamic quant) subclass #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions benchmarks/benchmark_e2e_fp8_sparse_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@
from tqdm import tqdm
from triton.testing import do_bench

from torchao.prototype.sparsity.activation.srelu_linear import (
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig,
)
from torchao.prototype.sparsity.activation.utils import SquaredReLU
from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
PerRow,
quantize_,
)
from torchao.sparsity.sparse_api import (
Float8DynamicSemiSparseActivationFloat8WeightConfig,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
def benchmark(num_tokens, hidden_size=4096, intermediate_size=16384):
ffn_ref = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
Expand Down Expand Up @@ -72,25 +71,12 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor)

# fp8 sparse
ffn_clone = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
SquaredReLU(),
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
.cuda()
)
quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig())
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)

# activation fp8 sparse
ffn_clone = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
# no Squared RELU since it will be fused into the second linear
SquaredReLU(),
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
Expand All @@ -103,9 +89,10 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
),
)
quantize_(
ffn_clone,
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(),
filter_fn=lambda mod, fqn: "1" in fqn,
ffn_clone[2],
Float8DynamicSemiSparseActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)
Expand All @@ -115,7 +102,6 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
"bf16_latency (us)": fp16_time,
"bf16_c_latency (us)": fp16_c_time,
"fp8_c_time (us)": fp8_c_time,
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
"speedup": fp8_c_time / fp8_c_activation_sparse_time,
}
Expand All @@ -124,7 +110,7 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
if __name__ == "__main__":
with torch.no_grad():
results = []
for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]):
for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]):
results.append(benchmark(num_tokens))
torch.compiler.reset()

Expand Down
23 changes: 23 additions & 0 deletions benchmarks/benchmark_splitk_sparse_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.nn.functional as F
from triton.testing import do_bench

from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv
from torchao.sparsity.utils import create_binary_tensor

dtype = torch.bfloat16


for sparsity_level in [0.01, 0.05, 0.1, 0.25, 0.5, 0.8, 0.9, 0.95]:
a = create_binary_tensor((1, 4096), sparsity_level).cuda().to(dtype)
b = torch.randn(16384, 4096).cuda().to(dtype).T.contiguous().T

sparse_time = do_bench(lambda: splitk_sparse_gemv(a, b)) * 1e6

dense_time = (
do_bench(lambda: F.linear(a.to(torch.float16), b.to(torch.float16))) * 1e6
)
speedup = dense_time / sparse_time
print(
f"sparsity_level: {sparsity_level:.2f} | sparse time: {sparse_time:.2f} | dense_time: {dense_time:.2f} | speedup: {speedup:.2f}"
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import copy
import unittest

import torch
import torch.nn.functional as F
from parameterized import parameterized

from torchao.kernel.splitk_sparse_gemv import splitk_sparse_gemv
from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Expand All @@ -9,17 +14,10 @@
quantize_,
)
from torchao.quantization.quant_api import _float8_cutlass_quant

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True

import copy
import unittest

from torchao.prototype.sparsity.activation.srelu_linear import (
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig,
from torchao.sparsity.sparse_api import (
Float8DynamicSemiSparseActivationFloat8WeightConfig,
)
from torchao.sparsity import sparsify_
from torchao.sparsity.utils import create_semi_structured_tensor
from torchao.sparsity.utils import create_binary_tensor, create_semi_structured_tensor
from torchao.utils import is_sm_at_least_90


Expand Down Expand Up @@ -102,8 +100,18 @@ def test_sparse24_sm90_sparsify_srelu(M=512, K=1024, fp8=torch.float8_e4m3fn) ->
assert (A_packed != A_packed_ref).float().mean().item() < 0.1


@parameterized.expand(
[
(1, 8192, 1024, True),
(64, 8192, 1024, True),
(1024, 8192, 1024, True),
(1, 8192, 1024, False),
(64, 8192, 1024, False),
(1024, 8192, 1024, False),
]
)
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024):
def test_fp8_semi_sparse_activation_linear(M, K, N, do_compile=False):
with torch.no_grad():
torch.manual_seed(0)
input_tensor = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda()
Expand All @@ -116,34 +124,51 @@ def test_srelu_fp8_semi_sparse_activation_linear(M=512, K=2048, N=1024):
quantize_(
reference_linear,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=False)
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)

# define reference implementation
def srelu_linear(x):
x = F.relu(x) ** 2
return reference_linear(x)
if do_compile:
reference_linear.forward = torch.compile(
reference_linear.forward,
fullgraph=True,
)

reference_srelu = torch.compile(srelu_linear, fullgraph=True)

# this only works with fullgraph=True, errors in eager
# TODO figure out exactly why this happens
sparsify_(
quantize_(
reference_linear_copy,
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(),
)
# (reference_linear_copy)
reference_linear_copy.forward = torch.compile(
reference_linear_copy.forward, fullgraph=True
Float8DynamicSemiSparseActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)

reference_output = reference_srelu(input_tensor)
if do_compile:
reference_linear_copy.forward = torch.compile(
reference_linear_copy.forward, fullgraph=True
)

reference_output = reference_linear(input_tensor)
custom_output = reference_linear_copy(input_tensor)

torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01)


@unittest.skipIf(not torch.cuda.is_available(), "Needs cuda to run")
def test_splitk_sparse_gemv():
torch.manual_seed(0)

activation = create_binary_tensor((1, 4096), 0.2).cuda().to(torch.float16)
weight = torch.randn(16384, 4096, dtype=torch.float16).cuda()

# weight must be column major
weight_transposed = weight.T.contiguous().T

sparse_res = splitk_sparse_gemv(activation, weight_transposed)
dense_res = F.linear(activation, weight_transposed)

# This rtol is ridiculousl high, because the split gemv output accumulates slightly differently than the dense output.
torch.testing.assert_close(sparse_res, dense_res, rtol=10, atol=0.1)


@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
def test_sparse24_fp8_sm90_cutlass_gemm_eye(
M=512, K=256, dtype=torch.float8_e4m3fn
Expand Down Expand Up @@ -171,7 +196,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye(
# Check MM with scale
b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32)
a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32)
A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(
A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(
A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale
)
assert torch.allclose(
Expand Down
5 changes: 5 additions & 0 deletions torchao/csrc/cuda/activation24/sparse_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ Tensor _sparse24_fp8_sm90_cutlass_gemm(
{cute::get<0>(args.problem_shape), cute::get<1>(args.problem_shape)},
at::TensorOptions().dtype(K::kElementOutAt));

// meta registration
if (kIsMeta) {
return out;
}

args.mainloop.ptr_A =
reinterpret_cast<K::ElementA const*>(tensor_a.data_ptr());
args.mainloop.ptr_B = static_cast<K::ElementB const*>(tensor_b.data_ptr());
Expand Down
12 changes: 11 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,17 @@ def from_hp_to_floatx(
scale = choose_qparams_affine_float8(
input_float, float8_dtype=target_dtype, block_size=block_size
)
data = quantize_affine_float8(input_float, scale, target_dtype)

# need to import here to avoid circular import
from torchao.dtypes.floatx.cutlass_semi_sparse_layout import (
CutlassSemiSparseLayout,
)

if isinstance(_layout, CutlassSemiSparseLayout):
# handle sparse activation specially, since the sparsification kernel also does the quantization
data = input_float
else:
data = quantize_affine_float8(input_float, scale, target_dtype)
data, scale, zero_point = _layout.post_process(
data, scale, None, block_size
)
Expand Down
6 changes: 6 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchao.dtypes.floatx.cutlass_semi_sparse_layout import (
_linear_fp8_act_fp8_weight_sparse_cutlass_check,
_linear_fp8_act_fp8_weight_sparse_cutlass_impl,
_linear_fp8_act_sparse_fp8_weight_cutlass_check,
_linear_fp8_act_sparse_fp8_weight_cutlass_impl,
)
from torchao.dtypes.floatx.float8_layout import (
_linear_fp8_act_fp8_weight_check,
Expand Down Expand Up @@ -191,6 +193,10 @@ def _register_aqt_quantized_linear_dispatches():
_linear_int8_act_int8_weight_semi_structured_sparse_check,
_linear_int8_act_int8_weight_semi_structured_sparse_impl,
),
(
_linear_fp8_act_sparse_fp8_weight_cutlass_check,
_linear_fp8_act_sparse_fp8_weight_cutlass_impl,
),
(
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
Expand Down
Loading
Loading