Skip to content

Commit db3d5d9

Browse files
authored
[Metal] improve normalization (huggingface#3283)
1 parent c3ed240 commit db3d5d9

7 files changed

Lines changed: 485 additions & 159 deletions

File tree

candle-metal-kernels/src/kernels/reduce.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ pub fn call_rms_norm(
193193
eps
194194
)
195195
);
196+
let work_per_threadgroup = elements_to_sum;
196197

197-
let out_length = length / elements_to_sum;
198+
let out_length = length / work_per_threadgroup;
198199

199200
let thread_group_count = MTLSize {
200201
width: out_length,
@@ -204,19 +205,17 @@ pub fn call_rms_norm(
204205

205206
let width = std::cmp::min(
206207
pipeline.max_total_threads_per_threadgroup(),
207-
elements_to_sum,
208-
)
209-
.next_power_of_two();
208+
(work_per_threadgroup / 2).next_power_of_two(),
209+
);
210210

211211
let thread_group_size = MTLSize {
212212
width,
213213
height: 1,
214214
depth: 1,
215215
};
216-
217216
encoder.use_resource(input, MTLResourceUsage::Read);
217+
encoder.use_resource(alpha, MTLResourceUsage::Read);
218218
encoder.use_resource(output, MTLResourceUsage::Write);
219-
encoder.set_threadgroup_memory_length(0, (width * 4).max(16));
220219
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
221220
Ok(())
222221
}
@@ -256,7 +255,9 @@ pub fn call_layer_norm(
256255
)
257256
);
258257

259-
let out_length = length / elements_to_sum;
258+
let work_per_threadgroup = elements_to_sum;
259+
260+
let out_length = length / work_per_threadgroup;
260261

261262
let thread_group_count = MTLSize {
262263
width: out_length,
@@ -266,19 +267,18 @@ pub fn call_layer_norm(
266267

267268
let width = std::cmp::min(
268269
pipeline.max_total_threads_per_threadgroup(),
269-
elements_to_sum,
270-
)
271-
.next_power_of_two();
270+
(work_per_threadgroup / 2).next_power_of_two(),
271+
);
272272

273273
let thread_group_size = MTLSize {
274274
width,
275275
height: 1,
276276
depth: 1,
277277
};
278-
279278
encoder.use_resource(input, MTLResourceUsage::Read);
279+
encoder.use_resource(alpha, MTLResourceUsage::Read);
280+
encoder.use_resource(beta, MTLResourceUsage::Read);
280281
encoder.use_resource(output, MTLResourceUsage::Write);
281-
encoder.set_threadgroup_memory_length(0, (width * 8).max(32));
282282
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
283283
Ok(())
284284
}

0 commit comments

Comments
 (0)