Skip to content

[Needs more investigation] int8_weight_only via quantize_() API on torch.float16 models results in NaN values across multiple CPU architectures #1662

Open
@vmpuri

Description

@vmpuri

Note: I'll work on seeing if this reproduces with a non-torchchat example.

While working on migrating torchchat's WeightOnlyInt8Quantizer to AO's quantize_(model, int8_weight_only()) API, I ran into issues where values would go to NaN after a few layers if the model's dtype was initially float16. This seems to occur across multiple platforms (tested with MPS, Mac CPU, x86 CPU), so I'm not sure if it's a hardware-specific issue.

Interestingly, setting the model dtype to bfloat16 does not encounter this error.

To repro, you can check out this PR with the migration in torchchat

and run a model using:

python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16

You'll notice the model just outputs "!" tokens - representing NaN. If you add a debug hook to the model, you can identify that some values in the intermediate tensors get very close to 0 just before NaN values are detected.

python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16

Metadata

Metadata

Labels

bugSomething isn't workingquantize

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions