Skip to content

Commit 540386b

Browse files
committed
[tuner] add use_direct_load (Global Load DMA) support to tuner
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent a36ec8c commit 540386b

10 files changed

Lines changed: 439 additions & 32 deletions

amdsharktuner/amdsharktuner/candidate_gen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def generate_solutions(
9595
num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints.
9696
allowed_waves_per_eu: list[int] = [2],
9797
allowed_denorm_flushing: list[bool] = [False],
98+
allowed_use_direct_load: list[bool] = [False],
9899
pipeline_options_search_space: rocm_dispatch_constraints.PipelineOptionsSearchSpace = rocm_dispatch_constraints.PipelineOptionsSearchSpace(),
99-
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
100100
conv_strategy: rocm_common.ConvolutionStrategy = rocm_common.ConvolutionStrategy.igemm
101101
| rocm_common.ConvolutionStrategy.direct,
102102
) -> Iterator[list[common.TuningConfiguration]]:
@@ -112,6 +112,8 @@ def generate_solutions(
112112
target_info,
113113
num_subgroups=num_subgroups,
114114
allowed_waves_per_eu=allowed_waves_per_eu,
115+
allowed_denorm_flushing=allowed_denorm_flushing,
116+
allowed_use_direct_load=allowed_use_direct_load,
115117
pipeline_options_search_space=pipeline_options_search_space,
116118
conv_strategy=conv_strategy,
117119
)
@@ -122,6 +124,7 @@ def generate_solutions(
122124
num_subgroups=num_subgroups,
123125
allowed_waves_per_eu=allowed_waves_per_eu,
124126
allowed_denorm_flushing=allowed_denorm_flushing,
127+
allowed_use_direct_load=allowed_use_direct_load,
125128
pipeline_options_search_space=pipeline_options_search_space,
126129
)
127130

amdsharktuner/amdsharktuner/common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ def get_lowering_config(
337337
assert (
338338
False
339339
), f"Unsupported type for key '{key}': {type(value).__name__}"
340+
case "promotion_types":
341+
# Handle list of Attribute objects for use_direct_load.
342+
if isinstance(value, Sequence):
343+
# Validate length matches promote_operands if present.
344+
if "promote_operands" in lowering_config_dict:
345+
promote_ops = lowering_config_dict["promote_operands"]
346+
if hasattr(promote_ops, "__len__") and len(value) != len(
347+
promote_ops
348+
):
349+
assert False, (
350+
f"promotion_types length ({len(value)}) must match "
351+
f"promote_operands length ({len(promote_ops)})"
352+
)
353+
promoted_value = ir.ArrayAttr.get(list(value))
354+
elif not isinstance(value, ir.ArrayAttr):
355+
assert (
356+
False
357+
), f"Unsupported type for key '{key}': {type(value).__name__}"
340358
case _:
341359
assert False, f"Unhandled key in lowering configuration: {key}"
342360

amdsharktuner/amdsharktuner/libtuner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,16 @@ def parse_arguments(
422422
"denormals to zero. Only applicable to attention ops. "
423423
"Possible values: [True, False]",
424424
)
425+
candidate_gen_args.add_argument(
426+
"--use-direct-load-options",
427+
type=lambda t: [s.strip().lower() == "true" for s in t.split(",")],
428+
default=[False],
429+
help="Comma-separated list of allowed values for use_direct_load. "
430+
"When True, enables Global Load DMA mode for matmul operand loading. "
431+
"Only supported on gfx950+ GPUs. Automatically sets "
432+
"no_reduce_shared_memory_bank_conflicts=true. "
433+
"Possible values: [True, False]. Default: [False].",
434+
)
425435
candidate_gen_args.add_argument(
426436
"--codegen-pipeline",
427437
choices=[x.value for x in CodegenPipelines],
@@ -839,8 +849,8 @@ def generate_candidate_specs(
839849
num_subgroups=args.num_subgroups,
840850
allowed_waves_per_eu=args.waves_per_eu_options,
841851
allowed_denorm_flushing=allowed_denorm_flushing,
852+
allowed_use_direct_load=args.use_direct_load_options,
842853
pipeline_options_search_space=pipeline_options_search_space,
843-
codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline),
844854
conv_strategy=conv_strategy,
845855
)
846856
if args.enable_random_seed:

amdsharktuner/amdsharktuner/rocm/rocm_common.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,36 @@
2525
# List of tested ROCm architectures.
2626
ROCM_ARCHITECTURES = ["gfx942", "gfx950", "gfx1100", "gfx1201"]
2727

28+
tune_logger = logging.getLogger("tune")
29+
30+
31+
def supports_global_load_dma(arch: str) -> bool:
32+
"""Check if architecture supports Global Load DMA (gfx950+).
33+
34+
CDNA4 is gfx950+ (majorVersion == 9 && minorVersion >= 5).
35+
"""
36+
if not arch.startswith("gfx"):
37+
return False
38+
try:
39+
version = int(arch[3:])
40+
major = version // 100
41+
minor = (version % 100) // 10
42+
return major == 9 and minor >= 5
43+
except ValueError:
44+
return False
45+
46+
47+
def get_use_global_load_dma_attr() -> ir.Attribute:
48+
"""Get the UseGlobalLoadDMAAttr for direct load promotion."""
49+
# TODO(Bangtian): Expose Python binding for iree_gpu.UseGlobalLoadDMAAttr instead of parsing string.
50+
return ir.Attribute.parse("#iree_gpu.use_global_load_dma")
51+
52+
53+
def get_promotion_types_for_direct_load(num_operands: int) -> list[ir.Attribute]:
54+
"""Get promotion_types array for direct load (all operands use DMA)."""
55+
dma_attr = get_use_global_load_dma_attr()
56+
return [dma_attr] * num_operands
57+
2858

2959
class ConvolutionStrategy(IntFlag):
3060
"""ROCm convolution lowering strategy for TileAndFuse pipeline."""
@@ -33,6 +63,46 @@ class ConvolutionStrategy(IntFlag):
3363
direct = 2
3464

3565

66+
def filter_use_direct_load(
67+
allowed_use_direct_load: list[bool],
68+
codegen_pipeline: "iree_codegen.DispatchLoweringPassPipeline",
69+
arch: str,
70+
conv_strategy: ConvolutionStrategy,
71+
) -> list[bool]:
72+
"""Filter use_direct_load options for unsupported configurations.
73+
74+
Returns filtered list with use_direct_load=True removed if unsupported.
75+
Logs warnings explaining why filtering occurred.
76+
"""
77+
from iree.compiler.dialects import iree_codegen # type: ignore
78+
79+
if not any(opt is True for opt in allowed_use_direct_load):
80+
return allowed_use_direct_load
81+
82+
if codegen_pipeline != iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse:
83+
tune_logger.warning(
84+
f"use_direct_load is only supported with TileAndFuse pipeline. "
85+
f"Current pipeline: {codegen_pipeline}. Disabling use_direct_load."
86+
)
87+
return [False]
88+
89+
if not supports_global_load_dma(arch):
90+
tune_logger.warning(
91+
f"use_direct_load is only supported on gfx950+ architectures. "
92+
f"Current architecture: {arch}. Disabling use_direct_load."
93+
)
94+
return [False]
95+
96+
if conv_strategy == ConvolutionStrategy.direct:
97+
tune_logger.warning(
98+
"use_direct_load is not supported for direct convolution strategy. "
99+
"Disabling use_direct_load."
100+
)
101+
return [False]
102+
103+
return allowed_use_direct_load
104+
105+
36106
@dataclass
37107
class ConvToIgemmInfo:
38108
"""

amdsharktuner/amdsharktuner/rocm/rocm_constraint_generators.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ def generate_solutions(
118118
gpu_target_info: iree_gpu.TargetInfo,
119119
**pipeline_constraint_options,
120120
) -> Iterator[list[common.TuningConfiguration]]:
121+
# Filter use_direct_load for unsupported configurations.
122+
codegen_pipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse
123+
pipeline_constraint_options[
124+
"allowed_use_direct_load"
125+
] = rocm_common.filter_use_direct_load(
126+
pipeline_constraint_options.get("allowed_use_direct_load", [False]),
127+
codegen_pipeline,
128+
gpu_target_info.arch,
129+
rocm_common.ConvolutionStrategy.igemm, # Contraction uses IGEMM-like path.
130+
)
131+
121132
return rocm_solutions.generate_generic_contraction_solutions(
122133
tuner_ctx=tuner_context,
123134
gpu_target_info=gpu_target_info,
@@ -128,7 +139,7 @@ def generate_solutions(
128139
res_type=self.op_info.res_type,
129140
dispatch_kind=common.DispatchKind.contraction,
130141
indexing_maps=self.op_info.indexing_maps,
131-
codegen_pipeline=iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
142+
codegen_pipeline=codegen_pipeline,
132143
**pipeline_constraint_options,
133144
)
134145

@@ -164,11 +175,25 @@ def generate_solutions(
164175
self.op_info.convolution_dims is not None
165176
), "convolution_dims must be set for convolution operations"
166177

178+
codegen_pipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse
179+
167180
# Generate IGEMM candidates.
168181
if conv_strategy & rocm_common.ConvolutionStrategy.igemm:
169182
tuner_context.logger.info(
170183
"Generating convolution candidates using IGEMM strategy"
171184
)
185+
186+
# Filter use_direct_load for IGEMM strategy.
187+
igemm_options = pipeline_constraint_options.copy()
188+
igemm_options[
189+
"allowed_use_direct_load"
190+
] = rocm_common.filter_use_direct_load(
191+
igemm_options.get("allowed_use_direct_load", [False]),
192+
codegen_pipeline,
193+
gpu_target_info.arch,
194+
rocm_common.ConvolutionStrategy.igemm,
195+
)
196+
172197
yield from rocm_solutions.generate_generic_contraction_solutions(
173198
tuner_ctx=tuner_context,
174199
gpu_target_info=gpu_target_info,
@@ -179,11 +204,11 @@ def generate_solutions(
179204
res_type=self.op_info.res_type,
180205
dispatch_kind=common.DispatchKind.conv,
181206
indexing_maps=self.op_info.indexing_maps,
182-
codegen_pipeline=iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
207+
codegen_pipeline=codegen_pipeline,
183208
igemm_details=self.op_info.igemm_details,
184209
conv_to_igemm_info=self.op_info.conv_to_igemm_info,
185210
convolution_dims=self.op_info.convolution_dims,
186-
**pipeline_constraint_options,
211+
**igemm_options,
187212
)
188213

189214
# Generate direct convolution candidates if supported.
@@ -192,6 +217,18 @@ def generate_solutions(
192217
tuner_context.logger.info(
193218
"Generating convolution candidates using direct strategy"
194219
)
220+
221+
# Filter use_direct_load for direct conv strategy.
222+
direct_options = pipeline_constraint_options.copy()
223+
direct_options[
224+
"allowed_use_direct_load"
225+
] = rocm_common.filter_use_direct_load(
226+
direct_options.get("allowed_use_direct_load", [False]),
227+
codegen_pipeline,
228+
gpu_target_info.arch,
229+
rocm_common.ConvolutionStrategy.direct,
230+
)
231+
195232
direct_dims, direct_sizes = self._compute_direct_conv_dimensions()
196233
# Pass filter loop info so solution generator can add them with tile size 1.
197234
direct_conv_info: rocm_solutions.DirectConvInfo = {
@@ -210,11 +247,11 @@ def generate_solutions(
210247
res_type=self.op_info.res_type,
211248
dispatch_kind=common.DispatchKind.conv,
212249
indexing_maps=self.op_info.indexing_maps,
213-
codegen_pipeline=iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
250+
codegen_pipeline=codegen_pipeline,
214251
igemm_details=None,
215252
conv_to_igemm_info=None,
216253
direct_conv_info=direct_conv_info,
217-
**pipeline_constraint_options,
254+
**direct_options,
218255
)
219256

220257
def _supports_direct_convolution(self, tuner_context: common.TunerContext) -> bool:

amdsharktuner/amdsharktuner/rocm/rocm_dispatch_constraints.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -757,34 +757,56 @@ def generate_tile_and_fuse_compilation_infos(
757757
padding: Optional[list[int]] = None,
758758
padding_conv: Optional[list[int]] = None,
759759
allowed_denorm_flushing: list[bool] = [False],
760+
allowed_use_direct_load: list[bool] = [False],
760761
) -> list[iree_codegen.CompilationInfoAttr]:
761762
"""Generate compilation infos for LLVMGPUTileAndFuse pipeline."""
762-
lowering_config_args = {
763-
"workgroup": workgroup_tile_sizes,
764-
"reduction": reduction_tile_sizes,
765-
"subgroup": subgroup_tile_sizes,
766-
"promote_operands": promote_operands,
767-
}
768-
769-
if mma_attr is not None:
770-
lowering_config_args["mma_kind"] = mma_attr
771-
772-
if padding is not None:
773-
lowering_config_args["padding"] = padding
774-
775-
if padding_conv is not None:
776-
lowering_config_args["padding_conv"] = padding_conv
763+
all_compilation_infos: list[iree_codegen.CompilationInfoAttr] = []
764+
765+
for use_direct_load in allowed_use_direct_load:
766+
lowering_config_args = {
767+
"workgroup": workgroup_tile_sizes,
768+
"reduction": reduction_tile_sizes,
769+
"subgroup": subgroup_tile_sizes,
770+
"promote_operands": promote_operands,
771+
}
772+
773+
# Add promotion_types when use_direct_load is enabled.
774+
if use_direct_load:
775+
# Defensive check: direct convolution should not reach here with use_direct_load=True.
776+
is_direct_conv = (
777+
pipeline_options_search_space.use_igemm_convolution is not None
778+
and pipeline_options_search_space.use_igemm_convolution == [False]
779+
)
780+
assert not is_direct_conv, (
781+
"use_direct_load=True is not supported for direct convolution. "
782+
"This should have been filtered in ROCmConvolutionTileAndFuseConstraintGenerator."
783+
)
784+
lowering_config_args[
785+
"promotion_types"
786+
] = rocm_common.get_promotion_types_for_direct_load(len(promote_operands))
787+
788+
if mma_attr is not None:
789+
lowering_config_args["mma_kind"] = mma_attr
790+
791+
if padding is not None:
792+
lowering_config_args["padding"] = padding
793+
794+
if padding_conv is not None:
795+
lowering_config_args["padding_conv"] = padding_conv
796+
797+
compilation_infos = _build_compilation_infos(
798+
tuner_ctx,
799+
lowering_config_args,
800+
workgroup_sizes,
801+
subgroup_size,
802+
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
803+
pipeline_options_search_space,
804+
allowed_waves_per_eu,
805+
allowed_denorm_flushing,
806+
)
807+
all_compilation_infos.extend(compilation_infos)
777808

778-
return _build_compilation_infos(
779-
tuner_ctx,
780-
lowering_config_args,
781-
workgroup_sizes,
782-
subgroup_size,
783-
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
784-
pipeline_options_search_space,
785-
allowed_waves_per_eu,
786-
allowed_denorm_flushing,
787-
)
809+
return all_compilation_infos
788810

789811

790812
def generate_vector_distribute_compilation_infos(

amdsharktuner/amdsharktuner/rocm/rocm_solutions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def generate_generic_contraction_solutions(
125125
num_subgroups: int = 4,
126126
allowed_waves_per_eu: list[int] = [2],
127127
allowed_denorm_flushing: list[bool] = [False],
128+
allowed_use_direct_load: list[bool] = [False],
128129
pipeline_options_search_space: rocm_dispatch_constraints.PipelineOptionsSearchSpace = rocm_dispatch_constraints.PipelineOptionsSearchSpace(),
129130
igemm_details: Optional[iree_codegen.IGEMMGenericConvDetails] = None,
130131
conv_to_igemm_info: Optional[rocm_common.ConvToIgemmInfo] = None,
@@ -378,6 +379,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
378379
padding=padding,
379380
padding_conv=padding_conv,
380381
allowed_denorm_flushing=allowed_denorm_flushing,
382+
allowed_use_direct_load=allowed_use_direct_load,
381383
)
382384
)
383385
case iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute:

0 commit comments

Comments
 (0)