Skip to content

Wrong result and no speedup with SemiSparseLinear from Torchao compared to torch.nn.Linear #1617

Open
@lin-ht

Description

@lin-ht

Hi, I tested a modified sample code from the tutorial to check the performance gain and the accuracy of the SemiSparseLinear. I found out that the SemiSparseLinear produces wrong results and is much slower than torch.nn.Linear on H100 GPU. The testing code is attached below. Is there anything I done incorrectly here?

import torch
# from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
# SparseSemiStructuredTensor._FORCE_CUTLASS = False
# Modification: use SemiSparseLinear from torchao.
from torchao.sparsity.training import (
    SemiSparseLinear,
)

# Problem scale
in_f = 10240
out_f = 3072

# mask Linear weight to be 2:4 sparse
# Modification: torchao SemiSparseLinear will jointly sparsify the A nd A', so we construct mask with 4x4 blocks.
mask = torch.Tensor([[0, 0, 1, 1],[0, 0, 1, 1],[1, 1, 0, 0],[1, 1, 0, 0]]).tile((out_f//4, in_f//4)).cuda().bool()
linear = torch.nn.Linear(in_f, out_f).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)

x = torch.rand(out_f, in_f).half().cuda()

with torch.inference_mode():
    dense_output = linear(x)
    dense_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # Error on accelerate via SparseSemiStructuredTensor:
    # RuntimeError: sparse_semi_structured_mad_op : Supported only on GPUs with compute capability 8.x
    # linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
    # Modification: Use the SemiSparseLinear from torchao instead
    linear_sparse = SemiSparseLinear.from_dense(linear)
    # The sparsification is dynamic in forward func of SemiSparseLinear, so the weight is identical to linear.
    assert id(linear_sparse.weight)==id(linear.weight)

    sparse_output = linear_sparse(x)
    sparse_t = Timer(stmt="linear_sparse(x)",
                     globals={"linear_sparse": linear_sparse,
                              "x": x}).blocked_autorange().median * 1e3

    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")

    abs_diff = torch.abs(sparse_output - dense_output)
    max_error = torch.max(abs_diff)
    max_error_index = torch.argmax(abs_diff)
    max_error_coords = torch.unravel_index(max_error_index, sparse_output.shape)
    print(f"Max error: {max_error.item()} at index {max_error_coords}")

    # sparse and dense matmul are numerically equivalent
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions