Skip to content

Remove missed out .float() cast and add dtype-safe handling for torch.linalg.inv in transfuser model#477

Merged
kamalrajkannan78 merged 1 commit intomainfrom
kkannan/feb11_transfuser_bfp16
Feb 12, 2026
Merged

Remove missed out .float() cast and add dtype-safe handling for torch.linalg.inv in transfuser model#477
kamalrajkannan78 merged 1 commit intomainfrom
kkannan/feb11_transfuser_bfp16

Conversation

@kamalrajkannan78
Copy link
Contributor

@kamalrajkannan78 kamalrajkannan78 commented Feb 11, 2026

Ticket

Problem description

  • This PR missed removing a .float() cast, which causes the entire post-processing part of the model to run in fp32.

What's changed

  • .float() changed to .to(scores.dtype): Now topk_xs inherits the dtype from the input scores tensor
  • torch.linalg.inv doesn't support Low precision dtypes. so the input is cast to fp32 before the operation and the output is cast back to the original dtype.

Checklist

  • Verified the changes through local testing

Logs

CPU runs

TT-XLA runs

@kamalrajkannan78 kamalrajkannan78 force-pushed the kkannan/feb11_transfuser_bfp16 branch from 0182c5e to 9efc316 Compare February 12, 2026 09:10
@kamalrajkannan78 kamalrajkannan78 merged commit f2a6550 into main Feb 12, 2026
2 checks passed
@kamalrajkannan78 kamalrajkannan78 deleted the kkannan/feb11_transfuser_bfp16 branch February 12, 2026 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enable bfloat16 support in transfuser model

3 participants