Skip to content

Commit 81ecf9a

Browse files
jeffdailyjammm
andauthored
[release/2.9] Fix int4mm device memcpy error on Windows (pytorch#175410) (#3166)
On Windows with HIP/ROCm, std::memcpy is a __host__ function and cannot be called from __device__ code. Use raw memcpy (which the HIP compiler provides as a device builtin) when building on Windows. This will allow builds for of pytorch for gfx942 on Windows. gfx950 is yet to be tested but it should likely build as well. Pull Request resolved: pytorch#175410 Approved by: https://github.com/jeffdaily Co-authored-by: Aaryaman Vasishta <aaryaman.vasishta@amd.com>
1 parent 813feb6 commit 81ecf9a

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

aten/src/ATen/native/cuda/int4mm.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,14 @@ struct BLayout_TC_int4 {
581581
// type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and
582582
// can't be used as a 32-bit asm register argument for `mma`
583583
static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
584+
// On Windows with ROCm, std::memcpy resolves to a __host__-only
585+
// function and cannot be called from __device__ code. Use the raw
586+
// memcpy which the HIP compiler provides as a __device__ builtin.
587+
#if defined(_WIN32) && defined(USE_ROCM)
588+
memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
589+
#else
584590
std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
591+
#endif
585592
}
586593
}
587594
}

0 commit comments

Comments
 (0)