[WIP] ArgCompare Benchmark Results and Performance Analysis #23609
Replies: 3 comments
-
|
Thanks for putting this together. Can you add units and element types to the table headers? I assume we are measuring microseconds? What stands out to me is that the numbers in the It’s also worth checking what are the concrete problem sizes and data types in argmax ops from real-world models, and benchmark based off that. |
Beta Was this translation helpful? Give feedback.
-
Tuning Update: Forced SG/PR SweepSG controls the number of subgroups per workgroup (
Tuning Update: Forced SG/PR/ST SweepAfter enabling split reduction for K-NN and BERT-Base, ST controls the split reduction tile size (values: 256, 512, 1024, 2048)
Summary: After tuning sweep, VD/VD_Split wins 17/18 cases. The only remaining gap is ResNet-50 (32×1024) that requires further investigation.
|
Beta Was this translation helpful? Give feedback.
-
DPP+Ballot vs Shuffle-Only ComparisonContext: The real question here is whether shuffle-only approach (without DPP + ballot) is enough, which suggests that the performance difference may be minimal for memory-bound workloads. This comparison evaluates both approaches and guides the choice of implementation to support ArgCompare within the VectorDistribute pipeline. Standard (No Split Reduction)
Without split: DPP matters only for 1D reductions 32K-131K (1.7-3.2x faster). For 262K and all batched/2D workloads, they're essentially identical (~1.0x). Split Reduction
(imagenet_1k and vq_codebook excluded due to the reduction dimension being smaller than the 32768 threshold) With split: DPP is consistently 1.0-1.6x faster. The gap is uniform across all sizes because split reduction normalizes the per-tile work. Summary: Based on the performance results above, we should upstream a gpu.ballot operation to the GPU dialect, which can then be lowered to target-specific implementations (rocdl.ballot for AMD GPU, equivalent ops for SPIR-V and NVVM) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hardware: AMD MI300X (gfx942)
Refer to https://gist.github.com/bangtianliu/256ad601139f50300fd8c6aa2125eb20 for our previous synthetic benchmarking on ArgCompare.
ArgMax/ArgMin Operations Analysis in real-world AI models
The existing VD pipeline still needs mask support for arbitrary reduction sizes. We use power-of-2 padding as a valid workaround to evaluate the performance of our VD implementation for ArgCompare.
Refer to Real-World ArgMax/ArgMin Operations Analysis for detailed info.
Implementations Compared
LLVMGPUVectorDistributepipeline for ArgCompare. Uses subgroup shuffles and DPP (Data-Parallel Primitives) for efficient cross-lane reductions on AMD GPUs. Single workgroup processes entire reduction with ROCDL ballot optimizations.DeviceReduce::ArgMax/ArgMinfor single-batch andDeviceSegmentedReducefor batched reductions. Highly optimized with adaptive algorithms (1-pass vs 2-pass based on reduction size).LLVMGPUDefaultpipeline withlinalg.genericlowered to pre-compiled bitcode (iree_uk_amdgpu_argmax_f32i64.gfx942.bc). Compiled with--iree-rocm-enable-ukernels=all. Uses workgroup_size=[64,1,1] with reduction tile size of 64.torch.argmax/torch.argminvia ROCm backend. Uses PyTorch's TensorIterator-based GPU reduction kernels (not hipCUB). On ROCm, CUDA kernels are auto-converted to HIP via HIPification.DeviceReduceMultiBlockkernel is used.Benchmark Results
Methodology: rocprof kernel timing
Warm-up: 100 runs, Benchmark: 500 runs averaged
VD_Split is enabled via compiler flags
--iree-dispatch-creation-enable-split-reductionand--iree-preprocessing-pass-pipeline='builtin.module(iree-dispatch-creation-set-split-reduction-sizes{split-reduction-target-size=TILE})'for reductions ≥ 32K elements, which tiles the reduction into smaller chunks (e.g., 512, 1024, 2048) processed in parallel before a final merge.Summary:
VectorDistribute Lowering Config
The following table shows the compiler's lowering configuration for each model when using
LLVMGPUVectorDistributepipeline.Column Definitions
[WIP]: Add in-depth performance analysis and evaluate additional tools (e.g., MIOpen benchmarks).
Beta Was this translation helpful? Give feedback.
All reactions