66 transpose_load, transpose_store,
77 load_ir_segments, store_ir_segments,
88 declare_smem_variables,
9- set_launch_bound_variables with context %}
9+ set_launch_bound_variables, launch_bounds
10+ with context %}
1011
1112#define THREADS_PER_WARP {{ forward_schedule.launch_config .warp_size }} // Warp size should be the same for forward and backward
1213#define FULL_MASK 0xffffffff
@@ -30,8 +31,11 @@ struct ConvData {
3031};
3132
3233
33- {%- macro generate_fixup_kernel (name, warp_size, dim, fixup_offset) %}
34- __global__ void {{name}}(void * workspace, IRREP_T * dst_ptr) {
34+ {%- macro generate_fixup_kernel (name, schedule, dim, fixup_offset) %}
35+ {%- set warp_size = schedule.launch_config .warp_size %}
36+ __global__ void
37+ {{ launch_bounds (schedule) }}
38+ {{name}}(void * workspace, IRREP_T * dst_ptr) {
3539 /*
3640 * Workspace consists of:
3741 * fixup_dim * warps_launched * sizeof(IRREP_T): Data
@@ -61,7 +65,7 @@ __global__ void {{name}}(void* workspace, IRREP_T* dst_ptr) {
6165}
6266{%- endmacro %}
6367
64- {{ generate_fixup_kernel (" fixup_forward" , forward_schedule. launch_config . warp_size , forward_schedule.L3 .dim , forward_workspace_offset) }}
68+ {{ generate_fixup_kernel (" fixup_forward" , forward_schedule, forward_schedule.L3 .dim , forward_workspace_offset) }}
6569
6670template <int ROW_LEN >
6771__device__ __forceinline__ void kahanAdd (IRREP_T * c_arr, IRREP_T * sum_arr, int lane_id) {
@@ -88,7 +92,9 @@ __device__ __forceinline__ void kahanAdd(IRREP_T* c_arr, IRREP_T* sum_arr, int l
8892 }
8993}
9094
91- __global__ void forward (
95+ __global__ void
96+ {{ launch_bounds (forward_schedule) }}
97+ forward (
9298 IRREP_T * L1_in,
9399 IRREP_T * L2_in,
94100 WEIGHT_T * weights,
@@ -174,10 +180,11 @@ __global__ void forward(
174180{{ generate_segment_kernel_backward (i, segment, backward_schedule.launch_config .warp_size ) }}
175181{%- endfor %}
176182
177- {{ generate_fixup_kernel (" fixup_backward" , backward_schedule. launch_config . warp_size , backward_schedule.L1 .dim , backward_workspace_offset) }}
183+ {{ generate_fixup_kernel (" fixup_backward" , backward_schedule, backward_schedule.L1 .dim , backward_workspace_offset) }}
178184
179- __global__ void backward (
180- IRREP_T * L1_in, IRREP_T * L1_grad,
185+ __global__ void
186+ {{ launch_bounds (backward_schedule) }}
187+ backward (IRREP_T * L1_in, IRREP_T * L1_grad,
181188 IRREP_T * L2_in, IRREP_T * L2_grad,
182189 WEIGHT_T * weights, WEIGHT_T * weights_grad,
183190 IRREP_T * L3_grad, ConvData c, void * workspace_raw,
@@ -284,8 +291,9 @@ __global__ void backward(
284291}
285292
286293
287- __global__ void double_backward_A (
288- IRREP_T * L1_in, IRREP_T * L2_in, WEIGHT_T * W, IRREP_T * L3_grad,
294+ __global__ void
295+ {{ launch_bounds (forward_schedule) }}
296+ double_backward_A (IRREP_T * L1_in, IRREP_T * L2_in, WEIGHT_T * W, IRREP_T * L3_grad,
289297 IRREP_T * L1_dgrad, IRREP_T * L2_dgrad, IRREP_T * W_dgrad,
290298 IRREP_T * L1_grad, IRREP_T * L2_grad, WEIGHT_T * W_grad, IRREP_T * L3_dgrad,
291299 ConvData c, void * workspace_raw, unsigned {{idx_type}}* transpose_perm) {
@@ -391,16 +399,17 @@ __global__ void double_backward_A(
391399 } {%- endfor %}
392400}
393401
394- {{ generate_fixup_kernel (" fixup_double_backwardB" , double_backward_schedule. launch_config . warp_size , double_backward_schedule.L1 .dim , double_backwardB_offset) }}
402+ {{ generate_fixup_kernel (" fixup_double_backwardB" , double_backward_schedule, double_backward_schedule.L1 .dim , double_backwardB_offset) }}
395403
396404{%- for i, segment in enumerate (double_backward_schedule.segments ) %}
397405{{ generate_segment_kernel_backward (i, segment, double_backward_schedule.launch_config .warp_size , double_bwd=True) }}
398406{%- endfor %}
399407
400408{% set schedule = double_backward_schedule %}
401409
402- __global__ void double_backward_B (
403- IRREP_T * L1_in, IRREP_T * L2_in, WEIGHT_T * W, IRREP_T * L3_grad,
410+ __global__ void
411+ {{ launch_bounds (double_backward_schedule) }}
412+ double_backward_B (IRREP_T * L1_in, IRREP_T * L2_in, WEIGHT_T * W, IRREP_T * L3_grad,
404413 IRREP_T * L1_dgrad, IRREP_T * L2_dgrad, IRREP_T * W_dgrad,
405414 IRREP_T * L1_grad, IRREP_T * L2_grad, WEIGHT_T * W_grad, IRREP_T * L3_dgrad,
406415 ConvData c, void * workspace_raw, unsigned {{idx_type}}* transpose_perm) {
0 commit comments