Skip to content

[ARM] Support 8bit/4bit weights decompression for Matmul primitive  #2081

Open
@dmitry-gorokhov

Description

@dmitry-gorokhov

Problem statement

LLM workloads oriented on best latency are memory bound. Inference speed is limited by model weights access through DDR. That’s why major optimization technique is weights compression (4bits weights compression might bring up-to 4 times better latency in comparison with bf16/fp16 weights).

Preferred solution

OneDNN already extended x64 brgemm Matmul primitive (8bit, 4bit) to support the following decompression math:

  1. Decompress block of weight in temp buffer (via brgemm_matmul_copy_b): w_fp = (w_compressed - zp)*scale.
  2. Call regular fp Matmul on the weight block.

Since floating point Brgemm Matmul is already implemented for aarch64 (at least with SVE) the proposal is to extended it to support compressed weights (in the same way it is done for x64).

The request is to support the following options:

  1. i4/u4/i8/u8 weights input + fp32/fp16/bf16 activations.
  2. additional input for scales (per output channel values for int8, grouped for int4). Data type: FP32/FP16
  3. optional zero point value (per output channel values for int8, grouped for int4). It can be equal to weights element type, but we can also convert to FP32/FP16 if impl prefers it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions