Skip to content

int4_weight_only Slows Down torch.nn.Linear for Llama2 7B Shapes #1606

Open
@mostafaelhoushi

Description

I have created a small script to benchmark int4 quantization on A100 GPUs, with inputs that have batch size 1 and seqlen 1.

When I test weigh shapes that exist in Llama2 7B, I actually get a slow down:

# input_dim, output_dim = 4096, 4096
Baseline:       0.023313920497894287 ms
Quantized:      0.08300095558166504 ms
# input_dim, output_dim = 4096, 11008
Baseline:       0.06082496166229248 ms
Quantized:      0.08460960388183594 ms
# input_dim, output_dim = 11008, 4096
Baseline:       0.059748477935791015 ms
Quantized:      0.09495231628417969 ms

When I use a really large shape that doesn't exist in Llama2 7B, I do get some speedup:

# input_dim, output_dim = 11008, 11008
Baseline:       0.14746272087097168 ms
Quantized:      0.09298111915588379 ms

This is strange because gpt-fast uses a similar int4 quantization and gets 2x speedup on Llama2 7B.

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions