Skip to content

Conversation

@Xiaoming-AMD
Copy link
Collaborator

@Xiaoming-AMD Xiaoming-AMD commented Oct 23, 2025

Summary

This PR fixes a precision mismatch issue where nn.Embedding outputs remain in fp32
even when AMP (autocast) is enabled with bf16 or fp16.
This inconsistency causes unnecessary dtype conversions and increased memory usage.


✨ What’s Changed

  • Globally monkey-patched nn.Embedding.__init__ to register a forward hook.
  • The hook:
    • Checks if AMP (torch.is_autocast_enabled()) is active.
    • If yes, casts the embedding output to the current autocast dtype (e.g., bf16).
  • Controlled via environment variable:
 --primus_turbo.enable_embedding_autocast=false  # disables the patch

@Xiaoming-AMD Xiaoming-AMD merged commit 9ee2c51 into main Oct 23, 2025
3 checks passed
@Xiaoming-AMD Xiaoming-AMD deleted the fix/titan-amp/force-embedding-bf16 branch October 27, 2025 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants