@@ -124,12 +124,25 @@ class ConvKernelConfig:
124124 vector_size_b : int = 8
125125 vector_size_c : int = 8
126126
127- # Fixed parameters
127+ # Occupancy parameters
128128 block_per_cu : int = 1
129129 num_wave_groups : int = 1
130+ num_groups_to_merge : int = 1 # For group merged convolution
131+
132+ # Double buffering
133+ double_smem_buffer : bool = False
130134
131135 def name (self , datatype : str ) -> str :
132- """Generate kernel name"""
136+ """
137+ Generate kernel name that uniquely identifies the kernel configuration.
138+
139+ Format: conv_{variant}_{dtype}_{ndim}d_{pipeline}_{epilogue}_{scheduler}
140+ _{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}
141+ _{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
142+ [_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}]
143+
144+ All parameters that affect kernel behavior are included.
145+ """
133146 t = self .tile
134147 tr = self .trait
135148
@@ -139,12 +152,42 @@ def name(self, datatype: str) -> str:
139152 ConvVariant .BACKWARD_WEIGHT : "bwdw" ,
140153 }[self .variant ]
141154
155+ # Core identity: variant, dtype, dims
142156 name = f"conv_{ variant_str } _{ datatype } _{ self .ndim_spatial } d"
157+
158+ # Pipeline configuration
143159 name += f"_{ tr .pipeline } _{ tr .epilogue } _{ tr .scheduler } "
160+
161+ # Block tile dimensions (M_Tile x N_Tile x K_Tile)
144162 name += f"_{ t .tile_m } x{ t .tile_n } x{ t .tile_k } "
163+
164+ # Wave distribution (M_Warp x N_Warp x K_Warp)
145165 name += f"_{ t .warp_m } x{ t .warp_n } x{ t .warp_k } "
146166
147- # Add padding suffix if not all enabled
167+ # Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile)
168+ name += f"_{ t .warp_tile_m } x{ t .warp_tile_n } x{ t .warp_tile_k } "
169+
170+ # Vector sizes (only if non-default)
171+ if (self .vector_size_a , self .vector_size_b , self .vector_size_c ) != (4 , 8 , 8 ):
172+ name += (
173+ f"_vec{ self .vector_size_a } _{ self .vector_size_b } _{ self .vector_size_c } "
174+ )
175+
176+ # Occupancy hints (only if non-default)
177+ if self .block_per_cu != 1 :
178+ name += f"_bpc{ self .block_per_cu } "
179+
180+ if self .num_wave_groups != 1 :
181+ name += f"_wg{ self .num_wave_groups } "
182+
183+ if self .num_groups_to_merge != 1 :
184+ name += f"_gm{ self .num_groups_to_merge } "
185+
186+ # Double SMEM buffer (for compute V4+)
187+ if self .double_smem_buffer or tr .double_smem_buffer :
188+ name += "_dsb"
189+
190+ # Padding suffix (only if not all enabled)
148191 if not (tr .pad_m and tr .pad_n and tr .pad_k ):
149192 name += f"_pad{ int (tr .pad_m )} { int (tr .pad_n )} { int (tr .pad_k )} "
150193
@@ -786,6 +829,44 @@ def main():
786829 help = "List configurations without generating" ,
787830 )
788831
832+ # Individual kernel configuration (when not using predefined configs)
833+ parser .add_argument ("--tile-m" , type = int , help = "Block tile M dimension" )
834+ parser .add_argument ("--tile-n" , type = int , help = "Block tile N dimension" )
835+ parser .add_argument ("--tile-k" , type = int , help = "Block tile K dimension" )
836+ parser .add_argument ("--warp-m" , type = int , help = "Wave distribution M" )
837+ parser .add_argument ("--warp-n" , type = int , help = "Wave distribution N" )
838+ parser .add_argument ("--warp-k" , type = int , default = 1 , help = "Wave distribution K" )
839+ parser .add_argument ("--warp-tile-m" , type = int , help = "Warp tile M" )
840+ parser .add_argument ("--warp-tile-n" , type = int , help = "Warp tile N" )
841+ parser .add_argument ("--warp-tile-k" , type = int , default = 16 , help = "Warp tile K" )
842+ parser .add_argument (
843+ "--pipeline" ,
844+ type = str ,
845+ choices = ["mem" , "compv3" , "compv4" , "compv5" ],
846+ help = "Pipeline type" ,
847+ )
848+ parser .add_argument (
849+ "--scheduler" ,
850+ type = str ,
851+ choices = ["intrawave" , "interwave" ],
852+ help = "Scheduler type" ,
853+ )
854+ parser .add_argument (
855+ "--epilogue" ,
856+ type = str ,
857+ default = "cshuffle" ,
858+ choices = ["cshuffle" , "default" ],
859+ help = "Epilogue type" ,
860+ )
861+ parser .add_argument ("--pad-m" , type = bool , default = True , help = "Pad M dimension" )
862+ parser .add_argument ("--pad-n" , type = bool , default = True , help = "Pad N dimension" )
863+ parser .add_argument ("--pad-k" , type = bool , default = True , help = "Pad K dimension" )
864+ parser .add_argument ("--vector-a" , type = int , default = 4 , help = "Vector size A" )
865+ parser .add_argument ("--vector-b" , type = int , default = 8 , help = "Vector size B" )
866+ parser .add_argument ("--vector-c" , type = int , default = 8 , help = "Vector size C" )
867+ parser .add_argument ("--block-per-cu" , type = int , default = 1 , help = "Blocks per CU" )
868+ parser .add_argument ("--num-wave-groups" , type = int , default = 1 , help = "Wave groups" )
869+
789870 args = parser .parse_args ()
790871
791872 if args .verbose :
@@ -799,11 +880,53 @@ def main():
799880 }
800881 requested_variants = [variant_map [v ] for v in args .variant ]
801882
802- # Get configurations for target arch with requested variants and ndims
803- filtered_configs = get_default_configs (
804- arch = args .arch , variants = requested_variants , ndims = args .ndim
883+ # Check if user specified custom configuration
884+ custom_config = (
885+ args .tile_m is not None or args .tile_n is not None or args . pipeline is not None
805886 )
806887
888+ if custom_config :
889+ # Build custom config from CLI arguments
890+ tile = TileConfig (
891+ tile_m = args .tile_m or 128 ,
892+ tile_n = args .tile_n or 128 ,
893+ tile_k = args .tile_k or 64 ,
894+ warp_m = args .warp_m or 2 ,
895+ warp_n = args .warp_n or 2 ,
896+ warp_k = args .warp_k or 1 ,
897+ warp_tile_m = args .warp_tile_m or 32 ,
898+ warp_tile_n = args .warp_tile_n or 32 ,
899+ warp_tile_k = args .warp_tile_k or 16 ,
900+ )
901+ trait = TraitConfig (
902+ pipeline = args .pipeline or "compv4" ,
903+ scheduler = args .scheduler or "intrawave" ,
904+ epilogue = args .epilogue or "cshuffle" ,
905+ pad_m = args .pad_m ,
906+ pad_n = args .pad_n ,
907+ pad_k = args .pad_k ,
908+ )
909+ config = ConvKernelConfig (
910+ tile = tile ,
911+ trait = trait ,
912+ variant = requested_variants [0 ]
913+ if requested_variants
914+ else ConvVariant .FORWARD ,
915+ ndim_spatial = args .ndim [0 ] if args .ndim else 2 ,
916+ arch = args .arch ,
917+ vector_size_a = args .vector_a ,
918+ vector_size_b = args .vector_b ,
919+ vector_size_c = args .vector_c ,
920+ block_per_cu = args .block_per_cu ,
921+ num_wave_groups = args .num_wave_groups ,
922+ )
923+ filtered_configs = [config ]
924+ else :
925+ # Get predefined configurations for target arch with requested variants and ndims
926+ filtered_configs = get_default_configs (
927+ arch = args .arch , variants = requested_variants , ndims = args .ndim
928+ )
929+
807930 if args .list_configs :
808931 print (f"Convolution configurations for { args .arch } :" )
809932 print (f" Datatypes: { args .datatype } " )
0 commit comments