diff --git a/src/alpamayo_r1/helper.py b/src/alpamayo_r1/helper.py index b5ae4da..74779f0 100644 --- a/src/alpamayo_r1/helper.py +++ b/src/alpamayo_r1/helper.py @@ -86,11 +86,12 @@ def to_device( ) -> Any: """Recursively cast data into the specified device, dtype.""" if isinstance(data, torch.Tensor): - data = data.to( - device=device, - dtype=dtype, - ) - return data + # Only apply dtype conversion to floating-point tensors. + # Integer tensors (e.g., input_ids, attention_mask) must preserve their dtype + # for compatibility with Hugging Face models during mixed-precision inference. + if dtype is not None and data.is_floating_point(): + return data.to(device=device, dtype=dtype) + return data.to(device=device) elif isinstance(data, collections.abc.Mapping): return {key: to_device(data[key], device=device, dtype=dtype) for key in data} elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):