Optimize Indexing Backward Kernel with Sub-group Aggregation and Tailored Stride Dispatch #2749
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR optimizes the indexing_backward kernel on XPU by implementing a specialized aggregation strategy for sorted indices. The primary goal is to minimize global memory contention during gradient accumulation.
Key Optimizations
Duplicate Aggregation & Lookahead: Instead of performing individual atomic updates, the kernel identifies contiguous identical indices using an optimized lookahead mechanism (SKIP_SORTED_INDICES). This collapses multiple redundant updates into a single localized accumulation, significantly reducing atomic contention on grad_weight.
Sub-group Parallel Reduction: For clusters with high duplicate counts, the kernel utilizes sub-group shuffle primitives (shift_group_left) to perform parallel reductions. This ensures that large index blocks are processed across all lanes within a sub-group simultaneously, maximizing compute throughput.
Tiled Stride Optimization: Three specialized kernel variants are introduced to handle different data layouts:
stride_1: Optimized for scalar-like indexing with maximum throughput.
small_stride: Parallelizes across the feature dimension using local work-items.
generic_stride: Handles high-dimensional feature vectors with optimized memory tiling.
SLM-backed Duplicate Cache: Implemented a Shared Local Memory (SLM) cache (smem_dups_cache) to coordinate duplicate counts within a sub-group, reducing redundant global memory fetches for index metadata.