Skip to content

Commit ffa577e

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 5882232 + af10490 commit ffa577e

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,44 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
111111

112112
}
113113

114+
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void simple_moe_cuda(const float * logits,
115+
float * weights,
116+
int32_t * ids,
117+
const int n_rows,
118+
const int n_experts) {
119+
const int row = blockIdx.x * blockDim.y + threadIdx.y;
120+
if (row >= n_rows) {
121+
return;
122+
}
123+
124+
logits += n_experts * row;
125+
weights += n_experts * row;
126+
ids += n_experts * row;
127+
128+
float max_val = -INFINITY;
129+
#pragma unroll
130+
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
131+
max_val = max(max_val, logits[i]);
132+
ids[i] = i;
133+
}
134+
135+
max_val = warp_reduce_max(max_val);
136+
137+
float sum = 0;
138+
#pragma unroll
139+
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
140+
weights[i] = expf(logits[i] - max_val);
141+
sum += weights[i];
142+
}
143+
144+
sum = warp_reduce_sum(sum);
145+
float norm = 1/sum;
146+
#pragma unroll
147+
for (int i = threadIdx.x; i < n_experts; i += WARP_SIZE) {
148+
weights[i] *= norm;
149+
}
150+
}
151+
114152
template <bool normalize>
115153
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
116154
const float * logits,
@@ -124,6 +162,11 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
124162
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
125163
cudaStream_t stream = ctx.stream();
126164

165+
if (n_expert_used == n_expert) {
166+
simple_moe_cuda<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert);
167+
return;
168+
}
169+
127170
switch (n_expert) {
128171
case 1:
129172
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);

0 commit comments

Comments
 (0)