@@ -111,6 +111,44 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
111111
112112}
113113
114+ __launch_bounds__ (4 * WARP_SIZE, 1 ) __global__ void simple_moe_cuda(const float * logits,
115+ float * weights,
116+ int32_t * ids,
117+ const int n_rows,
118+ const int n_experts) {
119+ const int row = blockIdx .x * blockDim .y + threadIdx .y ;
120+ if (row >= n_rows) {
121+ return ;
122+ }
123+
124+ logits += n_experts * row;
125+ weights += n_experts * row;
126+ ids += n_experts * row;
127+
128+ float max_val = -INFINITY;
129+ #pragma unroll
130+ for (int i = threadIdx .x ; i < n_experts; i += WARP_SIZE) {
131+ max_val = max (max_val, logits[i]);
132+ ids[i] = i;
133+ }
134+
135+ max_val = warp_reduce_max (max_val);
136+
137+ float sum = 0 ;
138+ #pragma unroll
139+ for (int i = threadIdx .x ; i < n_experts; i += WARP_SIZE) {
140+ weights[i] = expf (logits[i] - max_val);
141+ sum += weights[i];
142+ }
143+
144+ sum = warp_reduce_sum (sum);
145+ float norm = 1 /sum;
146+ #pragma unroll
147+ for (int i = threadIdx .x ; i < n_experts; i += WARP_SIZE) {
148+ weights[i] *= norm;
149+ }
150+ }
151+
114152template <bool normalize>
115153static void launch_topk_moe_cuda (ggml_backend_cuda_context & ctx,
116154 const float * logits,
@@ -124,6 +162,11 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
124162 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
125163 cudaStream_t stream = ctx.stream ();
126164
165+ if (n_expert_used == n_expert) {
166+ simple_moe_cuda<<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert);
167+ return ;
168+ }
169+
127170 switch (n_expert) {
128171 case 1 :
129172 topk_moe_cuda<1 , normalize><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
0 commit comments