Skip to content

Preserve token dtype in CompositeEncodings (fix fp16/bf16 forward)#10

Merged
gabrieltseng merged 2 commits intoallenai:mainfrom
calebrob6:fix/dtype-safe-encodings
Apr 30, 2026
Merged

Preserve token dtype in CompositeEncodings (fix fp16/bf16 forward)#10
gabrieltseng merged 2 commits intoallenai:mainfrom
calebrob6:fix/dtype-safe-encodings

Conversation

@calebrob6
Copy link
Copy Markdown
Contributor

@calebrob6 calebrob6 commented Apr 30, 2026

From me: I want to run the olmoearth models in fp16 mode to benchmark throughput performance but was running into the error described below. Looks like the problem is that the modality_embed dtype in the flexi_vit definition doesn't get changed when you call .half().

From copilot:

Problem

Calling model.encoder(...) (or model(...)) after model.half() /
.to(torch.bfloat16) fails with:

RuntimeError: expected scalar type Float but found Half

at the very first LayerNorm inside the transformer blocks
(olmoearth_pretrain_v1/nn/attention.py:537, self.norm1(x)).

Root cause

In CompositeEncodings._apply_encodings_per_modality
(olmoearth_pretrain_v1/nn/flexi_vit.py:951):

modality_embed = torch.zeros(modality_tokens.shape, device=device)

No dtype= argument, so this always allocates float32 regardless of
modality_tokens.dtype. Channel / time / month / spatial encodings are then
accumulated into it (lines 968, 973, 980, 996), and line 997 returns
modality_tokens + modality_embed. Under PyTorch type promotion, float16 + float32float32, so the encoded tokens are silently upcast back to fp32
even though the model and input were converted to a low-precision dtype. The
fp32 tokens then enter the first transformer block, whose norm1.weight is
fp16 — boom.

This is reproducible end-to-end on CPU and CUDA with the public
OlmoEarth-v1-Nano configuration (it does not require pre-trained weights
or GPU hardware).

Fix

modality_embed = torch.zeros(
    modality_tokens.shape, device=device, dtype=modality_tokens.dtype
)

fp32 behavior is unchanged. fp16 / bf16 forward passes now succeed and
produce outputs with a dtype matching the input.

Verification

Verified locally with the OlmoEarth-v1-Nano architecture:

dtype device before after
fp32 CPU/CUDA ✅ pass ✅ pass (unchanged)
fp16 CPU mixed dtype (CPU): expect parameter to have scalar type of Float ✅ pass
fp16 CUDA expected scalar type Float but found Half ✅ pass
bf16 CUDA ❌ same ✅ pass

All existing non-slow tests still pass:

$ pytest tests/ -m "not slow" --ignore=tests/test_model_equivalence.py
====================== 10 passed, 4 deselected in 13.61s =======================

(Excluded test_model_equivalence.py only because it depends on
olmoearth-pretrain==0.1.0 which I do not have installed locally — CI will
exercise it.)

Regression test

Adds tests/test_dtype_consistency.py with a parametrised test that runs the
nano encoder forward pass at fp16 and bf16 on CPU and asserts that the
encoder output dtype matches the input dtype. The test:

  • Constructs the model directly via OlmoEarthPretrain_v1(...) so it does
    not need network access or HF auth in CI.
  • Uses patch_size == max_patch_size == 8 to avoid a separate, unrelated
    CPU PyTorch limit (F.interpolate(... mode="bicubic", antialias=True) is
    not implemented for fp16/bf16 on CPU, hit by the FlexiPatchEmbed
    resize path when patch sizes differ). That limit is orthogonal to this
    PR — fp16/bf16 work fine on CUDA at any patch size after this fix.
  • Runs in ~1.8s on CPU and confirms both fp16 and bf16.

Same bug upstream

The identical line exists in
allenai/olmoearth_pretrain at
olmoearth_pretrain/nn/flexi_vit.py.
Given the equivalence test added in #6, you may want to mirror the change
there. Happy to open a parallel PR if useful.

Related defense-in-depth (not in this PR)

get_2d_sincos_pos_encoding_with_resolution in
olmoearth_pretrain_v1/nn/encodings.py unconditionally returns fp32
because it builds torch.arange(...) and sin/cos without a dtype kwarg.
The result is currently in-place added (+=) into modality_embed, so
once modality_embed has the right dtype (this PR), the inplace op
preserves it and the upcast does not propagate. But it is a latent
hazard that would re-introduce the bug if the inplace op were ever
refactored to a non-inplace add. A follow-up could thread
dtype=modality_tokens.dtype (or a final cast) through there for
defense-in-depth — let me know if you would like that included here.

The modality_embed accumulator in
`CompositeEncodings._apply_encodings_per_modality` was allocated via
`torch.zeros(modality_tokens.shape, device=device)` with no `dtype=`
kwarg, so it defaulted to fp32 regardless of the input. The final
`return modality_tokens + modality_embed` then triggered PyTorch type
promotion (fp16 + fp32 -> fp32), silently upcasting the encoded tokens
back to fp32 even when the caller had cast the model to fp16/bf16.

Downstream this surfaced as a dtype mismatch at the very first
LayerNorm in each transformer block:

    RuntimeError: expected scalar type Float but found Half

Allocating `modality_embed` with `dtype=modality_tokens.dtype` fixes the
issue and makes the encoder work end-to-end in fp16 and bf16 (CUDA and
CPU). fp32 behavior is unchanged.

Adds a regression test that runs the nano encoder forward pass at fp16
and bf16 on CPU and asserts that the output dtype matches the input.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread tests/test_dtype_consistency.py Outdated
@gabrieltseng gabrieltseng merged commit 3e6d0b2 into allenai:main Apr 30, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants