Skip to content

Commit 2746fc1

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 4b2f955 + f90d1fd commit 2746fc1

File tree

10 files changed

+216
-117
lines changed

10 files changed

+216
-117
lines changed

ggml/src/ggml-backend.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,7 +1639,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
16391639

16401640
// check if we should start a new split based on the sources of the current node
16411641
bool need_new_split = false;
1642-
if (node->op == GGML_OP_ADD && node->op_params[0] == 0xff) {
1642+
if ((node->op == GGML_OP_ADD && node->op_params[0] == 0xff) ||
1643+
node->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] == 0xff) {
16431644
need_new_split = true;
16441645
}
16451646
else if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
@@ -1882,6 +1883,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
18821883
static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_backend_sched_split * split, std::array<bool, GGML_SCHED_MAX_BACKENDS> & needs_sync,
18831884
std::vector<int32_t> & ids, std::vector<uint32_t> & unique_ids, ggml_tensor * last_ids_tensor) {
18841885
if (split->n_inputs < 1) return;
1886+
constexpr bool k_set_sync = false;
18851887
int split_backend_id = split->backend_id;
18861888
ggml_backend_t split_backend = sched->backends[split_backend_id];
18871889
ggml_backend_t last_input_backend = nullptr;
@@ -1892,13 +1894,10 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back
18921894

18931895
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
18941896
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
1895-
if (needs_sync[split_backend_id]) {
1896-
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
1897-
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
1898-
} else {
1899-
ggml_backend_synchronize(split_backend);
1900-
}
1901-
needs_sync[split_backend_id] = false;
1897+
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
1898+
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
1899+
} else {
1900+
ggml_backend_synchronize(split_backend);
19021901
}
19031902
ggml_backend_tensor_copy(input, input_cpy);
19041903
} else {
@@ -1909,7 +1908,7 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back
19091908
} else {
19101909
ggml_backend_synchronize(split_backend);
19111910
}
1912-
needs_sync[split_backend_id] = false;
1911+
needs_sync[split_backend_id] = k_set_sync;
19131912
}
19141913

19151914
ggml_tensor * node = split->graph.nodes[0];
@@ -1923,7 +1922,6 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back
19231922
last_input_backend = input_backend;
19241923
}
19251924

1926-
//printf("node: %s have %d inputs, processing input %d\n", node->name, split->n_inputs, j);
19271925
ggml_tensor * ids_tensor = node->op == GGML_OP_MUL_MAT_ID ? node->src[2] : node->src[3];
19281926
auto ids_backend = split_backend;
19291927

@@ -1945,7 +1943,7 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back
19451943
ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
19461944

19471945
ggml_backend_synchronize(ids_backend);
1948-
needs_sync[tensor_backend_id(ids_tensor)] = false;
1946+
needs_sync[tensor_backend_id(ids_tensor)] = k_set_sync;
19491947

19501948
unique_ids.resize((n_expert + 31)/32);
19511949
std::memset(unique_ids.data(), 0, unique_ids.size()*sizeof(uint32_t));
@@ -2005,15 +2003,15 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back
20052003
int input_backend_id = tensor_backend_id(input);
20062004
if (needs_sync[input_backend_id]) {
20072005
ggml_backend_synchronize(input_backend);
2008-
needs_sync[input_backend_id] = false;
2006+
needs_sync[input_backend_id] = k_set_sync;
20092007
}
20102008
if (needs_sync[split_backend_id]) {
20112009
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
20122010
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
20132011
} else {
20142012
ggml_backend_synchronize(split_backend);
20152013
}
2016-
needs_sync[split_backend_id] = false;
2014+
needs_sync[split_backend_id] = k_set_sync;
20172015
}
20182016
ggml_backend_tensor_copy(input, input_cpy);
20192017
}
@@ -2034,7 +2032,6 @@ static ggml_status ggml_backend_sched_compute_splits_sm_graph(ggml_backend_sched
20342032
for (int i = 0; i < sched->n_splits; ++i) {
20352033
auto split_i = &splits[i];
20362034
this_split.clear();
2037-
//auto& this_split = all_splits.emplace_back();
20382035
this_split.push_back(split_i);
20392036
for (int j = i+1; j < sched->n_splits; ++j) {
20402037
auto split_j = &splits[j];
@@ -2092,7 +2089,7 @@ static ggml_status ggml_backend_sched_compute_splits_sm_graph(ggml_backend_sched
20922089

20932090
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
20942091

2095-
if (sched->split_mode_graph) {
2092+
if (false && sched->split_mode_graph) {
20962093
return ggml_backend_sched_compute_splits_sm_graph(sched);
20972094
}
20982095

ggml/src/ggml-cuda.cu

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3411,46 +3411,44 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
34113411
}
34123412

34133413
if (backend_src != backend_dst) {
3414+
ggml_cuda_pool_alloc<half> tmp_src(cuda_ctx_src->pool());
3415+
ggml_cuda_pool_alloc<half> tmp_dst(cuda_ctx_dst->pool());
3416+
bool needs_f16_f32_copy = false;
34143417
// copy on src stream
34153418
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
34163419
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
34173420
} else {
34183421
#ifdef GGML_CUDA_NO_PEER_COPY
34193422
return false;
34203423
#else
3421-
if (false && src->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3424+
if (false && src->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->ne[1] >= 32) {
34223425
//
34233426
// The goal here is to reduce traffic between GPU's, which is entirely non-negligible
34243427
// for prompt processing.
34253428
// We cast the tensor to be copied to f16, copy the f16 data peer-to-peer
34263429
// and then cast back to f32 on the destination side.
3427-
// The cost for converting to/from f16 is much ower than the cost of copying
3430+
// The cost for converting to/from f16 is much lower than the cost of copying
34283431
// two times more data over PCI-E (well, at least the 30 GB/s PCI-E I have).
3429-
// iBut for some reason the following is not working.
3432+
// But for some reason the following is slower.
34303433
// Can somebody tell me why?
34313434
//
3432-
ggml_cuda_pool_alloc<half> tmp_src(cuda_ctx_src->pool(), ggml_nelements(src));
3433-
ggml_cuda_pool_alloc<half> tmp_dst(cuda_ctx_dst->pool(), ggml_nelements(dst));
3435+
3436+
ggml_cuda_set_device(cuda_ctx_dst->device);
3437+
tmp_dst.alloc(ggml_nelements(dst));
3438+
3439+
ggml_cuda_set_device(cuda_ctx_src->device);
3440+
tmp_src.alloc(ggml_nelements(src));
34343441

34353442
auto src_f16 = *src;
34363443
src_f16.type = GGML_TYPE_F16;
34373444
for (int i = 0; i < 4; ++i) src_f16.nb[i] /= 2;
34383445
src_f16.data = tmp_src.get();
34393446

3440-
auto dst_f16 = *dst;
3441-
dst_f16.type = GGML_TYPE_F16;
3442-
for (int i = 0; i < 4; ++i) dst_f16.nb[i] /= 2;
3443-
dst_f16.data = tmp_dst.get();
3444-
3445-
ggml_cuda_set_device(cuda_ctx_src->device);
34463447
ggml_cuda_cpy(*cuda_ctx_src, src, &src_f16, true);
3447-
CUDA_CHECK(cudaStreamSynchronize(cuda_ctx_src->stream()));
34483448

3449-
CUDA_CHECK(cudaMemcpyPeerAsync(dst_f16.data, cuda_ctx_dst->device, src_f16.data, cuda_ctx_src->device, ggml_nbytes(&dst_f16), cuda_ctx_src->stream()));
3449+
CUDA_CHECK(cudaMemcpyPeerAsync(tmp_dst.ptr, cuda_ctx_dst->device, src_f16.data, cuda_ctx_src->device, ggml_nbytes(&src_f16), cuda_ctx_src->stream()));
34503450

3451-
ggml_cuda_set_device(cuda_ctx_dst->device);
3452-
CUDA_CHECK(cudaStreamSynchronize(cuda_ctx_dst->stream()));
3453-
ggml_cuda_cpy(*cuda_ctx_dst, &dst_f16, dst, true);
3451+
needs_f16_f32_copy = true;
34543452

34553453
} else {
34563454
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
@@ -3467,7 +3465,15 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
34673465
CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
34683466

34693467
// wait on dst stream for the copy to complete
3468+
ggml_cuda_set_device(cuda_ctx_dst->device);
34703469
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
3470+
if (needs_f16_f32_copy) {
3471+
auto dst_f16 = *dst;
3472+
dst_f16.type = GGML_TYPE_F16;
3473+
for (int i = 0; i < 4; ++i) dst_f16.nb[i] /= 2;
3474+
dst_f16.data = tmp_dst.get();
3475+
ggml_cuda_cpy(*cuda_ctx_dst, &dst_f16, dst, true);
3476+
}
34713477
} else {
34723478
// src and dst are on the same backend
34733479
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,15 @@ static __global__ void k_fast_add(int64_t ne0, int64_t nelem, const float * x, c
321321
z[i] = x[i] + y[i % ne0];
322322
}
323323

324+
template <typename src1_t, typename src2_t, typename dst_t>
325+
static __global__ void k_fast_add_2(int64_t ne0, int64_t nelem, const src1_t * x, const src2_t * y, dst_t * z) {
326+
int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
327+
if (i >= nelem) {
328+
return;
329+
}
330+
z[i] = (dst_t)((float)x[i] + (float)y[i]);
331+
}
332+
324333
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
325334
if (ggml_nrows(dst->src[1]) == 1 && dst->src[0]->ne[0] == dst->src[1]->ne[0] &&
326335
dst->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
@@ -332,6 +341,45 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
332341
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
333342
return;
334343
}
344+
if (ggml_is_contiguous(dst->src[0]) && ggml_are_same_shape(dst->src[0], dst->src[1]) && ggml_is_contiguous(dst)) {
345+
constexpr int kBlockSize = 256;
346+
auto nelem = ggml_nelements(dst);
347+
int nblocks = (nelem + kBlockSize - 1)/kBlockSize;
348+
if (dst->type == GGML_TYPE_F16) {
349+
if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F16) {
350+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
351+
(const half *)dst->src[0]->data, (const half *)dst->src[1]->data, (half *)dst->data);
352+
}
353+
else if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F32) {
354+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
355+
(const half *)dst->src[0]->data, (const float *)dst->src[1]->data, (half *)dst->data);
356+
}
357+
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) {
358+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
359+
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (half *)dst->data);
360+
} else {
361+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
362+
(const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (half *)dst->data);
363+
}
364+
} else {
365+
if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F16) {
366+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
367+
(const half *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);
368+
}
369+
else if (dst->src[0]->type == GGML_TYPE_F16 && dst->src[1]->type == GGML_TYPE_F32) {
370+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
371+
(const half *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
372+
}
373+
else if (dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32) {
374+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
375+
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
376+
} else {
377+
k_fast_add_2<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
378+
(const float *)dst->src[0]->data, (const half *)dst->src[1]->data, (float *)dst->data);
379+
}
380+
}
381+
return;
382+
}
335383
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
336384
}
337385

ggml/src/ggml-cuda/cpy.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
542542
char ** dest_ptrs_d = nullptr;
543543
int graph_cpynode_index = -1;
544544
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
545-
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
545+
if(!disable_indirection_for_this_node && ctx.cuda_graph && ctx.cuda_graph->use_cpy_indirection) {
546546
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
547547
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
548548
}
@@ -651,7 +651,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
651651
ggml_type_name(src0->type), ggml_type_name(src1->type));
652652
}
653653
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
654-
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
654+
if(!disable_indirection_for_this_node && ctx.cuda_graph && ctx.cuda_graph->use_cpy_indirection) {
655655
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
656656
}
657657
#else

ggml/src/ggml-cuda/norm.cu

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#include "norm.cuh"
22

3-
template <int block_size>
4-
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
3+
template <int block_size, typename T>
4+
static __global__ void norm_f32(const T * x, float * dst, const int ncols, const float eps) {
55
const int row = blockIdx.x*blockDim.y + threadIdx.y;
66
const int tid = threadIdx.x;
77

88
float2 mean_var = make_float2(0.f, 0.f);
99

1010
for (int col = tid; col < ncols; col += block_size) {
11-
const float xi = x[row*ncols + col];
11+
const float xi = (float)x[row*ncols + col];
1212
mean_var.x += xi;
1313
mean_var.y += xi * xi;
1414
}
@@ -32,7 +32,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
3232
const float inv_std = rsqrtf(var + eps);
3333

3434
for (int col = tid; col < ncols; col += block_size) {
35-
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
35+
dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std);
3636
}
3737
}
3838

@@ -261,14 +261,15 @@ static __global__ void fused_rms_norm_f32_nc(
261261
}
262262
}
263263

264-
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
264+
template <typename T>
265+
static void norm_f32_cuda(const T * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
265266
GGML_ASSERT(ncols % WARP_SIZE == 0);
266267
if (ncols < 1024) {
267268
const dim3 block_dims(WARP_SIZE, 1, 1);
268-
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
269+
norm_f32<WARP_SIZE, T><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
269270
} else {
270271
const dim3 block_dims(1024, 1, 1);
271-
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
272+
norm_f32<1024, T><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
272273
}
273274
}
274275

@@ -364,7 +365,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
364365

365366
GGML_ASSERT(ggml_is_contiguous(src0));
366367

367-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
368+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
368369
GGML_ASSERT( dst->type == GGML_TYPE_F32);
369370

370371
const int64_t ne00 = src0->ne[0];
@@ -373,7 +374,11 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
373374
float eps;
374375
memcpy(&eps, dst->op_params, sizeof(float));
375376

376-
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
377+
if (src0->type == GGML_TYPE_F32) {
378+
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
379+
} else {
380+
norm_f32_cuda((const half *)src0_d, dst_d, ne00, nrows, eps, stream);
381+
}
377382
}
378383

379384
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7232,7 +7232,12 @@ static struct ggml_tensor * ggml_norm_impl(
72327232
is_node = true;
72337233
}
72347234

7235-
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
7235+
if (inplace && a->type != GGML_TYPE_F32) {
7236+
GGML_ABORT("Fatal error");
7237+
}
7238+
7239+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : a->type == GGML_TYPE_F32 ? ggml_dup_tensor(ctx, a)
7240+
: ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
72367241

72377242
ggml_set_op_params(result, &eps, sizeof(eps));
72387243

0 commit comments

Comments
 (0)