Open
Description
Hi, I tested a modified sample code from the tutorial to check the performance gain and the accuracy of the SemiSparseLinear. I found out that the SemiSparseLinear produces wrong results and is much slower than torch.nn.Linear on H100 GPU. The testing code is attached below. Is there anything I done incorrectly here?
import torch
# from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
# SparseSemiStructuredTensor._FORCE_CUTLASS = False
# Modification: use SemiSparseLinear from torchao.
from torchao.sparsity.training import (
SemiSparseLinear,
)
# Problem scale
in_f = 10240
out_f = 3072
# mask Linear weight to be 2:4 sparse
# Modification: torchao SemiSparseLinear will jointly sparsify the A nd A', so we construct mask with 4x4 blocks.
mask = torch.Tensor([[0, 0, 1, 1],[0, 0, 1, 1],[1, 1, 0, 0],[1, 1, 0, 0]]).tile((out_f//4, in_f//4)).cuda().bool()
linear = torch.nn.Linear(in_f, out_f).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)
x = torch.rand(out_f, in_f).half().cuda()
with torch.inference_mode():
dense_output = linear(x)
dense_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# Error on accelerate via SparseSemiStructuredTensor:
# RuntimeError: sparse_semi_structured_mad_op : Supported only on GPUs with compute capability 8.x
# linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
# Modification: Use the SemiSparseLinear from torchao instead
linear_sparse = SemiSparseLinear.from_dense(linear)
# The sparsification is dynamic in forward func of SemiSparseLinear, so the weight is identical to linear.
assert id(linear_sparse.weight)==id(linear.weight)
sparse_output = linear_sparse(x)
sparse_t = Timer(stmt="linear_sparse(x)",
globals={"linear_sparse": linear_sparse,
"x": x}).blocked_autorange().median * 1e3
print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
abs_diff = torch.abs(sparse_output - dense_output)
max_error = torch.max(abs_diff)
max_error_index = torch.argmax(abs_diff)
max_error_coords = torch.unravel_index(max_error_index, sparse_output.shape)
print(f"Max error: {max_error.item()} at index {max_error_coords}")
# sparse and dense matmul are numerically equivalent
assert torch.allclose(sparse_output, dense_output, atol=1e-3)
Metadata
Metadata
Assignees
Labels
No labels
Activity