Skip to content

Conversation

@yucai-intel
Copy link
Contributor

Description
This PR optimizes the indexing_backward kernel on XPU by implementing a specialized aggregation strategy for sorted indices. The primary goal is to minimize global memory contention during gradient accumulation.

Key Optimizations

  • Duplicate Aggregation & Lookahead: Instead of performing individual atomic updates, the kernel identifies contiguous identical indices using an optimized lookahead mechanism (SKIP_SORTED_INDICES). This collapses multiple redundant updates into a single localized accumulation, significantly reducing atomic contention on grad_weight.

  • Sub-group Parallel Reduction: For clusters with high duplicate counts, the kernel utilizes sub-group shuffle primitives (shift_group_left) to perform parallel reductions. This ensures that large index blocks are processed across all lanes within a sub-group simultaneously, maximizing compute throughput.

  • Tiled Stride Optimization: Three specialized kernel variants are introduced to handle different data layouts:
    stride_1: Optimized for scalar-like indexing with maximum throughput.
    small_stride: Parallelizes across the feature dimension using local work-items.
    generic_stride: Handles high-dimensional feature vectors with optimized memory tiling.

  • SLM-backed Duplicate Cache: Implemented a Shared Local Memory (SLM) cache (smem_dups_cache) to coordinate duplicate counts within a sub-group, reducing redundant global memory fetches for index metadata.

@yucai-intel yucai-intel changed the title Yucai/index/put Optimize Indexing Backward Kernel with Sub-group Aggregation and Tailored Stride Dispatch Jan 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants