Skip to content

Commit a6f3278

Browse files
committed
TL/CUDA: Add int datatype support for NVLS
This PR adds support for the following additional integer data types in NVLS (NVLink SHARP) collective operations allreduce and reduce_scatter: - INT32 (s32): 32-bit signed integer with v4 vectorization - INT64 (s64): 64-bit signed integer with v2 vectorization - UINT32 (u32): 32-bit unsigned integer with v4 vectorization - UINT64 (u64): 64-bit unsigned integer with v2 vectorization Added PTX multimem.ld_reduce and multimem.st instructions for each type Created NvlsOps structs for type-specific operations Updated allreduce and reduce_scatter kernels with new data type handling Modified validation logic to accept new data types Signed-off-by: Juee14Desai <jueehimalbha@nvidia.com>
1 parent ec0bc8a commit a6f3278

File tree

5 files changed

+294
-8
lines changed

5 files changed

+294
-8
lines changed

src/components/tl/cuda/allreduce/allreduce_nvls.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,15 @@ ucc_status_t ucc_tl_cuda_allreduce_nvls_init(
187187

188188
if (buf_size < 1024 || coll_args->args.op != UCC_OP_SUM ||
189189
(coll_args->args.dst.info.datatype != UCC_DT_FLOAT32 &&
190-
coll_args->args.dst.info.datatype != UCC_DT_BFLOAT16)) {
190+
coll_args->args.dst.info.datatype != UCC_DT_BFLOAT16 &&
191+
coll_args->args.dst.info.datatype != UCC_DT_INT32 &&
192+
coll_args->args.dst.info.datatype != UCC_DT_UINT32 &&
193+
coll_args->args.dst.info.datatype != UCC_DT_INT64 &&
194+
coll_args->args.dst.info.datatype != UCC_DT_UINT64)) {
191195
tl_debug(
192196
UCC_TL_TEAM_LIB(team),
193197
"NVLS allreduce is supported only with SUM operation "
194-
"and float32 or bfloat16 datatype, with message size >= 1024 "
198+
"and float32, bfloat16, int32, uint32, int64, or uint64 datatype, with message size >= 1024 "
195199
"bytes");
196200
return UCC_ERR_NOT_SUPPORTED;
197201
}

src/components/tl/cuda/kernels/allreduce_kernel.cu

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,72 @@ __global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
4848
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 2));
4949
}
5050

