diff --git a/src/alpamayo_r1/helper.py b/src/alpamayo_r1/helper.py index b5ae4da..53fa0f7 100644 --- a/src/alpamayo_r1/helper.py +++ b/src/alpamayo_r1/helper.py @@ -84,13 +84,15 @@ def to_device( device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> Any: - """Recursively cast data into the specified device, dtype.""" + """Recursively cast data into the specified device, dtype. + + Note: dtype conversion is only applied to floating-point tensors. + Integer and boolean tensors preserve their original dtype. + """ if isinstance(data, torch.Tensor): - data = data.to( - device=device, - dtype=dtype, - ) - return data + if dtype is not None and data.is_floating_point(): + return data.to(device=device, dtype=dtype) + return data.to(device=device) elif isinstance(data, collections.abc.Mapping): return {key: to_device(data[key], device=device, dtype=dtype) for key in data} elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):