Description
Note: I'll work on seeing if this reproduces with a non-torchchat example.
While working on migrating torchchat's WeightOnlyInt8Quantizer
to AO's quantize_(model, int8_weight_only())
API, I ran into issues where values would go to NaN after a few layers if the model's dtype was initially float16
. This seems to occur across multiple platforms (tested with MPS, Mac CPU, x86 CPU), so I'm not sure if it's a hardware-specific issue.
Interestingly, setting the model dtype to bfloat16
does not encounter this error.
To repro, you can check out this PR with the migration in torchchat
and run a model using:
python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16
You'll notice the model just outputs "!" tokens - representing NaN. If you add a debug hook to the model, you can identify that some values in the intermediate tensors get very close to 0 just before NaN values are detected.
python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16