51+
template <typename NvlsOps>
52+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
53+
allreduce_kernel_scalar32(ucc_tl_cuda_nvls_control_t *mc_bar,
54+
ucc_tl_cuda_nvls_control_t *uc_bar,
55+
const uint32_t total_blocks,
56+
uint64_t launch_counter,
57+
uint32_t *base_u32, size_t count_u32, uint32_t rank,
58+
uint32_t tsize)
59+
{
60+
// pre barrier
61+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 1));
62+
63+
// Kernel execution
64+
size_t chunk_start = ((int64_t)count_u32 * (int64_t)rank) / (int64_t)tsize;
65+
size_t chunk_end = ((int64_t)count_u32 * (int64_t)(rank + 1)) / (int64_t)tsize;
66+
67+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4;
68+
size_t stride = blockDim.x * gridDim.x * 4;
69+
70+
for (size_t idx = chunk_start + thread_offset; idx < chunk_end; idx += stride) {
71+
typename NvlsOps::value_type v0, v1, v2, v3;
72+
NvlsOps::ld(v0, base_u32 + idx + 0);
73+
NvlsOps::ld(v1, base_u32 + idx + 1);
74+
NvlsOps::ld(v2, base_u32 + idx + 2);
75+
NvlsOps::ld(v3, base_u32 + idx + 3);
76+
NvlsOps::st(v0, base_u32 + idx + 0);
77+
NvlsOps::st(v1, base_u32 + idx + 1);
78+
NvlsOps::st(v2, base_u32 + idx + 2);
79+
NvlsOps::st(v3, base_u32 + idx + 3);
80+
}
81+
82+
// post barrier
83+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 2));
84+
}
85+
86+
template <typename NvlsOps>
87+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
88+
allreduce_kernel_scalar64(ucc_tl_cuda_nvls_control_t *mc_bar,
89+
ucc_tl_cuda_nvls_control_t *uc_bar,
90+
const uint32_t total_blocks,
91+
uint64_t launch_counter,
92+
uint64_t *base_u64, size_t count_u64, uint32_t rank,
93+
uint32_t tsize)
94+
{
95+
// pre barrier
96+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 1));
97+
98+
// Kernel execution
99+
size_t chunk_start = ((int64_t)count_u64 * (int64_t)rank) / (int64_t)tsize;
100+
size_t chunk_end = ((int64_t)count_u64 * (int64_t)(rank + 1)) / (int64_t)tsize;
101+
102+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 2;
103+
size_t stride = blockDim.x * gridDim.x * 2;
104+
105+
for (size_t idx = chunk_start + thread_offset; idx < chunk_end; idx += stride) {
106+
typename NvlsOps::value_type v0, v1;
107+
NvlsOps::ld(v0, base_u64 + idx + 0);
108+
NvlsOps::ld(v1, base_u64 + idx + 1);
109+
NvlsOps::st(v0, base_u64 + idx + 0);
110+
NvlsOps::st(v1, base_u64 + idx + 1);
111+
}
112+
113+
// post barrier
114+
nvls_bar(&(mc_bar->arrival_counter), &(uc_bar->arrival_counter), total_blocks * (launch_counter * 2 + 2));
115+
}
116+
51117
#ifdef __cplusplus
52118
extern "C" {
53119
#endif
@@ -69,17 +135,40 @@ ucc_status_t post_allreduce_kernel(cudaStream_t stream, uint32_t sm_count,
69135
ucc_tl_cuda_nvls_control_t *uc_bar = reinterpret_cast<ucc_tl_cuda_nvls_control_t *>(uc_control_addr);
70136
uint32_t expected_blocks = sm_count * tsize; // total num of blocks in the multicast group, num gpus * num blocks per gpu, used for barrier synchronization
71137

138+
assert(((uintptr_t)(mc_base_addr) % 8) == 0);
72139
switch (datatype) {
73140
case UCC_DT_FLOAT32:
74-
assert(((uintptr_t)(mc_base_addr) % 8) == 0);
75141
allreduce_kernel_vec32<NvlsFp32Ops><<<sm_count, threads, 0, stream>>>(
76142
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
77143
break;
78144
case UCC_DT_BFLOAT16:
79-
assert(((uintptr_t)(mc_base_addr) % 8) == 0);
80145
allreduce_kernel_vec32<NvlsBf16Ops><<<sm_count, threads, 0, stream>>>(
81146
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
82147
break;
148+
case UCC_DT_INT32:
149+
allreduce_kernel_scalar32<NvlsInt32Ops><<<sm_count, threads, 0, stream>>>(
150+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
151+
break;
152+
case UCC_DT_UINT32:
153+
allreduce_kernel_scalar32<NvlsUint32Ops><<<sm_count, threads, 0, stream>>>(
154+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u32, count_u32, rank, tsize);
155+
break;
156+
case UCC_DT_INT64:
157+
{
158+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
159+
size_t count_u64 = src_size_bytes / sizeof(uint64_t);
160+
allreduce_kernel_scalar64<NvlsInt64Ops><<<sm_count, threads, 0, stream>>>(
161+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u64, count_u64, rank, tsize);
162+
}
163+
break;
164+
case UCC_DT_UINT64:
165+
{
166+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
167+
size_t count_u64 = src_size_bytes / sizeof(uint64_t);
168+
allreduce_kernel_scalar64<NvlsUint64Ops><<<sm_count, threads, 0, stream>>>(
169+
mc_bar, uc_bar, expected_blocks, launch_counter, base_u64, count_u64, rank, tsize);
170+
}
171+
break;
83172
default:
84173
return UCC_ERR_NOT_SUPPORTED;
85174
}

src/components/tl/cuda/kernels/nvls.cuh

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,67 @@ struct NvlsBf16Ops {
8181
MULTIMEM_ST_BF16(v, ptr);
8282
}
8383
};
84+
85+
struct NvlsInt32Ops {
86+
using value_type = unsigned int;
87+
__device__ static inline void ld(unsigned int &v, const uint32_t *ptr) {
88+
asm("multimem.ld_reduce.global.add.u32 %0, [%1];"
89+
: "=r"(v)
90+
: "l"(ptr)
91+
: "memory");
92+
}
93+
__device__ static inline void st(const unsigned int &v, uint32_t *ptr) {
94+
asm volatile("multimem.st.global.u32 [%0], %1;" ::"l"(ptr),
95+
"r"(v)
96+
: "memory");
97+
}
98+
};
99+
100+
struct NvlsUint32Ops {
101+
using value_type = unsigned int;
102+
__device__ static inline void ld(unsigned int &v, const uint32_t *ptr) {
103+
asm("multimem.ld_reduce.global.add.u32 %0, [%1];"
104+
: "=r"(v)
105+
: "l"(ptr)
106+
: "memory");
107+
}
108+
__device__ static inline void st(const unsigned int &v, uint32_t *ptr) {
109+
asm volatile("multimem.st.global.u32 [%0], %1;" ::"l"(ptr),
110+
"r"(v)
111+
: "memory");
112+
}
113+
};
114+
115+
// PTX does not support s64 with add operation, so we use u64 instead
116+
struct NvlsInt64Ops {
117+
using value_type = unsigned long long;
118+
__device__ static inline void ld(unsigned long long &v, const uint64_t *ptr) {
119+
asm("multimem.ld_reduce.global.add.u64 %0, [%1];"
120+
: "=l"(v)
121+
: "l"(ptr)
122+
: "memory");
123+
}
124+
__device__ static inline void st(const unsigned long long &v, uint64_t *ptr) {
125+
asm volatile("multimem.st.global.u64 [%0], %1;" ::"l"(ptr),
126+
"l"(v)
127+
: "memory");
128+
}
129+
};
130+
131+
struct NvlsUint64Ops {
132+
using value_type = unsigned long long;
133+
__device__ static inline void ld(unsigned long long &v, const uint64_t *ptr) {
134+
asm("multimem.ld_reduce.global.add.u64 %0, [%1];"
135+
: "=l"(v)
136+
: "l"(ptr)
137+
: "memory");
138+
}
139+
__device__ static inline void st(const unsigned long long &v, uint64_t *ptr) {
140+
asm volatile("multimem.st.global.u64 [%0], %1;" ::"l"(ptr),
141+
"l"(v)
142+
: "memory");
143+
}
144+
};
84145
#endif // __cplusplus
85146

86147
#endif // UCC_TL_CUDA_NVLS_CUH_

src/components/tl/cuda/kernels/reduce_scatter_kernel.cu

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,74 @@ __global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
4747
total_blocks * (launch_counter * 2 + 2));
4848
}
4949

50+
template <typename NvlsOps>
51+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
52+
reduce_scatter_kernel_scalar32(
53+
ucc_tl_cuda_nvls_control_t *mc_bar, ucc_tl_cuda_nvls_control_t *uc_bar,
54+
const uint32_t total_blocks, uint64_t launch_counter,
55+
uint32_t *base_u32, size_t offset, size_t count, uint32_t *dst_u32)
56+
{
57+
// pre barrier
58+
nvls_bar(
59+
&(mc_bar->arrival_counter),
60+
&(uc_bar->arrival_counter),
61+
total_blocks * (launch_counter * 2 + 1));
62+
63+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4;
64+
size_t stride = blockDim.x * gridDim.x * 4;
65+
66+
for (size_t idx = offset + thread_offset; idx < offset + count;
67+
idx += stride) {
68+
typename NvlsOps::value_type v0, v1, v2, v3;
69+
NvlsOps::ld(v0, base_u32 + idx + 0);
70+
NvlsOps::ld(v1, base_u32 + idx + 1);
71+
NvlsOps::ld(v2, base_u32 + idx + 2);
72+
NvlsOps::ld(v3, base_u32 + idx + 3);
73+
dst_u32[idx - offset + 0] = v0;
74+
dst_u32[idx - offset + 1] = v1;
75+
dst_u32[idx - offset + 2] = v2;
76+
dst_u32[idx - offset + 3] = v3;
77+
}
78+
79+
// post barrier
80+
nvls_bar(
81+
&(mc_bar->arrival_counter),
82+
&(uc_bar->arrival_counter),
83+
total_blocks * (launch_counter * 2 + 2));
84+
}
85+
86+
template <typename NvlsOps>
87+
__global__ void __launch_bounds__(UCC_TL_CUDA_MAX_NVLS_THREADS)
88+
reduce_scatter_kernel_scalar64(
89+
ucc_tl_cuda_nvls_control_t *mc_bar, ucc_tl_cuda_nvls_control_t *uc_bar,
90+
const uint32_t total_blocks, uint64_t launch_counter,
91+
uint64_t *base_u64, size_t offset, size_t count, uint64_t *dst_u64)
92+
{
93+
// pre barrier
94+
nvls_bar(
95+
&(mc_bar->arrival_counter),
96+
&(uc_bar->arrival_counter),
97+
total_blocks * (launch_counter * 2 + 1));
98+
99+
size_t thread_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 2;
100+
size_t stride = blockDim.x * gridDim.x * 2;
101+
102+
for (size_t idx = offset + thread_offset; idx < offset + count;
103+
idx += stride) {
104+
typename NvlsOps::value_type v0, v1;
105+
NvlsOps::ld(v0, base_u64 + idx + 0);
106+
NvlsOps::ld(v1, base_u64 + idx + 1);
107+
dst_u64[idx - offset + 0] = v0;
108+
dst_u64[idx - offset + 1] = v1;
109+
}
110+
111+
// post barrier
112+
nvls_bar(
113+
&(mc_bar->arrival_counter),
114+
&(uc_bar->arrival_counter),
115+
total_blocks * (launch_counter * 2 + 2));
116+
}
117+
50118
#ifdef __cplusplus
51119
extern "C" {
52120
#endif
@@ -101,6 +169,64 @@ ucc_status_t post_reduce_scatter_kernel(
101169
count,
102170
reinterpret_cast<uint32_t *>(dst_ptr));
103171
break;
172+
case UCC_DT_INT32:
173+
reduce_scatter_kernel_scalar32<NvlsInt32Ops>
174+
<<<sm_count, threads, 0, stream>>>(
175+
mc_bar,
176+
uc_bar,
177+
expected_blocks,
178+
launch_counter,
179+
base_u32,
180+
offset,
181+
count,
182+
reinterpret_cast<uint32_t *>(dst_ptr));
183+
break;
184+
case UCC_DT_UINT32:
185+
reduce_scatter_kernel_scalar32<NvlsUint32Ops>
186+
<<<sm_count, threads, 0, stream>>>(
187+
mc_bar,
188+
uc_bar,
189+
expected_blocks,
190+
launch_counter,
191+
base_u32,
192+
offset,
193+
count,
194+
reinterpret_cast<uint32_t *>(dst_ptr));
195+
break;
196+
case UCC_DT_INT64:
197+
{
198+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
199+
size_t offset_u64 = offset / 2;
200+
size_t count_u64 = count / 2;
201+
reduce_scatter_kernel_scalar64<NvlsInt64Ops>
202+
<<<sm_count, threads, 0, stream>>>(
203+
mc_bar,
204+
uc_bar,
205+
expected_blocks,
206+
launch_counter,
207+
base_u64,
208+
offset_u64,
209+
count_u64,
210+
reinterpret_cast<uint64_t *>(dst_ptr));
211+
}
212+
break;
213+
case UCC_DT_UINT64:
214+
{
215+
uint64_t *base_u64 = reinterpret_cast<uint64_t *>(mc_base_addr);
216+
size_t offset_u64 = offset / 2;
217+
size_t count_u64 = count / 2;
218+
reduce_scatter_kernel_scalar64<NvlsUint64Ops>
219+
<<<sm_count, threads, 0, stream>>>(
220+
mc_bar,
221+
uc_bar,
222+
expected_blocks,
223+
launch_counter,
224+
base_u64,
225+
offset_u64,
226+
count_u64,
227+
reinterpret_cast<uint64_t *>(dst_ptr));
228+
}
229+
break;
104230
default:
105231
return UCC_ERR_NOT_SUPPORTED;
106232
}

src/components/tl/cuda/reduce_scatterv/reduce_scatterv_nvls.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,12 @@ ucc_status_t ucc_tl_cuda_reduce_scatterv_nvls_init_common(
175175
"NVLS reduce scatter(v) supported only with SUM operation");
176176
return UCC_ERR_NOT_SUPPORTED;
177177
}
178-
if (dt != UCC_DT_FLOAT32 && dt != UCC_DT_BFLOAT16) {
178+
if (dt != UCC_DT_FLOAT32 && dt != UCC_DT_BFLOAT16 &&
179+
dt != UCC_DT_INT32 && dt != UCC_DT_UINT32 &&
180+
dt != UCC_DT_INT64 && dt != UCC_DT_UINT64) {
179181
tl_debug(
180182
UCC_TL_TEAM_LIB(team),
181-
"NVLS reduce scatter(v) supported only with float32/bfloat16");
183+
"NVLS reduce scatter(v) supported only with float32/bfloat16/int32/uint32/int64/uint64");
182184
return UCC_ERR_NOT_SUPPORTED;
183185
}
184186

@@ -191,11 +193,15 @@ ucc_status_t ucc_tl_cuda_reduce_scatterv_nvls_init_common(
191193
count_elements = get_count(task, trank);
192194

193195
/* Convert from datatype elements to uint32_t indices for the kernel.
194-
* For float32: 1 element = 1 uint32_t
196+
* For float32/int32/uint32: 1 element = 1 uint32_t
197+
* For int64/uint64: 1 element = 2 uint32_t
195198
* For bfloat16: 2 elements = 1 uint32_t */
196-
if (dt == UCC_DT_FLOAT32) {
199+
if (dt == UCC_DT_FLOAT32 || dt == UCC_DT_INT32 || dt == UCC_DT_UINT32) {
197200
offset_u32 = offset_elements;
198201
count_u32 = count_elements;
202+
} else if (dt == UCC_DT_INT64 || dt == UCC_DT_UINT64) {
203+
offset_u32 = offset_elements * 2;
204+
count_u32 = count_elements * 2;
199205
} else { /* UCC_DT_BFLOAT16 */
200206
if (offset_elements % 2 != 0 || count_elements % 2 != 0) {
201207
tl_debug(

0 commit comments

Comments
 (0)