Description
Hi,
I've tested the QuantLSTM on the EMG dataset with the following bitwidth settings:
weight_quant: 8
io_quant: 2
sigmoid: 8
tanh_quant: 8
cell_state_quant: 8
accumulation_quant: 16
However, the quantized model failed to learn properly (classification accuracy for the 8-category task is only 12.5%).
After reviewing the code and visualizing the quantization process in the figure below:
I noticed that with io_quant
set to 2-bit, the output from the final LSTM layer to the fully connected layer is also 2-bit. In my experience, extreme low-bitwidth models generally require higher bitwidths in the final classifier layer (i.e., the FC layer). This could explain the poor performance.
I've made a proposed modification, as shown in the figure below:
- Quantize the input for every LSTM layer using the
$q_x$ quantizer (previously, only the first layer's input was quantized). - Quantize the input hidden state for each LSTM layer using the
$q_h$ quantizer (previously, this was done byoutput_quant
). - Remove
output_quant
from the output hidden state.
With this change and some quick tests, the new scheme achieved 53% accuracy—much higher than the previous 12.5%.
Next, I plan to run more experiments on datasets like PTB and explore different LSTM and RNN variations. Any feedback or suggestions are welcome, and I'll update this thread with results.
P.S. The weight and activation quantizers are defined as follows:
class Int8WeightPerTensorFloatScratch(WeightQuantSolver):
quant_type = QuantType.INT # integer quantization
bit_width_impl_type = BitWidthImplType.CONST # constant bit width
float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
scaling_impl_type = ScalingImplType.STATS # scale based on statistics
scaling_stats_op = StatsOp.MAX # scale statistics is the absmax value
restrict_scaling_type = RestrictValueType.FP # scale factor is a floating point value
scaling_per_output_channel = False # scale is per tensor
bit_width = 8 # bit width is 8
signed = True # quantization range is signed
narrow_range = True # quantization range is [-127,127] rather than [-128, 127]
zero_point_impl = ZeroZeroPoint # zero point is 0.
scaling_min_val = 1e-10 # minimum value for the scale factor
class Int8ActPerTensorFloatScratch(ActQuantSolver):
quant_type = QuantType.INT # integer quantization
bit_width_impl_type = BitWidthImplType.CONST # constant bit width
float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
scaling_impl_type = ScalingImplType.STATS # scale is a parameter initialized from statistics
scaling_stats_op = StatsOp.PERCENTILE # scale statistics is a percentile of the abs value
high_percentile_q = 99.999 # percentile is 99.999
collect_stats_steps = 300 # statistics are collected for 300 forward steps before switching to a learned parameter
restrict_scaling_type = RestrictValueType.FP # scale is a floating-point value
scaling_per_output_channel = False # scale is per tensor
bit_width = 8 # bit width is 8
signed = True # quantization range is signed
narrow_range = False # quantization range is [-128, 127] rather than [-127, 127]
zero_point_impl = ZeroZeroPoint # zero point is 0.
scaling_min_val = 1e-10 # minimum value for the scale factor
class Int2ActPerTensorFloatScratch(Int8ActPerTensorFloatScratch):
bit_width=2