This repository holds a minimal PyTorch script that trains a fully connected network on
MNIST while quantizing every weight update to integers. The
quantization is simulated by scaling the parameters to the available range before each
forward pass, so weights still update in float32 but the network computes with the
low-precision approximation.
- Install dependencies:
pip install -r requirements.txt - Run training (uses CPU or CUDA if available):
python train_int_mnist.py --epochs 5 --batch-size 256 - To speed up experiments, restrict dataset size:
python train_int_mnist.py --limit-train-samples 5000 --limit-test-samples 1000
Optional flags can adjust the hidden dimension, learning rate, or dataset directory; run
python train_int_mnist.py --help for the full list.
By default training runs in fp32 for the first few epochs and then switches to int quantized
forward passes at --quantization-start-epoch. The --quantization flag also accepts plain
integers between 1 and 8 (e.g. --quantization 6 for a 6-bit symmetric quantization) or the
special values int8 / none to use 8-bit math or keep fp32 throughout. After training
completes a serialized state dict is saved to --export-path (default mnist_mlp.pt).
Model code lives in mnist_model.py, which exposes both the quantized training linear wrapper
and an SCLinear building block that the stochastic runner can substitute when needed.
Once you export the checkpoint, run convert_to_stochastic.py to build stochastic bit streams
for each floating tensor (--num-bits controls stream length, default 256). You can choose
between the default random streams and a deterministic unary stream (--stream-mode unary)
where each weight is encoded as a block of 1s followed by 0s. While converting, the script
prints each layer’s mean/std error as a percentage and writes histogram.png
next to the exported archive so you can inspect the overall relative error distribution. The output
bundle (mnist_bitstreams.pt) includes metadata (scales, stream mode, seed) so you can
reproduce the same streams later with either run_unary_model.py or run_stochastic_model.py.
Run python test_unary_conversion.py whenever you change unary mode to confirm that 4-bit
and 8-bit tensors are losslessly reconstructed from 16-bit streams.
After conversion, run_stochastic_model.py loads the bit streams, swaps every linear layer for
an SCLinear, and applies the precomputed stochastic data so you can evaluate the MNIST MLP
with the quantized bitstream representation. Point it at the archive with --bitstreams
(defaults to mnist_bitstreams.pt) and use the same test flags as you would for training.
If you just want to re-run the exported checkpoint with the QuantizedLinear wrapper (no
stochastic streams), use evaluate_quantized.py to load the checkpoint, reapply the saved
quantization bits, and compute MNIST accuracy on the test set.
Run validate_quantized_checkpoint.py --checkpoint mnist_mlp.pt --expected-bits 4 (or any
bit width) to ensure the exported checkpoint actually recorded the requested quantization level.
Then run verify_quantization_consistency.py --checkpoint mnist_mlp.pt --bits 4 to quantize
the checkpoint weights with quant_utils.symmetric_quantize and confirm the stored values
match what the quantization functions produce within a tolerable relative error.