diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index 7c0689b4..396b4fa1 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -18,6 +18,16 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, return __shfl_down_sync(mask, var.operator __half(), delta); } +__device__ __inline__ at::Half __shfl_up(const at::Half var, + const unsigned int delta) { + return __shfl_up(var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_down(const at::Half var, + const unsigned int delta) { + return __shfl_down(var.operator __half(), delta); +} + #ifdef USE_ROCM __device__ __inline__ at::Half __ldg(const at::Half* ptr) { return __ldg(reinterpret_cast(ptr));