Skip to content

Commit 3e6d0b2

Browse files
authored
Merge pull request #10 from calebrob6/fix/dtype-safe-encodings
Preserve token dtype in CompositeEncodings (fix fp16/bf16 forward)
2 parents ba887c7 + 4aae988 commit 3e6d0b2

2 files changed

Lines changed: 79 additions & 1 deletion

File tree

olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,9 @@ def _apply_encodings_per_modality(
948948
raise ValueError(f"Unsupported tokens shape: {modality_tokens.shape}")
949949

950950
device = modality_tokens.device
951-
modality_embed = torch.zeros(modality_tokens.shape, device=device)
951+
modality_embed = torch.zeros(
952+
modality_tokens.shape, device=device, dtype=modality_tokens.dtype
953+
)
952954
n = self.embedding_dim_per_embedding_type
953955
actual_bandsets = modality_tokens.shape[-2]
954956

tests/test_dtype_consistency.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Regression tests for low-precision (fp16/bf16) forward passes.
2+
3+
Prior to the fix in `flexi_vit.py::_apply_encodings_per_modality`, the
4+
`modality_embed` accumulator was unconditionally allocated as fp32 (via
5+
`torch.zeros(modality_tokens.shape, device=device)` with no `dtype=` kwarg).
6+
PyTorch's type promotion then upcast the encoded tokens back to fp32 even if
7+
the model and input had been converted to fp16/bf16, which broke the very
8+
first `LayerNorm` inside the transformer blocks with:
9+
10+
RuntimeError: expected scalar type Float but found Half
11+
12+
These tests run the encoder forward pass at fp16 and bf16 on CPU and assert
13+
that:
14+
1. no dtype mismatch is raised, and
15+
2. the encoder output dtype matches the input dtype (i.e. it was not
16+
silently upcast back to fp32 by an internal allocation).
17+
18+
We construct the model directly (no Hugging Face download) and request a
19+
patch size equal to the model's base patch size, which avoids unrelated CPU
20+
bicubic-interpolation limits in `F.interpolate`.
21+
"""
22+
23+
import pytest
24+
import torch
25+
26+
from olmoearth_pretrain_minimal import OlmoEarthPretrain_v1
27+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import (
28+
MaskedOlmoEarthSample,
29+
)
30+
31+
PATCH_SIZE = 8
32+
33+
34+
def _make_sample(dtype: torch.dtype) -> MaskedOlmoEarthSample:
35+
B, H, W, T, num_s2_bands = 1, 16, 16, 3, 12
36+
sentinel2_l2a = torch.randn((B, H, W, T, num_s2_bands), dtype=dtype)
37+
sentinel2_l2a_mask = torch.zeros((B, H, W, T, num_s2_bands), dtype=torch.long)
38+
39+
days = torch.randint(0, 25, (B, T, 1), dtype=torch.long)
40+
months = torch.randint(0, 12, (B, T, 1), dtype=torch.long)
41+
years = torch.randint(2018, 2020, (B, T, 1), dtype=torch.long)
42+
timestamps = torch.cat([days, months, years], dim=-1)
43+
44+
return MaskedOlmoEarthSample(
45+
timestamps=timestamps,
46+
sentinel2_l2a=sentinel2_l2a,
47+
sentinel2_l2a_mask=sentinel2_l2a_mask,
48+
)
49+
50+
51+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
52+
def test_nano_encoder_forward_low_precision(dtype: torch.dtype) -> None:
53+
"""Encoder forward pass must not raise a dtype mismatch in fp16 / bf16."""
54+
model = OlmoEarthPretrain_v1(
55+
model_size="nano",
56+
supported_modality_names=["sentinel2_l2a"],
57+
max_patch_size=PATCH_SIZE,
58+
max_sequence_length=3,
59+
)
60+
model = model.to(dtype).eval()
61+
62+
sample = _make_sample(dtype)
63+
64+
with torch.inference_mode():
65+
out = model.encoder(
66+
sample, patch_size=PATCH_SIZE, input_res=10, fast_pass=True
67+
)
68+
69+
# Encoder output dtype should match the model/input dtype, not have been
70+
# silently upcast by an internal fp32 allocation.
71+
tokens = out["tokens_and_masks"].sentinel2_l2a
72+
assert tokens.dtype == dtype, (
73+
f"Expected encoder output dtype {dtype}, got {tokens.dtype}"
74+
)
75+
76+

0 commit comments

Comments
 (0)