Skip to content

Commit d949707

Browse files
committed
TL/CUDA: Add int datatype support for NVLS
This PR adds support for the following additional integer data types in NVLS (NVLink SHARP) collective operations allreduce and reduce_scatter: - INT32 (s32): 32-bit signed integer - INT64 (s64): 64-bit signed integer - UINT32 (u32): 32-bit unsigned integer - UINT64 (u64): 64-bit unsigned integer Added PTX multimem.ld_reduce and multimem.st instructions for each type Created NvlsOps structs for type-specific operations Updated allreduce and reduce_scatter kernels with new data type handling Modified validation logic to accept new data types Signed-off-by: Juee14Desai <jueehimalbha@nvidia.com>
1 parent b7d4f76 commit d949707

File tree

6 files changed

+321
-23
lines changed

6 files changed

+321
-23
lines changed

src/components/tl/cuda/allreduce/allreduce.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ ucc_status_t ucc_tl_cuda_allreduce_nvls_init(ucc_base_coll_args_t *coll_args,
3636
ucc_base_team_t *team,
3737
ucc_coll_task_t **task_h);
3838

39+
static inline int
40+
ucc_tl_cuda_allreduce_nvls_dt_supported(ucc_datatype_t dt)
41+
{
42+
switch (dt) {
43+
case UCC_DT_FLOAT32:
44+
case UCC_DT_BFLOAT16:
45+
case UCC_DT_INT32:
46+
case UCC_DT_UINT32:
47+
case UCC_DT_INT64:
48+
case UCC_DT_UINT64:
49+
return 1;
50+
default:
51+
return 0;
52+
}
53+
}
54+
3955
static inline int ucc_tl_cuda_allreduce_alg_from_str(const char *str)
4056
{
4157
int i;

src/components/tl/cuda/allreduce/allreduce_nvls.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,13 @@ ucc_status_t ucc_tl_cuda_allreduce_nvls_init(
186186
ucc_status_t status;
187187

188188
if (buf_size < 1024 || coll_args->args.op != UCC_OP_SUM ||
189-
(coll_args->args.dst.info.datatype != UCC_DT_FLOAT32 &&
190-
coll_args->args.dst.info.datatype != UCC_DT_BFLOAT16)) {
189+
!ucc_tl_cuda_allreduce_nvls_dt_supported(
190+
coll_args->args.dst.info.datatype)) {
191191
tl_debug(
192192
UCC_TL_TEAM_LIB(team),
193193
"NVLS allreduce is supported only with SUM operation "
194-
"and float32 or bfloat16 datatype, with message size >= 1024 "
195-
"bytes");
194+
"and float32, bfloat16, int32, uint32, int64, or uint64 "
195+
"datatype, with message size >= 1024 bytes");
196196
return UCC_ERR_NOT_SUPPORTED;
197197
}
198198
if (ucc_unlikely(

src/components/tl/cuda/kernels/allreduce_kernel.cu

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,72 @@ __global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
4848
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 2));
4949
}
5050

