Skip to content

Commit e649c0b

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 777a3bf + f792373 commit e649c0b

File tree

2 files changed

+168
-77
lines changed

2 files changed

+168
-77
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3429,6 +3429,7 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend,
34293429

34303430
GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
34313431

3432+
ggml_cuda_set_device(cuda_ctx->device);
34323433
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
34333434
}
34343435

@@ -3530,6 +3531,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
35303531
}
35313532

35323533
// record event on src stream after the copy
3534+
ggml_cuda_set_device(cuda_ctx_src->device);
35333535
if (!cuda_ctx_src->copy_event) {
35343536
CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
35353537
}
@@ -3547,6 +3549,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
35473549
}
35483550
} else {
35493551
// src and dst are on the same backend
3552+
printf("Why is this being invoked?\n");
35503553
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
35513554
}
35523555
return true;

ggml/src/ggml-cuda/reduce.cu

Lines changed: 165 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -78,85 +78,26 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
7878

7979
auto & info = ggml_cuda_info();
8080
#ifdef GGML_USE_NCCL
81-
if (info.have_nccl && nhave == nreduce) { // somehow I'm not able to figure out how to use NCCL when not all GPUs participate in the reduce op
81+
// Somehow I'm not able to figure out how to use NCCL correctly.
82+
// It does not work at all if not all GPUs participate in the reduce op, and we
83+
// get suboptimal prompt processing performance when we have more than 2 GPUs.
84+
// Hence, if enabled, we use NCCL only for the cases where it works and performs well.
85+
if (info.have_nccl && nhave == nreduce && (nhave == 2 || dst->ne[1] < 32)) {
8286
GGML_ASSERT(info.have_nccl);
8387
GGML_ASSERT(info.device_count == nreduce);
84-
auto type = dst->type;
85-
//int device = ctx.device;
86-
if (nreduce != info.device_count) {
87-
GGML_ABORT("Not implemented");
88-
}
89-
//auto tim1 = std::chrono::steady_clock::now();
90-
auto data_type = type == GGML_TYPE_F32 ? ncclFloat : ncclHalf;
91-
if (nreduce == 4 && dst->ne[1] > 32) {
92-
auto com = info.nccl_coms + info.device_count;
93-
static const int devs[8] = {0,1, 2,3, 0,2, 1,3};
94-
for (int ip = 0; ip < 4; ++ip) {
95-
ncclGroupStart();
96-
ggml_cuda_set_device(devs[2*ip+0]);
97-
auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data,
98-
ggml_nelements(dst), data_type, ncclSum, com[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream());
99-
ggml_cuda_set_device(devs[2*ip+1]);
100-
auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data,
101-
ggml_nelements(dst), data_type, ncclSum, com[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream());
102-
ncclGroupEnd();
103-
if (status1 != ncclSuccess || status2 != ncclSuccess) {
104-
fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2);
105-
GGML_ABORT("Fatal error");
106-
}
107-
}
108-
}
109-
else if (nreduce == 3 && dst->ne[1] > 32) {
110-
auto com = info.nccl_coms + info.device_count;
111-
static const int devs[4] = {0,1, 0,2};
112-
for (int ip = 0; ip < 2; ++ip) {
113-
ncclGroupStart();
114-
ggml_cuda_set_device(devs[2*ip+0]);
115-
auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data,
116-
ggml_nelements(dst), data_type, ncclSum, com[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream());
117-
ggml_cuda_set_device(devs[2*ip+1]);
118-
auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data,
119-
ggml_nelements(dst), data_type, ncclSum, com[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream());
120-
ncclGroupEnd();
121-
if (status1 != ncclSuccess || status2 != ncclSuccess) {
122-
fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2);
123-
GGML_ABORT("Fatal error");
124-
}
125-
}
126-
ncclGroupStart();
127-
ggml_cuda_set_device(0);
128-
auto status1 = ncclSend(dst->src[0]->data, ggml_nelements(dst), data_type, 1, com[0], info.all_ctx[0]->stream());
129-
ggml_cuda_set_device(1);
130-
auto status2 = ncclRecv(dst->src[1]->data, ggml_nelements(dst), data_type, 0, com[1], info.all_ctx[1]->stream());
131-
ncclGroupEnd();
132-
if (status1 != ncclSuccess || status2 != ncclSuccess) {
133-
fprintf(stderr, "%s: ncclSend/Recv failed with statuses %d, %d\n", __func__, (int)status1, (int)status2);
88+
auto data_type = dst->type == GGML_TYPE_F32 ? ncclFloat : ncclHalf;
89+
ncclGroupStart();
90+
for (int i = 0; i < nreduce; ++i) {
91+
ggml_cuda_set_device(i);
92+
auto status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr,
93+
dst->src[i] ? dst->src[i]->data : nullptr,
94+
ggml_nelements(dst), data_type, ncclSum, info.nccl_coms[i], info.all_ctx[i]->stream());
95+
if (status != ncclSuccess) {
96+
fprintf(stderr, "%s: ncclAllReduce failed with status %d\n", __func__, (int)status);
13497
GGML_ABORT("Fatal error");
13598
}
13699
}
137-
else {
138-
ncclGroupStart();
139-
for (int i = 0; i < nreduce; ++i) {
140-
ncclComm_t this_comm;
141-
if (nhave == nreduce) {
142-
this_comm = info.nccl_coms[i];
143-
} else {
144-
auto status = ncclCommSplit(info.nccl_coms[i], dst->src[i] ? 0 : NCCL_SPLIT_NOCOLOR, i, &this_comm, NULL);
145-
GGML_ASSERT(status == ncclSuccess);
146-
}
147-
ggml_cuda_set_device(i);
148-
auto stream = info.all_ctx[i]->stream();
149-
GGML_ASSERT(stream);
150-
auto status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr,
151-
dst->src[i] ? dst->src[i]->data : nullptr,
152-
ggml_nelements(dst), data_type, ncclSum, this_comm, stream);
153-
if (status != ncclSuccess) {
154-
fprintf(stderr, "%s: ncclAllReduce failed with status %d\n", __func__, (int)status);
155-
GGML_ABORT("Fatal error");
156-
}
157-
}
158-
ncclGroupEnd();
159-
}
100+
ncclGroupEnd();
160101
ggml_cuda_set_device(ctx.device);
161102
return;
162103
}
@@ -176,6 +117,149 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
176117
GGML_ASSERT(ii == nhave);
177118
GGML_ASSERT(have_this_device);
178119
}
120+
//
121+
// For prompt processing) the objective is to minimize the amount of data being exchanged between
122+
// the GPUs, even if this means we need to launch a larger number of kernels (we are bandwidth
123+
// bound rather than latency bound).
124+
// The following implements a ring communication+reduction that achieves this goal.
125+
// I would have thought that this is automatically done by NCCL, but it doesn't look that
126+
// way (or I simply don't understand how to use NCCL) as the ring implementation bellow achieves quite a bit
127+
// better performance compared to what I get with NCCL.
128+
//
129+
// We do the data reduction in stages. Let's N be the number of GPUs.
130+
// In each stage, each GPU sends 1/N'th of the data to a peer GPU in a ring fashion
131+
// (i.e. 0->1, 1->2, 2->3, ..., N-1 ->0). Each GPU then performs the addition with the
132+
// portion just received. After N-1 stages, each GPU ends up having the full sum for 1/N'th
133+
// of the data. We then do a second round of N-1 stages where each GPU sends a fully reduced
134+
// portion to its peer. The following shows how all this works for 2, 3, and 4 GPUs:
135+
// Worth noting that because in each round each GPU sends and receives data, we use the
136+
// bidirectional p2p bandwidth, which tends to be 2X the unidirectional bandwidth.
137+
//
138+
// Examples
139+
//
140+
// ======================== 2 devices:
141+
// stage 0:
142+
// i = 0, peer = 1, ichunk = 0 -> copy part 0 from device 1, add -> device 0 has part 0 complete
143+
// i = 1, peer = 0, ichunk = 1 -> copy part 1 from device 0, add -> device 1 has part 1 complete
144+
// second loop
145+
// stage 0
146+
// i = 0, peer = 1, ichunk = 1 -> copy part 1 from device 1 -> device 0 has parts 0, 1 complete
147+
// i = 1, peer = 0, ichunk = 0 -> copy part 0 from device 0 -> device 1 has parts 0, 1 complete
148+
//
149+
// ======================== 3 devices
150+
// stage 0
151+
// i = 0, peer = 1, ichunk = 0 -> copy part 0 from device 1, add -> part 0 = 0+1
152+
// i = 1, peer = 2, ichunk = 1 -> copy part 1 from device 2, add -> part 1 = 1+2
153+
// i = 2, peer = 0, ichunk = 2 -> copy part 2 from device 0, add -> part 2 = 0+2
154+
// stage 1
155+
// i = 0, peer = 1, ichunk = 1 -> copy part 1 from device 1, add -> part 1 = 0+1+2
156+
// i = 1, peer = 2, ichunk = 2 -> copy part 2 from device 2, add -> part 2 = 0+1+2
157+
// i = 2, peer = 0, ichunk = 0 -> copy part 0 from device 0, add -> part 0 = 0+1+2
158+
// second loop
159+
// stage 0
160+
// i = 0, peer = 1, ichunk = 2 -> copy part 2 from device 1, device 0 now has parts 1, 2 complete
161+
// i = 1, peer = 2, ichunk = 0 -> copy part 0 from device 2, device 1 now has parts 0, 2 complete
162+
// i = 2, peer = 0, ichunk = 1 -> copy part 1 from device 0, device 2 now has parts 0, 1 complete
163+
// stage 1
164+
// i = 0, peer = 1, ichunk = 0 -> copy part 0 from device 1, device 0 now has parts 0, 1, 2, complete
165+
// i = 1, peer = 2, ichunk = 1 -> copy part 1 from device 2, device 1 now has parts 0, 1, 2, complete
166+
// i = 2, peer = 0, ichunk = 2 -> copy part 2 from device 0, device 2 now has parts 0, 1, 2, complete
167+
//
168+
// ======================== 4 devices
169+
// stage 0
170+
// i = 0, peer = 1, ichunk = 0 -> copy part 0 from device 1, add -> part 0 = 0+1
171+
// i = 1, peer = 2, ichunk = 1 -> copy part 1 from device 2, add -> part 1 = 1+2
172+
// i = 2, peer = 3, ichunk = 2 -> copy part 2 from device 3, add -> part 2 = 2+3
173+
// i = 3, peer = 0, ichunk = 3 -> copy part 3 from device 0, add -> part 3 = 0+3
174+
// stage 1
175+
// i = 0, peer = 1, ichunk = 1 -> copy part 1 from device 1, add -> part 1 = 0+1+2
176+
// i = 1, peer = 2, ichunk = 2 -> copy part 2 from device 2, add -> part 2 = 1+2+3
177+
// i = 2, peer = 3, ichunk = 3 -> copy part 3 from device 3, add -> part 3 = 0+2+3
178+
// i = 3, peer = 0, ichunk = 0 -> copy part 0 from device 0, add -> part 0 = 0+1+3
179+
// stage 2
180+
// i = 0, peer = 1, ichunk = 2 -> copy part 2 from device 1, add -> part 2 = 0+1+2+3
181+
// i = 1, peer = 2, ichunk = 3 -> copy part 3 from device 2, add -> part 3 = 0+1+2+3
182+
// i = 2, peer = 3, ichunk = 0 -> copy part 0 from device 3, add -> part 0 = 0+1+2+3
183+
// i = 3, peer = 0, ichunk = 1 -> copy part 1 from device 0, add -> part 1 = 0+1+2+3
184+
// second loop
185+
// stage 0
186+
// i = 0, peer = 1, ichunk = 3 -> copy part 3 from device 1, device 0 now has parts 2, 3
187+
// i = 1, peer = 2, ichunk = 0 -> copy part 0 from device 2, device 1 now has parts 3, 0
188+
// i = 2, peer = 3, ichunk = 1 -> copy part 1 from device 3, device 2 now has parts 0, 1
189+
// i = 3, peer = 0, ichunk = 2 -> copy part 2 from device 0, device 3 now has parts 1, 2
190+
// stage 1
191+
// i = 0, peer = 1, ichunk = 0 -> copy part 0 from device 1, device 0 now has parts 0, 2, 3
192+
// i = 1, peer = 2, ichunk = 1 -> copy part 1 from device 2, device 1 now has parts 3, 0, 1
193+
// i = 2, peer = 3, ichunk = 2 -> copy part 2 from device 3, device 2 now has parts 0, 1, 2
194+
// i = 3, peer = 0, ichunk = 3 -> copy part 3 from device 0, device 3 now has parts 1, 2, 3
195+
// stage 2
196+
// i = 0, peer = 1, ichunk = 1 -> copy part 1 from device 1, device 0 now has parts 0, 1, 2, 3
197+
// etc.
198+
//
199+
if (dst->ne[1] >= 32) {
200+
auto nelem = ggml_nelements(dst);
201+
auto elem_size = ggml_element_size(dst);
202+
auto nelem_per_device = (nelem + nhave - 1)/nhave;
203+
auto required_size = nelem_per_device*elem_size;
204+
for (int ii = 0; ii < nhave; ++ii) {
205+
int i = idx[ii];
206+
auto this_ctx = info.all_ctx[i];
207+
if (!this_ctx->copy_event) {
208+
ggml_cuda_set_device(this_ctx->device);
209+
CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->copy_event, cudaEventDisableTiming));
210+
}
211+
if (required_size > this_ctx->copy_size) {
212+
ggml_cuda_set_device(this_ctx->device);
213+
if (this_ctx->copy_buffer) {
214+
CUDA_CHECK(cudaFree(this_ctx->copy_buffer));
215+
}
216+
CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, required_size, this_ctx->device));
217+
this_ctx->copy_size = required_size;
218+
}
219+
}
220+
for (int stage = 0; stage < nhave-1; ++stage) {
221+
int ichunk = stage;
222+
for (int ii = 0; ii < nhave; ++ii) {
223+
int i = idx[ii];
224+
int peer = idx[(ii+1)%nhave];
225+
auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device);
226+
ggml_cuda_set_device(info.all_ctx[peer]->device);
227+
CUDA_CHECK(cudaMemcpyPeerAsync(info.all_ctx[i]->copy_buffer, info.all_ctx[i]->device,
228+
(const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device,
229+
this_nelem*elem_size, info.all_ctx[peer]->stream()));
230+
CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream()));
231+
ggml_cuda_set_device(info.all_ctx[i]->device);
232+
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[peer]->copy_event, 0));
233+
int num_blocks = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
234+
if (dst->type == GGML_TYPE_F16) {
235+
k_add<half, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(this_nelem,
236+
(const half *)info.all_ctx[i]->copy_buffer, (half *)dst->src[i]->data + ichunk*nelem_per_device);
237+
} else {
238+
k_add<float, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(this_nelem,
239+
(const float *)info.all_ctx[i]->copy_buffer, (float *)dst->src[i]->data + ichunk*nelem_per_device);
240+
}
241+
ichunk = (ichunk + 1)%nhave;
242+
}
243+
}
244+
for (int stage = 0; stage < nhave-1; ++stage) {
245+
int ichunk = (nhave - 1 + stage)%nhave;
246+
for (int ii = 0; ii < nhave; ++ii) {
247+
int i = idx[ii];
248+
int peer = idx[(ii+1)%nhave];
249+
auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device);
250+
ggml_cuda_set_device(info.all_ctx[peer]->device);
251+
CUDA_CHECK(cudaMemcpyPeerAsync((char *)dst->src[i]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[i]->device,
252+
(const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device,
253+
this_nelem*elem_size, info.all_ctx[peer]->stream()));
254+
CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream()));
255+
ggml_cuda_set_device(info.all_ctx[i]->device);
256+
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[peer]->copy_event, 0));
257+
ichunk = (ichunk + 1)%nhave;
258+
}
259+
}
260+
ggml_cuda_set_device(ctx.device);
261+
return;
262+
}
179263
if (nhave == 4 && dst->ne[1] <= 8 && ctx.p2p_enabled) {
180264
for (int ii = 0; ii < nhave; ++ii) {
181265
int i = idx[ii];
@@ -189,15 +273,16 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
189273
auto nelem = ggml_nelements(dst);
190274
for (int ii = 0; ii < nhave/2; ++ii) {
191275
int i = idx[2*ii+0];
192-
ggml_cuda_set_device(i);
193276
int nblocks = (nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
194277
copy_task task;
195278
task.nptr = nhave/2;
196279
task.nelem = nelem;
197280
task.ptrs[0] = (char *)dst->src[i]->data;
198281
int j = idx[2*ii+1];
282+
ggml_cuda_set_device(j);
199283
CUDA_CHECK(cudaEventRecord(info.all_ctx[j]->copy_event, info.all_ctx[j]->stream()));
200284
task.ptrs[1] = (char *)dst->src[j]->data;
285+
ggml_cuda_set_device(i);
201286
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
202287
if (dst->type == GGML_TYPE_F16) {
203288
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
@@ -212,14 +297,14 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
212297
}
213298
for (int ii = 0; ii < nhave/2; ++ii) {
214299
int i = idx[2*ii+1];
215-
ggml_cuda_set_device(i);
216300
int nblocks = (nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
217301
copy_task task;
218302
task.nptr = nhave/2;
219303
task.nelem = nelem;
220304
task.ptrs[0] = (char *)dst->src[i]->data;
221305
int j = idx[(2*ii+2)%nhave];
222306
task.ptrs[1] = (char *)dst->src[j]->data;
307+
ggml_cuda_set_device(i);
223308
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
224309
if (dst->type == GGML_TYPE_F16) {
225310
k_reduce_add_T<half, CUDA_REDUCE_BLOCK_SIZE, 2><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
@@ -258,6 +343,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
258343
auto elem_size = ggml_element_size(dst);
259344
for (int ii = 0; ii < nhave; ++ii) {
260345
int i = idx[ii];
346+
ggml_cuda_set_device(i);
261347
int this_nelem = std::min(nelem_per_device, nelem - ii*nelem_per_device);
262348
copy_task task;
263349
task.nptr = nhave;
@@ -304,18 +390,20 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
304390
//printf("Submitted kernels\n");
305391
for (int ii = 0; ii < nhave; ++ii) {
306392
int i = idx[ii];
393+
ggml_cuda_set_device(i);
307394
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
308395
}
309396
//printf("Recorded events again\n");
310397
for (int ii = 0; ii < nhave; ++ii) {
311398
int i = idx[ii];
399+
ggml_cuda_set_device(i);
312400
for (int jj = 0; jj < nhave; ++jj) {
313401
if (jj == ii) continue;
314402
int j = idx[jj];
315403
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
316404
}
317405
}
318-
//printf("All good so far\n");
406+
ggml_cuda_set_device(ctx.device);
319407
return;
320408
}
321409
auto required_size = nbytes*(nhave-1);

0 commit comments

Comments
 (0)