Skip to content

Commit b1d73d0

Browse files
[Codegen][Tuner] Expose XOR shuffle bounds and validation functions in CAPI (#23442)
After #23175, we now generate swizzle hint ops for scaled gemms whose parameters are set during `lowering_config` selection. These parameters (`rowElems` and `accessElems`) can be chosen via the tuner too (although in the future we intend for there to be an analytically derived solution to this). This PR exposes two functions, `getXorShuffleBounds` and `isXorShuffleValid`, to allow the tuner to constrain its search space for applicable XOR swizzles. Assisted by: composer-1 --------- Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
1 parent 2ffd825 commit b1d73d0

10 files changed

Lines changed: 120 additions & 9 deletions

File tree

compiler/bindings/c/iree/compiler/dialects/iree_gpu.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,19 @@ ireeGPUTargetInfoGetMMAIntrinsics(MlirAttribute mmaIntrinsics,
184184
mma_intrinsic_enum_t *mmaIntrinsicVals,
185185
uint8_t *virtualMmaIntrinsicTags);
186186

187+
// Returns the lower and upper bounds for valid XOR shuffle parameters for a
188+
// given MMA intrinsic and operand index. On success, writes to minAccessElems
189+
// and totalTileElems and returns true. Returns false on failure.
190+
MLIR_CAPI_EXPORTED bool ireeGPUGetXorShuffleBounds(MlirAttribute mmaIntrinsic,
191+
int32_t operandIndex,
192+
int64_t *minAccessElems,
193+
int64_t *totalTileElems);
194+
195+
// Returns true if the XOR shuffle is valid for the given parameters.
196+
MLIR_CAPI_EXPORTED bool ireeGPUIsXORShuffleValid(int64_t numRowElems,
197+
int64_t numAccessElems,
198+
int64_t totalTileElems);
199+
187200
#ifdef __cplusplus
188201
}
189202
#endif

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ NB_MODULE(_ireeCompilerDialects, m) {
485485
if (!mlirAttributeIsNull(basisInfo.mappingAttr)) {
486486
mapping = getIntArrayAttrValues(basisInfo.mappingAttr);
487487
}
488-
return std::make_tuple(counts, mapping);
488+
return std::tuple(counts, mapping);
489489
})
490490
.def_property_readonly(
491491
"mma_kind", [](MlirAttribute self) -> std::optional<MlirAttribute> {
@@ -633,14 +633,38 @@ NB_MODULE(_ireeCompilerDialects, m) {
633633
});
634634

635635
iree_gpu_module.def(
636-
"get_single_subgroup_layout",
637-
[](MlirAttribute attr, int fragment) {
638-
return ireeGPUGetSingleSubgroupLayout(attr, fragment);
639-
},
636+
"get_single_subgroup_layout", ireeGPUGetSingleSubgroupLayout,
640637
"Returns the single subgroup layout (element, thread, outer, "
641638
"tstrides) for a given MMA or VirtualMMA intrinsic and fragment. ",
642639
py::arg("attr"), py::arg("fragment"));
643640

641+
//===-------------------------------------------------------------------===//
642+
// Binding to XOR shuffle utility functions
643+
//===-------------------------------------------------------------------===//
644+
645+
iree_gpu_module.def(
646+
"get_xor_shuffle_bounds",
647+
[](MlirAttribute mmaIntrinsic,
648+
int operandIndex) -> std::optional<std::tuple<int64_t, int64_t>> {
649+
int64_t minAccessElems = 0, totalTileElems = 0;
650+
if (ireeGPUGetXorShuffleBounds(mmaIntrinsic, operandIndex,
651+
&minAccessElems, &totalTileElems)) {
652+
return std::tuple(minAccessElems, totalTileElems);
653+
}
654+
return std::nullopt;
655+
},
656+
"Returns the bounds for valid XOR shuffle parameters (min_access_elems, "
657+
"total_tile_elems) for the given MMA intrinsic and operand index. See "
658+
"GPUUtils for sweep semantics. Returns (min_access_elems, "
659+
"total_tile_elems) or None on failure.",
660+
py::arg("mmaIntrinsic"), py::arg("operand_index"));
661+
662+
iree_gpu_module.def(
663+
"is_xor_shuffle_valid", ireeGPUIsXORShuffleValid,
664+
"Returns true if the XOR shuffle is valid for the given parameters.",
665+
py::arg("num_row_elems"), py::arg("num_access_elems"),
666+
py::arg("total_tile_elems"));
667+
644668
//===-------------------------------------------------------------------===//
645669
// Binding to utility function getExecutableVariantOps
646670
//===-------------------------------------------------------------------===//

compiler/bindings/python/test/api/tuner_api_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from iree.compiler import ir
88
from iree.compiler.dialects import iree_codegen
9+
from iree.compiler.dialects import iree_gpu
910
from iree.compiler.dialects import affine
1011
from iree.compiler.ir import AffineMap, AffineDimExpr
1112

@@ -545,3 +546,34 @@ def test_infer_scaled_contraction_dimensions():
545546
assert dims_batched.n == [2], f"Got {dims_batched.n}"
546547
assert dims_batched.k == [3], f"Got {dims_batched.k}"
547548
assert dims_batched.kB == [4], f"Got {dims_batched.kB}"
549+
550+
551+
@run
552+
def test_is_xor_shuffle_valid():
553+
"""Test XOR shuffle validation (pure function, no MLIR attributes)."""
554+
# Valid: row and access divide tile; row >= access; tile >= row.
555+
assert iree_gpu.is_xor_shuffle_valid(256, 32, 512)
556+
assert iree_gpu.is_xor_shuffle_valid(512, 64, 512)
557+
assert iree_gpu.is_xor_shuffle_valid(32, 8, 512)
558+
# Invalid: row exceeds tile.
559+
assert not iree_gpu.is_xor_shuffle_valid(512, 32, 256)
560+
# Invalid: access exceeds row.
561+
assert not iree_gpu.is_xor_shuffle_valid(256, 512, 512)
562+
# Invalid: row does not evenly divide tile.
563+
assert not iree_gpu.is_xor_shuffle_valid(300, 32, 512)
564+
# Invalid: access does not evenly divide row.
565+
assert not iree_gpu.is_xor_shuffle_valid(256, 33, 512)
566+
567+
568+
@run
569+
def test_get_xor_shuffle_bounds():
570+
"""Test XOR shuffle bounds for an MMA intrinsic (for use by SharkTuner)."""
571+
# Use an MMA intrinsic that supports getXorShuffleBounds (InnerTileDescAttrInterface).
572+
mma_attr = iree_gpu.MMAAttr.get(iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16)
573+
bounds = iree_gpu.get_xor_shuffle_bounds(mma_attr, operand_index=0)
574+
assert bounds is not None, "get_xor_shuffle_bounds should succeed for MMAAttr"
575+
min_access_elems, total_tile_elems = bounds
576+
assert min_access_elems == 4
577+
assert total_tile_elems == 256
578+
bounds_rhs = iree_gpu.get_xor_shuffle_bounds(mma_attr, operand_index=1)
579+
assert bounds_rhs is not None

compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,29 @@ void ireeGPUTargetInfoGetMMAIntrinsics(MlirAttribute mmaIntrinsics,
530530
assert(false && "Unexpected attribute type in MMA intrinsics array");
531531
}
532532
}
533+
534+
bool ireeGPUGetXorShuffleBounds(MlirAttribute mmaIntrinsic,
535+
int32_t operandIndex, int64_t *minAccessElems,
536+
int64_t *totalTileElems) {
537+
assert(!mlirAttributeIsNull(mmaIntrinsic) && "mmaIntrinsic cannot be null");
538+
auto innerTileDesc = llvm::dyn_cast<
539+
mlir::iree_compiler::IREE::Codegen::InnerTileDescAttrInterface>(
540+
unwrap(mmaIntrinsic));
541+
assert(innerTileDesc && "innerTileDesc cannot be null");
542+
assert(minAccessElems && "minAccessElems cannot be null");
543+
assert(totalTileElems && "totalTileElems cannot be null");
544+
mlir::FailureOr<mlir::iree_compiler::XorShuffleBounds> result =
545+
mlir::iree_compiler::getXorShuffleBounds(innerTileDesc, operandIndex);
546+
if (llvm::failed(result)) {
547+
return false;
548+
}
549+
*minAccessElems = result->minAccessElems;
550+
*totalTileElems = result->totalTileElems;
551+
return true;
552+
}
553+
554+
bool ireeGPUIsXORShuffleValid(int64_t numRowElems, int64_t numAccessElems,
555+
int64_t totalTileElems) {
556+
return mlir::iree_compiler::isXORShuffleValid(numRowElems, numAccessElems,
557+
totalTileElems);
558+
}

compiler/src/iree/compiler/API/api_exports.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ extern void ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID();
2828
extern void ireeCodegenDispatchLoweringPassPipelineAttrGetValue();
2929
extern void ireeCodegenGetExecutableVariantOps();
3030
extern void ireeGPUGetSingleSubgroupLayout();
31+
extern void ireeGPUGetXorShuffleBounds();
32+
extern void ireeGPUIsXORShuffleValid();
3133
extern void ireeCodegenGetTunerRootOps();
3234
extern void ireeCodegenGetAttentionOpDetail();
3335
extern void ireeCodegenInferScaledContractionDimensions();
@@ -949,6 +951,8 @@ uintptr_t __iree_compiler_hidden_force_extern() {
949951
x += (uintptr_t)&ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
950952
x += (uintptr_t)&ireeCodegenGetExecutableVariantOps;
951953
x += (uintptr_t)&ireeGPUGetSingleSubgroupLayout;
954+
x += (uintptr_t)&ireeGPUGetXorShuffleBounds;
955+
x += (uintptr_t)&ireeGPUIsXORShuffleValid;
952956
x += (uintptr_t)&ireeCodegenGetTunerRootOps;
953957
x += (uintptr_t)&ireeCodegenGetAttentionOpDetail;
954958
x += (uintptr_t)&ireeCodegenInferScaledContractionDimensions;

compiler/src/iree/compiler/API/api_exports.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ EXPORTS
1818
ireeCodegenDispatchLoweringPassPipelineAttrGetValue
1919
ireeCodegenGetExecutableVariantOps
2020
ireeGPUGetSingleSubgroupLayout
21+
ireeGPUGetXorShuffleBounds
22+
ireeGPUIsXORShuffleValid
2123
ireeCodegenGetTunerRootOps
2224
ireeCodegenGetAttentionOpDetail
2325
ireeCodegenInferScaledContractionDimensions

compiler/src/iree/compiler/API/api_exports.ld

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ VER_0 {
1919
ireeCodegenDispatchLoweringPassPipelineAttrGetValue;
2020
ireeCodegenGetExecutableVariantOps;
2121
ireeGPUGetSingleSubgroupLayout;
22+
ireeGPUGetXorShuffleBounds;
23+
ireeGPUIsXORShuffleValid;
2224
ireeCodegenGetTunerRootOps;
2325
ireeCodegenGetAttentionOpDetail;
2426
ireeCodegenInferScaledContractionDimensions;

compiler/src/iree/compiler/API/api_exports.macos.lst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ _ireeCodegenDispatchLoweringPassPipelineAttrGetTypeID
1717
_ireeCodegenDispatchLoweringPassPipelineAttrGetValue
1818
_ireeCodegenGetExecutableVariantOps
1919
_ireeGPUGetSingleSubgroupLayout
20+
_ireeGPUGetXorShuffleBounds
21+
_ireeGPUIsXORShuffleValid
2022
_ireeCodegenGetTunerRootOps
2123
_ireeCodegenGetAttentionOpDetail
2224
_ireeCodegenInferScaledContractionDimensions

compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ validateXorShuffle(FailureOr<XorShuffleParams> swizzle,
840840
// Disabling clang-tidy for the following functions, as it will be externally
841841
// linked to the CAPI in a future PR.
842842
// NOLINTBEGIN(misc-use-internal-linkage)
843-
FailureOr<XorShuffleParams>
843+
FailureOr<XorShuffleBounds>
844844
getXorShuffleBounds(IREE::Codegen::InnerTileDescAttrInterface intrinsic,
845845
int operandIndex) {
846846
FailureOr<int64_t> maybeMinimumAccessElems =
@@ -850,8 +850,7 @@ getXorShuffleBounds(IREE::Codegen::InnerTileDescAttrInterface intrinsic,
850850
if (failed(maybeMinimumAccessElems) || failed(maybeTotalTileElems)) {
851851
return failure();
852852
}
853-
return XorShuffleParams({/*rowElems=*/*maybeMinimumAccessElems,
854-
/*accessElems=*/*maybeTotalTileElems});
853+
return XorShuffleBounds{*maybeMinimumAccessElems, *maybeTotalTileElems};
855854
}
856855

857856
bool isXORShuffleValid(int64_t numRowElems, int64_t numAccessElems,

compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ struct XorShuffleParams {
211211
int64_t accessElems;
212212
};
213213

214+
/// Bounds for valid XOR shuffle parameters (min access elements per thread,
215+
/// total elements in the tile). Used when sweeping over valid configs.
216+
struct XorShuffleBounds {
217+
int64_t minAccessElems;
218+
int64_t totalTileElems;
219+
};
220+
214221
/// For a given MMA intrinsic and operand, returns the lower bound and upper
215222
/// bound for valid values of XOR shuffle attribute parameters, access width and
216223
/// row width. For both parameters, the elements ingested per thread at a time
@@ -221,7 +228,7 @@ struct XorShuffleParams {
221228
/// the upper bound.
222229
/// - sweep row elements over all multiple of the access elements, respecting
223230
/// the upper bound.
224-
FailureOr<XorShuffleParams>
231+
FailureOr<XorShuffleBounds>
225232
getXorShuffleBounds(IREE::Codegen::InnerTileDescAttrInterface intrinsic,
226233
int operandIndex);
227234

0 commit comments

Comments
 (0)