51+
template <typename NvlsOps>
52+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
53+
allreduce_kernel_scalar32(ucc_tl_cuda_nvls_control_t *mc_bar,
54+
ucc_tl_cuda_nvls_control_t *uc_bar,
55+
const uint32_t total_blocks,
56+
uint64_t launch_counter,
57+
uint32_t *base_u32, size_t count_u32, uint32_t rank,
58+
uint32_t tsize)
59+
{
60+
// pre barrier
61+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 1));
62+
63+
// Kernel execution
64+
size_t chunk_start = ((int64_t)count_u32 * (int64_t)rank) / (int64_t)tsize;
65+
size_t chunk_end = ((int64_t)count_u32 * (int64_t)(rank + 1)) / (int64_t)tsize;
66+
67+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4;
68+
size_t stride = blockDim.x * gridDim.x * 4;
69+
70+
for (size_t idx = chunk_start + thread_offset; idx < chunk_end; idx += stride) {
71+
typename NvlsOps::value_type v0, v1, v2, v3;
72+
NvlsOps::ld(v0, base_u32 + idx + 0);
73+
NvlsOps::ld(v1, base_u32 + idx + 1);
74+
NvlsOps::ld(v2, base_u32 + idx + 2);
75+
NvlsOps::ld(v3, base_u32 + idx + 3);
76+
NvlsOps::st(v0, base_u32 + idx + 0);
77+
NvlsOps::st(v1, base_u32 + idx + 1);
78+
NvlsOps::st(v2, base_u32 + idx + 2);
79+
NvlsOps::st(v3, base_u32 + idx + 3);
80+
}
81+
82+
// post barrier
83+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 2));
84+
}
85+
86+
template <typename NvlsOps>
87+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
88+
allreduce_kernel_scalar64(ucc_tl_cuda_nvls_control_t *mc_bar,
89+
ucc_tl_cuda_nvls_control_t *uc_bar,
90+
const uint32_t total_blocks,
91+
uint64_t launch_counter,
92+
uint64_t *base_u64, size_t count_u64, uint32_t rank,
93+
uint32_t tsize)
94+
{
95+
// pre barrier
96+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 1));
97+
98+
// Kernel execution
99+
size_t chunk_start = ((int64_t)count_u64 * (int64_t)rank) / (int64_t)tsize;
100+
size_t chunk_end = ((int64_t)count_u64 * (int64_t)(rank + 1)) / (int64_t)tsize;
101+
102+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 2;
103+
size_t stride = blockDim.x * gridDim.x * 2;
104+
105+
for (size_t idx = chunk_start + thread_offset; idx < chunk_end; idx += stride) {
106+
typename NvlsOps::value_type v0, v1;
107+
NvlsOps::ld(v0, base_u64 + idx + 0);
108+
NvlsOps::ld(v1, base_u64 + idx + 1);
109+
NvlsOps::st(v0, base_u64 + idx + 0);
110+
NvlsOps::st(v1, base_u64 + idx + 1);
111+
}
112+
113+
// post barrier
114+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 2));
115+
}
116+
51117
#ifdef __cplusplus
52118
extern "C" {
53119
#endif
@@ -69,17 +135,40 @@ ucc_status_t post_allreduce_kernel(cudaStream_t stream, uint32_t sm_count,
69135
ucc_tl_cuda_nvls_control_t *uc_bar = reinterpret_cast<ucc_tl_cuda_nvls_control_t *>(uc_control_addr);
70136
uint32_t expected_blocks = sm_count * tsize; // total num of blocks in the multicast group, num gpus * num blocks per gpu, used for barrier synchronization
71137

138+
assert(((uintptr_t)(mc_base_addr) % 8) == 0);
72139
switch (datatype) {
73140
case UCC_DT_FLOAT32:
74-
assert(((uintptr_t)(mc_base_addr) % 8) == 0);
75141
allreduce_kernel_vec32<NvlsFp32Ops><<<sm_count, threads, 0, stream>>>(
76142
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
77143
break;
78144
case UCC_DT_BFLOAT16:
79-
assert(((uintptr_t)(mc_base_addr) % 8) == 0);
80145
allreduce_kernel_vec32<NvlsBf16Ops><<<sm_count, threads, 0, stream>>>(
81146
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
82147
break;
148+
case UCC_DT_INT32:
149+
allreduce_kernel_scalar32<NvlsInt32Ops><<<sm_count, threads, 0, stream>>>(
150+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
151+
break;
152+
case UCC_DT_UINT32:
153+
allreduce_kernel_scalar32<NvlsUint32Ops><<<sm_count, threads, 0, stream>>>(
154+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
155+
break;
156+
case UCC_DT_INT64:
157+
{
158+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
159+
size_t count_u64 = src_size_bytes / sizeof(uint64_t);
160+
allreduce_kernel_scalar64<NvlsInt64Ops><<<sm_count, threads, 0, stream>>>(
161+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u64, count_u64, rank, tsize);
162+
}
163+
break;
164+
case UCC_DT_UINT64:
165+
{
166+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
167+
size_t count_u64 = src_size_bytes / sizeof(uint64_t);
168+
allreduce_kernel_scalar64<NvlsUint64Ops><<<sm_count, threads, 0, stream>>>(
169+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u64, count_u64, rank, tsize);
170+
}
171+
break;
83172
default:
84173
return UCC_ERR_NOT_SUPPORTED;
85174
}

src/components/tl/cuda/kernels/nvls.cuh

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,67 @@ struct NvlsBf16Ops {
8181
MULTIMEM_ST_BF16(v, ptr);
8282
}
8383
};
84+
85+
struct NvlsInt32Ops {
86+
using value_type = int;
87+
__device__ static inline void ld(int &v, const uint32_t *ptr) {
88+
asm("multimem.ld_reduce.global.add.s32 %0, [%1];"
89+
: "=r"(v)
90+
: "l"(ptr)
91+
: "memory");
92+
}
93+
__device__ static inline void st(const int &v, uint32_t *ptr) {
94+
asm volatile("multimem.st.global.s32 [%0], %1;" ::"l"(ptr),
95+
"r"(v)
96+
: "memory");
97+
}
98+
};
99+
100+
struct NvlsUint32Ops {
101+
using value_type = unsigned int;
102+
__device__ static inline void ld(unsigned int &v, const uint32_t *ptr) {
103+
asm("multimem.ld_reduce.global.add.u32 %0, [%1];"
104+
: "=r"(v)
105+
: "l"(ptr)
106+
: "memory");
107+
}
108+
__device__ static inline void st(const unsigned int &v, uint32_t *ptr) {
109+
asm volatile("multimem.st.global.u32 [%0], %1;" ::"l"(ptr),
110+
"r"(v)
111+
: "memory");
112+
}
113+
};
114+
115+
// PTX does not support s64 with add operation, so we use u64 instead
116+
struct NvlsInt64Ops {
117+
using value_type = unsigned long long;
118+
__device__ static inline void ld(unsigned long long &v, const uint64_t *ptr) {
119+
asm("multimem.ld_reduce.global.add.u64 %0, [%1];"
120+
: "=l"(v)
121+
: "l"(ptr)
122+
: "memory");
123+
}
124+
__device__ static inline void st(const unsigned long long &v, uint64_t *ptr) {
125+
asm volatile("multimem.st.global.u64 [%0], %1;" ::"l"(ptr),
126+
"l"(v)
127+
: "memory");
128+
}
129+
};
130+
131+
struct NvlsUint64Ops {
132+
using value_type = unsigned long long;
133+
__device__ static inline void ld(unsigned long long &v, const uint64_t *ptr) {
134+
asm("multimem.ld_reduce.global.add.u64 %0, [%1];"
135+
: "=l"(v)
136+
: "l"(ptr)
137+
: "memory");
138+
}
139+
__device__ static inline void st(const unsigned long long &v, uint64_t *ptr) {
140+
asm volatile("multimem.st.global.u64 [%0], %1;" ::"l"(ptr),
141+
"l"(v)
142+
: "memory");
143+
}
144+
};
84145
#endif // __cplusplus
85146

86147
#endif // UCC_TL_CUDA_NVLS_CUH_

src/components/tl/cuda/kernels/reduce_scatter_kernel.cu

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ __global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
3232
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4;
3333
size_t stride = blockDim.x * gridDim.x * 4;
3434

35-
for (size_t idx = offset + thread_offset; idx < offset + count;
35+
for (size_t idx = offset + thread_offset; idx + 3 < offset + count;
3636
idx += stride) {
3737
uint4 val;
3838
NvlsOps::ld(val, base_u32 + idx);
@@ -47,6 +47,74 @@ __global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
4747
total_blocks * (launch_counter * 2 + 2));
4848
}
4949

50+
template <typename NvlsOps>
51+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
52+
reduce_scatter_kernel_scalar32(
53+
ucc_tl_cuda_nvls_control_t *mc_bar, ucc_tl_cuda_nvls_control_t *uc_bar,
54+
const uint32_t total_blocks, uint64_t launch_counter,
55+
uint32_t *base_u32, size_t offset, size_t count, uint32_t *dst_u32)
56+
{
57+
// pre barrier
58+
nvls_bar(
59+
&(mc_bar->arrival_counter),
60+
&(uc_bar->arrival_counter),
61+
total_blocks * (launch_counter * 2 + 1));
62+
63+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4;
64+
size_t stride = blockDim.x * gridDim.x * 4;
65+
66+
for (size_t idx = offset + thread_offset; idx + 3 < offset + count;
67+
idx += stride) {
68+
typename NvlsOps::value_type v0, v1, v2, v3;
69+
NvlsOps::ld(v0, base_u32 + idx + 0);
70+
NvlsOps::ld(v1, base_u32 + idx + 1);
71+
NvlsOps::ld(v2, base_u32 + idx + 2);
72+
NvlsOps::ld(v3, base_u32 + idx + 3);
73+
dst_u32[idx - offset + 0] = v0;
74+
dst_u32[idx - offset + 1] = v1;
75+
dst_u32[idx - offset + 2] = v2;
76+
dst_u32[idx - offset + 3] = v3;
77+
}
78+
79+
// post barrier
80+
nvls_bar(
81+
&(mc_bar->arrival_counter),
82+
&(uc_bar->arrival_counter),
83+
total_blocks * (launch_counter * 2 + 2));
84+
}
85+
86+
template <typename NvlsOps>
87+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
88+
reduce_scatter_kernel_scalar64(
89+
ucc_tl_cuda_nvls_control_t *mc_bar, ucc_tl_cuda_nvls_control_t *uc_bar,
90+
const uint32_t total_blocks, uint64_t launch_counter,
91+
uint64_t *base_u64, size_t offset, size_t count, uint64_t *dst_u64)
92+
{
93+
// pre barrier
94+
nvls_bar(
95+
&(mc_bar->arrival_counter),
96+
&(uc_bar->arrival_counter),
97+
total_blocks * (launch_counter * 2 + 1));
98+
99+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 2;
100+
size_t stride = blockDim.x * gridDim.x * 2;
101+
102+
for (size_t idx = offset + thread_offset; idx + 1 < offset + count;
103+
idx += stride) {
104+
typename NvlsOps::value_type v0, v1;
105+
NvlsOps::ld(v0, base_u64 + idx + 0);
106+
NvlsOps::ld(v1, base_u64 + idx + 1);
107+
dst_u64[idx - offset + 0] = v0;
108+
dst_u64[idx - offset + 1] = v1;
109+
}
110+
111+
// post barrier
112+
nvls_bar(
113+
&(mc_bar->arrival_counter),
114+
&(uc_bar->arrival_counter),
115+
total_blocks * (launch_counter * 2 + 2));
116+
}
117+
50118
#ifdef __cplusplus
51119
extern "C" {
52120
#endif
@@ -101,6 +169,64 @@ ucc_status_t post_reduce_scatter_kernel(
101169
count,
102170
reinterpret_cast<uint32_t *>(dst_ptr));
103171
break;
172+
case UCC_DT_INT32:
173+
reduce_scatter_kernel_scalar32<NvlsInt32Ops>
174+
<<<sm_count, threads, 0, stream>>>(
175+
mc_bar,
176+
uc_bar,
177+
expected_blocks,
178+
launch_counter,
179+
base_u32,
180+
offset,
181+
count,
182+
reinterpret_cast<uint32_t *>(dst_ptr));
183+
break;
184+
case UCC_DT_UINT32:
185+
reduce_scatter_kernel_scalar32<NvlsUint32Ops>
186+
<<<sm_count, threads, 0, stream>>>(
187+
mc_bar,
188+
uc_bar,
189+
expected_blocks,
190+
launch_counter,
191+
base_u32,
192+
offset,
193+
count,
194+
reinterpret_cast<uint32_t *>(dst_ptr));
195+
break;
196+
case UCC_DT_INT64:
197+
{
198+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
199+
size_t offset_u64 = offset / 2;
200+
size_t count_u64 = count / 2;
201+
reduce_scatter_kernel_scalar64<NvlsInt64Ops>
202+
<<<sm_count, threads, 0, stream>>>(
203+
mc_bar,
204+
uc_bar,
205+
expected_blocks,
206+
launch_counter,
207+
base_u64,
208+
offset_u64,
209+
count_u64,
210+
reinterpret_cast<uint64_t *>(dst_ptr));
211+
}
212+
break;
213+
case UCC_DT_UINT64:
214+
{
215+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
216+
size_t offset_u64 = offset / 2;
217+
size_t count_u64 = count / 2;
218+
reduce_scatter_kernel_scalar64<NvlsUint64Ops>
219+
<<<sm_count, threads, 0, stream>>>(
220+
mc_bar,
221+
uc_bar,
222+
expected_blocks,
223+
launch_counter,
224+
base_u64,
225+
offset_u64,
226+
count_u64,
227+
reinterpret_cast<uint64_t *>(dst_ptr));
228+
}
229+
break;
104230
default:
105231
return UCC_ERR_NOT_SUPPORTED;
106232
}

0 commit comments

Comments
 (0)