Skip to content

Commit 4aae988

Browse files
authored
Add float32 to test -- manually :O
1 parent 2ec5604 commit 4aae988

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

tests/test_dtype_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _make_sample(dtype: torch.dtype) -> MaskedOlmoEarthSample:
4848
)
4949

5050

51-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
51+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
5252
def test_nano_encoder_forward_low_precision(dtype: torch.dtype) -> None:
5353
"""Encoder forward pass must not raise a dtype mismatch in fp16 / bf16."""
5454
model = OlmoEarthPretrain_v1(

0 commit comments

Comments
 (0)