Skip to content

temprature as numpy.float64 is captured as a fake tensor during torch…#2692

Open
umechand-amd wants to merge 1 commit into
pytorch:mainfrom
umechand-amd:umechand-amd/speech-transformer-export-temperature
Open

temprature as numpy.float64 is captured as a fake tensor during torch…#2692
umechand-amd wants to merge 1 commit into
pytorch:mainfrom
umechand-amd:umechand-amd/speech-transformer-export-temperature

Conversation

@umechand-amd

Copy link
Copy Markdown

Description:

What

speech_transformer fails the --export accuracy benchmark with fail_accuracy.

Why

The attention temperature is set with np.power(d_k, 0.5), which returns a numpy.float64. Under torch.export(strict=True), that numpy scalar gets captured as a fake (meta) constant, so the exported module returns a FakeTensor instead of a real output — and the accuracy check rejects it.

Fix

Cast the temperature to a plain Python float:

temperature=float(np.power(d_k, 0.5))
                                                                                                                                                                                                                                                                                                                                                                                   
Value is unchanged; it just traces as a normal scalar.    
                                                                                                                                                                                                                                                                                                                                                                                   
Testing
                                                                                                                                                                                                                                                                                                                                                                                   
- --export accuracy config now passes (previously fail_accuracy).                                                                                                                                                                                                                                                                                                                  
- python test.py -v -k speech_transformerall pass.
- Not device-specific: confirmed on NVIDIA A100, AMD MI300/MI350X, and CPU.

@meta-cla meta-cla Bot added the cla signed label Jun 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants