@@ -118,6 +118,17 @@ def generate_solutions(
118118 gpu_target_info : iree_gpu .TargetInfo ,
119119 ** pipeline_constraint_options ,
120120 ) -> Iterator [list [common .TuningConfiguration ]]:
121+ # Filter use_direct_load for unsupported configurations.
122+ codegen_pipeline = iree_codegen .DispatchLoweringPassPipeline .LLVMGPUTileAndFuse
123+ pipeline_constraint_options [
124+ "allowed_use_direct_load"
125+ ] = rocm_common .filter_use_direct_load (
126+ pipeline_constraint_options .get ("allowed_use_direct_load" , [False ]),
127+ codegen_pipeline ,
128+ gpu_target_info .arch ,
129+ rocm_common .ConvolutionStrategy .igemm , # Contraction uses IGEMM-like path.
130+ )
131+
121132 return rocm_solutions .generate_generic_contraction_solutions (
122133 tuner_ctx = tuner_context ,
123134 gpu_target_info = gpu_target_info ,
@@ -128,7 +139,7 @@ def generate_solutions(
128139 res_type = self .op_info .res_type ,
129140 dispatch_kind = common .DispatchKind .contraction ,
130141 indexing_maps = self .op_info .indexing_maps ,
131- codegen_pipeline = iree_codegen . DispatchLoweringPassPipeline . LLVMGPUTileAndFuse ,
142+ codegen_pipeline = codegen_pipeline ,
132143 ** pipeline_constraint_options ,
133144 )
134145
@@ -164,11 +175,25 @@ def generate_solutions(
164175 self .op_info .convolution_dims is not None
165176 ), "convolution_dims must be set for convolution operations"
166177
178+ codegen_pipeline = iree_codegen .DispatchLoweringPassPipeline .LLVMGPUTileAndFuse
179+
167180 # Generate IGEMM candidates.
168181 if conv_strategy & rocm_common .ConvolutionStrategy .igemm :
169182 tuner_context .logger .info (
170183 "Generating convolution candidates using IGEMM strategy"
171184 )
185+
186+ # Filter use_direct_load for IGEMM strategy.
187+ igemm_options = pipeline_constraint_options .copy ()
188+ igemm_options [
189+ "allowed_use_direct_load"
190+ ] = rocm_common .filter_use_direct_load (
191+ igemm_options .get ("allowed_use_direct_load" , [False ]),
192+ codegen_pipeline ,
193+ gpu_target_info .arch ,
194+ rocm_common .ConvolutionStrategy .igemm ,
195+ )
196+
172197 yield from rocm_solutions .generate_generic_contraction_solutions (
173198 tuner_ctx = tuner_context ,
174199 gpu_target_info = gpu_target_info ,
@@ -179,11 +204,11 @@ def generate_solutions(
179204 res_type = self .op_info .res_type ,
180205 dispatch_kind = common .DispatchKind .conv ,
181206 indexing_maps = self .op_info .indexing_maps ,
182- codegen_pipeline = iree_codegen . DispatchLoweringPassPipeline . LLVMGPUTileAndFuse ,
207+ codegen_pipeline = codegen_pipeline ,
183208 igemm_details = self .op_info .igemm_details ,
184209 conv_to_igemm_info = self .op_info .conv_to_igemm_info ,
185210 convolution_dims = self .op_info .convolution_dims ,
186- ** pipeline_constraint_options ,
211+ ** igemm_options ,
187212 )
188213
189214 # Generate direct convolution candidates if supported.
@@ -192,6 +217,18 @@ def generate_solutions(
192217 tuner_context .logger .info (
193218 "Generating convolution candidates using direct strategy"
194219 )
220+
221+ # Filter use_direct_load for direct conv strategy.
222+ direct_options = pipeline_constraint_options .copy ()
223+ direct_options [
224+ "allowed_use_direct_load"
225+ ] = rocm_common .filter_use_direct_load (
226+ direct_options .get ("allowed_use_direct_load" , [False ]),
227+ codegen_pipeline ,
228+ gpu_target_info .arch ,
229+ rocm_common .ConvolutionStrategy .direct ,
230+ )
231+
195232 direct_dims , direct_sizes = self ._compute_direct_conv_dimensions ()
196233 # Pass filter loop info so solution generator can add them with tile size 1.
197234 direct_conv_info : rocm_solutions .DirectConvInfo = {
@@ -210,11 +247,11 @@ def generate_solutions(
210247 res_type = self .op_info .res_type ,
211248 dispatch_kind = common .DispatchKind .conv ,
212249 indexing_maps = self .op_info .indexing_maps ,
213- codegen_pipeline = iree_codegen . DispatchLoweringPassPipeline . LLVMGPUTileAndFuse ,
250+ codegen_pipeline = codegen_pipeline ,
214251 igemm_details = None ,
215252 conv_to_igemm_info = None ,
216253 direct_conv_info = direct_conv_info ,
217- ** pipeline_constraint_options ,
254+ ** direct_options ,
218255 )
219256
220257 def _supports_direct_convolution (self , tuner_context : common .TunerContext ) -> bool :
0 commit comments