Skip to content

Commit e0c7425

Browse files
jammmjeffdaily
authored andcommitted
[ROCm] Fix int4mm device memcpy error on Windows (pytorch#175410)
On Windows with HIP/ROCm, std::memcpy is a __host__ function and cannot be called from __device__ code. Use raw memcpy (which the HIP compiler provides as a device builtin) when building on Windows. This will allow builds for of pytorch for gfx942 on Windows. gfx950 is yet to be tested but it should likely build as well. Pull Request resolved: pytorch#175410 Approved by: https://github.com/jeffdaily
1 parent 8543095 commit e0c7425

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

aten/src/ATen/native/cuda/int4mm.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,14 @@ struct BLayout_TC_int4 {
576576
// type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and
577577
// can't be used as a 32-bit asm register argument for `mma`
578578
static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
579+
// On Windows with ROCm, std::memcpy resolves to a __host__-only
580+
// function and cannot be called from __device__ code. Use the raw
581+
// memcpy which the HIP compiler provides as a __device__ builtin.
582+
#if defined(_WIN32) && defined(USE_ROCM)
583+
memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
584+
#else
579585
std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
586+
#endif
580587
}
581588
}
582589
}

0 commit comments

Comments
 (0)