|
| 1 | +"""Regression tests for low-precision (fp16/bf16) forward passes. |
| 2 | +
|
| 3 | +Prior to the fix in `flexi_vit.py::_apply_encodings_per_modality`, the |
| 4 | +`modality_embed` accumulator was unconditionally allocated as fp32 (via |
| 5 | +`torch.zeros(modality_tokens.shape, device=device)` with no `dtype=` kwarg). |
| 6 | +PyTorch's type promotion then upcast the encoded tokens back to fp32 even if |
| 7 | +the model and input had been converted to fp16/bf16, which broke the very |
| 8 | +first `LayerNorm` inside the transformer blocks with: |
| 9 | +
|
| 10 | + RuntimeError: expected scalar type Float but found Half |
| 11 | +
|
| 12 | +These tests run the encoder forward pass at fp16 and bf16 on CPU and assert |
| 13 | +that: |
| 14 | + 1. no dtype mismatch is raised, and |
| 15 | + 2. the encoder output dtype matches the input dtype (i.e. it was not |
| 16 | + silently upcast back to fp32 by an internal allocation). |
| 17 | +
|
| 18 | +We construct the model directly (no Hugging Face download) and request a |
| 19 | +patch size equal to the model's base patch size, which avoids unrelated CPU |
| 20 | +bicubic-interpolation limits in `F.interpolate`. |
| 21 | +""" |
| 22 | + |
| 23 | +import pytest |
| 24 | +import torch |
| 25 | + |
| 26 | +from olmoearth_pretrain_minimal import OlmoEarthPretrain_v1 |
| 27 | +from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import ( |
| 28 | + MaskedOlmoEarthSample, |
| 29 | +) |
| 30 | + |
| 31 | +PATCH_SIZE = 8 |
| 32 | + |
| 33 | + |
| 34 | +def _make_sample(dtype: torch.dtype) -> MaskedOlmoEarthSample: |
| 35 | + B, H, W, T, num_s2_bands = 1, 16, 16, 3, 12 |
| 36 | + sentinel2_l2a = torch.randn((B, H, W, T, num_s2_bands), dtype=dtype) |
| 37 | + sentinel2_l2a_mask = torch.zeros((B, H, W, T, num_s2_bands), dtype=torch.long) |
| 38 | + |
| 39 | + days = torch.randint(0, 25, (B, T, 1), dtype=torch.long) |
| 40 | + months = torch.randint(0, 12, (B, T, 1), dtype=torch.long) |
| 41 | + years = torch.randint(2018, 2020, (B, T, 1), dtype=torch.long) |
| 42 | + timestamps = torch.cat([days, months, years], dim=-1) |
| 43 | + |
| 44 | + return MaskedOlmoEarthSample( |
| 45 | + timestamps=timestamps, |
| 46 | + sentinel2_l2a=sentinel2_l2a, |
| 47 | + sentinel2_l2a_mask=sentinel2_l2a_mask, |
| 48 | + ) |
| 49 | + |
| 50 | + |
| 51 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) |
| 52 | +def test_nano_encoder_forward_low_precision(dtype: torch.dtype) -> None: |
| 53 | + """Encoder forward pass must not raise a dtype mismatch in fp16 / bf16.""" |
| 54 | + model = OlmoEarthPretrain_v1( |
| 55 | + model_size="nano", |
| 56 | + supported_modality_names=["sentinel2_l2a"], |
| 57 | + max_patch_size=PATCH_SIZE, |
| 58 | + max_sequence_length=3, |
| 59 | + ) |
| 60 | + model = model.to(dtype).eval() |
| 61 | + |
| 62 | + sample = _make_sample(dtype) |
| 63 | + |
| 64 | + with torch.inference_mode(): |
| 65 | + out = model.encoder( |
| 66 | + sample, patch_size=PATCH_SIZE, input_res=10, fast_pass=True |
| 67 | + ) |
| 68 | + |
| 69 | + # Encoder output dtype should match the model/input dtype, not have been |
| 70 | + # silently upcast by an internal fp32 allocation. |
| 71 | + tokens = out["tokens_and_masks"].sentinel2_l2a |
| 72 | + assert tokens.dtype == dtype, ( |
| 73 | + f"Expected encoder output dtype {dtype}, got {tokens.dtype}" |
| 74 | + ) |
| 75 | + |
| 76 | + |
0 commit comments