@@ -193,8 +193,9 @@ pub fn call_rms_norm(
193193 eps
194194 )
195195 ) ;
196+ let work_per_threadgroup = elements_to_sum;
196197
197- let out_length = length / elements_to_sum ;
198+ let out_length = length / work_per_threadgroup ;
198199
199200 let thread_group_count = MTLSize {
200201 width : out_length,
@@ -204,19 +205,17 @@ pub fn call_rms_norm(
204205
205206 let width = std:: cmp:: min (
206207 pipeline. max_total_threads_per_threadgroup ( ) ,
207- elements_to_sum,
208- )
209- . next_power_of_two ( ) ;
208+ ( work_per_threadgroup / 2 ) . next_power_of_two ( ) ,
209+ ) ;
210210
211211 let thread_group_size = MTLSize {
212212 width,
213213 height : 1 ,
214214 depth : 1 ,
215215 } ;
216-
217216 encoder. use_resource ( input, MTLResourceUsage :: Read ) ;
217+ encoder. use_resource ( alpha, MTLResourceUsage :: Read ) ;
218218 encoder. use_resource ( output, MTLResourceUsage :: Write ) ;
219- encoder. set_threadgroup_memory_length ( 0 , ( width * 4 ) . max ( 16 ) ) ;
220219 encoder. dispatch_thread_groups ( thread_group_count, thread_group_size) ;
221220 Ok ( ( ) )
222221}
@@ -256,7 +255,9 @@ pub fn call_layer_norm(
256255 )
257256 ) ;
258257
259- let out_length = length / elements_to_sum;
258+ let work_per_threadgroup = elements_to_sum;
259+
260+ let out_length = length / work_per_threadgroup;
260261
261262 let thread_group_count = MTLSize {
262263 width : out_length,
@@ -266,19 +267,18 @@ pub fn call_layer_norm(
266267
267268 let width = std:: cmp:: min (
268269 pipeline. max_total_threads_per_threadgroup ( ) ,
269- elements_to_sum,
270- )
271- . next_power_of_two ( ) ;
270+ ( work_per_threadgroup / 2 ) . next_power_of_two ( ) ,
271+ ) ;
272272
273273 let thread_group_size = MTLSize {
274274 width,
275275 height : 1 ,
276276 depth : 1 ,
277277 } ;
278-
279278 encoder. use_resource ( input, MTLResourceUsage :: Read ) ;
279+ encoder. use_resource ( alpha, MTLResourceUsage :: Read ) ;
280+ encoder. use_resource ( beta, MTLResourceUsage :: Read ) ;
280281 encoder. use_resource ( output, MTLResourceUsage :: Write ) ;
281- encoder. set_threadgroup_memory_length ( 0 , ( width * 8 ) . max ( 32 ) ) ;
282282 encoder. dispatch_thread_groups ( thread_group_count, thread_group_size) ;
283283 Ok ( ( ) )
284284}
0 commit comments