Preserve token dtype in CompositeEncodings (fix fp16/bf16 forward)#10
Merged
gabrieltseng merged 2 commits intoallenai:mainfrom Apr 30, 2026
Merged
Conversation
The modality_embed accumulator in
`CompositeEncodings._apply_encodings_per_modality` was allocated via
`torch.zeros(modality_tokens.shape, device=device)` with no `dtype=`
kwarg, so it defaulted to fp32 regardless of the input. The final
`return modality_tokens + modality_embed` then triggered PyTorch type
promotion (fp16 + fp32 -> fp32), silently upcasting the encoded tokens
back to fp32 even when the caller had cast the model to fp16/bf16.
Downstream this surfaced as a dtype mismatch at the very first
LayerNorm in each transformer block:
RuntimeError: expected scalar type Float but found Half
Allocating `modality_embed` with `dtype=modality_tokens.dtype` fixes the
issue and makes the encoder work end-to-end in fp16 and bf16 (CUDA and
CPU). fp32 behavior is unchanged.
Adds a regression test that runs the nano encoder forward pass at fp16
and bf16 on CPU and asserts that the output dtype matches the input.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
gabrieltseng
approved these changes
Apr 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
From me: I want to run the olmoearth models in fp16 mode to benchmark throughput performance but was running into the error described below. Looks like the problem is that the
modality_embeddtype in the flexi_vit definition doesn't get changed when you call .half().From copilot:
Problem
Calling
model.encoder(...)(ormodel(...)) aftermodel.half()/.to(torch.bfloat16)fails with:at the very first
LayerNorminside the transformer blocks(
olmoearth_pretrain_v1/nn/attention.py:537,self.norm1(x)).Root cause
In
CompositeEncodings._apply_encodings_per_modality(
olmoearth_pretrain_v1/nn/flexi_vit.py:951):No
dtype=argument, so this always allocatesfloat32regardless ofmodality_tokens.dtype. Channel / time / month / spatial encodings are thenaccumulated into it (lines 968, 973, 980, 996), and line 997 returns
modality_tokens + modality_embed. Under PyTorch type promotion,float16 + float32→float32, so the encoded tokens are silently upcast back to fp32even though the model and input were converted to a low-precision dtype. The
fp32 tokens then enter the first transformer block, whose
norm1.weightisfp16 — boom.
This is reproducible end-to-end on CPU and CUDA with the public
OlmoEarth-v1-Nanoconfiguration (it does not require pre-trained weightsor GPU hardware).
Fix
fp32behavior is unchanged. fp16 / bf16 forward passes now succeed andproduce outputs with a dtype matching the input.
Verification
Verified locally with the OlmoEarth-v1-Nano architecture:
mixed dtype (CPU): expect parameter to have scalar type of Floatexpected scalar type Float but found HalfAll existing non-slow tests still pass:
(Excluded
test_model_equivalence.pyonly because it depends onolmoearth-pretrain==0.1.0which I do not have installed locally — CI willexercise it.)
Regression test
Adds
tests/test_dtype_consistency.pywith a parametrised test that runs thenano encoder forward pass at fp16 and bf16 on CPU and asserts that the
encoder output dtype matches the input dtype. The test:
OlmoEarthPretrain_v1(...)so it doesnot need network access or HF auth in CI.
patch_size == max_patch_size == 8to avoid a separate, unrelatedCPU PyTorch limit (
F.interpolate(... mode="bicubic", antialias=True)isnot implemented for fp16/bf16 on CPU, hit by the
FlexiPatchEmbedresize path when patch sizes differ). That limit is orthogonal to this
PR — fp16/bf16 work fine on CUDA at any patch size after this fix.
Same bug upstream
The identical line exists in
allenai/olmoearth_pretrainatolmoearth_pretrain/nn/flexi_vit.py.Given the equivalence test added in #6, you may want to mirror the change
there. Happy to open a parallel PR if useful.
Related defense-in-depth (not in this PR)
get_2d_sincos_pos_encoding_with_resolutioninolmoearth_pretrain_v1/nn/encodings.pyunconditionally returns fp32because it builds
torch.arange(...)and sin/cos without adtypekwarg.The result is currently in-place added (
+=) intomodality_embed, soonce
modality_embedhas the right dtype (this PR), the inplace oppreserves it and the upcast does not propagate. But it is a latent
hazard that would re-introduce the bug if the inplace op were ever
refactored to a non-inplace add. A follow-up could thread
dtype=modality_tokens.dtype(or a final cast) through there fordefense-in-depth — let me know if you would like that included here.