diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 923f1152c9d..8f7be2eacb2 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -143,9 +143,11 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 : xla::PrimitiveType::S16; case xla::PrimitiveType::S64: - return xla::PrimitiveType::S64; + return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 + : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: - return xla::PrimitiveType::U64; + return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32 + : xla::PrimitiveType::U64; case xla::PrimitiveType::C128: return xla::PrimitiveType::C128; default: