Skip to content

Modifying RNN Quantization for bitwidth lower than 8-bit #1041

Open
@JiaMingLin

Description

@JiaMingLin

Hi,

I've tested the QuantLSTM on the EMG dataset with the following bitwidth settings:

weight_quant: 8
io_quant: 2
sigmoid: 8
tanh_quant: 8
cell_state_quant: 8
accumulation_quant: 16

However, the quantized model failed to learn properly (classification accuracy for the 8-category task is only 12.5%).
After reviewing the code and visualizing the quantization process in the figure below:
截圖 2024-10-03 晚上11 46 34
I noticed that with io_quant set to 2-bit, the output from the final LSTM layer to the fully connected layer is also 2-bit. In my experience, extreme low-bitwidth models generally require higher bitwidths in the final classifier layer (i.e., the FC layer). This could explain the poor performance.

I've made a proposed modification, as shown in the figure below:
截圖 2024-10-04 凌晨12 02 35

  • Quantize the input for every LSTM layer using the $q_x$ quantizer (previously, only the first layer's input was quantized).
  • Quantize the input hidden state for each LSTM layer using the $q_h$ quantizer (previously, this was done by output_quant).
  • Remove output_quant from the output hidden state.

With this change and some quick tests, the new scheme achieved 53% accuracy—much higher than the previous 12.5%.

Next, I plan to run more experiments on datasets like PTB and explore different LSTM and RNN variations. Any feedback or suggestions are welcome, and I'll update this thread with results.

P.S. The weight and activation quantizers are defined as follows:

class Int8WeightPerTensorFloatScratch(WeightQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.STATS # scale based on statistics
    scaling_stats_op = StatsOp.MAX # scale statistics is the absmax value
    restrict_scaling_type = RestrictValueType.FP # scale factor is a floating point value
    scaling_per_output_channel = False # scale is per tensor
    bit_width = 8 # bit width is 8
    signed = True # quantization range is signed
    narrow_range = True # quantization range is [-127,127] rather than [-128, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.
    scaling_min_val = 1e-10 # minimum value for the scale factor

class Int8ActPerTensorFloatScratch(ActQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.STATS # scale is a parameter initialized from statistics
    scaling_stats_op = StatsOp.PERCENTILE # scale statistics is a percentile of the abs value
    high_percentile_q = 99.999 # percentile is 99.999
    collect_stats_steps = 300  # statistics are collected for 300 forward steps before switching to a learned parameter
    restrict_scaling_type = RestrictValueType.FP # scale is a floating-point value
    scaling_per_output_channel = False  # scale is per tensor
    bit_width = 8  # bit width is 8
    signed = True # quantization range is signed
    narrow_range = False # quantization range is [-128, 127] rather than [-127, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.
    scaling_min_val = 1e-10 # minimum value for the scale factor

class Int2ActPerTensorFloatScratch(Int8ActPerTensorFloatScratch):
    bit_width=2

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions