Skip to content

Update StableHLO Sort Op Conversion#6989

Open
sdjukicTT wants to merge 1 commit intomainfrom
sdjukic/update-sort-conversion
Open

Update StableHLO Sort Op Conversion#6989
sdjukicTT wants to merge 1 commit intomainfrom
sdjukic/update-sort-conversion

Conversation

@sdjukicTT
Copy link
Contributor

Ticket

/

Problem description

StableHLO Sort Op sometimes produces TTIR Gather Ops, but we want to remove the generic TTIR Gather Op that mimics StableHLO Gather (#6579).

What's changed

Changed the kKeyValue SortType case of StableHLOToTTIRSortOpConversionPattern - removed TTIR Gather Ops and necessary transformations for its input, now TTIR Embedding Op is emitted directly.

In detail

For every tensor that needs to be reordered according to indices:

  1. ReshapeOp value to [total, 1] (EmbeddingOp requires 2D weights)
  2. EmbeddingOp with transformed indices
  3. ReshapeOp output back to original shape

Indices are transformed differently depending on whether sort dim was first, last or in the middle:
Case 1: sortDim = last (most common, stride = 1, no multiply needed)

  • Reshape indices to [pre, d_sort] (2D, already EmbeddingOp-compatible)
  • ArangeOp(start=0, end=total, step=d_sort, dim=0) → batch offsets [pre, d_sort]
  • flat_indices = AddOp(offsets, indices_2d)

Case 2: sortDim = first (stride = post)

  • Reshape indices to [d_sort, post] (2D, already EmbeddingOp-compatible)
  • ConstantOp([1]) → scalar stride tensor with value post
  • scaled_sort = MultiplyOp(indices_2d, stride_scalar) → broadcasted multiply,
    sort contribution
  • ArangeOp(start=0, end=post, step=1, dim=1) → post offsets [d_sort, post]
  • flat_indices = AddOp(scaled_sort, post_offsets)

Case 3: sortDim in middle (stride = post, need 3D intermediate)

  • Reshape indices to [pre, d_sort, post] (3D)
  • ConstantOp → precomputed offset tensor [pre, d_sort, post] where
    offsets[p, j, q] = p * d_sort * post + q (fully known at compile time — no
    runtime compute on accelerator)
  • ConstantOp([1]) → scalar stride tensor with value post
  • scaled_sort = MultiplyOp(indices_3d, stride_scalar) → broadcasted multiply,
    sort contribution
  • flat_indices_3d = AddOp(offsets, scaled_sort)
  • Reshape flat_indices_3d to 2D for EmbeddingOp (required: input rank <= 2)

Checklist

  • New/Existing tests provide coverage for changes
  • changed test to match new decomposition
  • checked results manually for different cases

@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 46.98795% with 44 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.05%. Comparing base (8371504) to head (ad63dc4).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
...ersion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp 46.98% 44 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #6989      +/-   ##
==========================================
- Coverage   69.14%   69.05%   -0.09%     
==========================================
  Files         381      381              
  Lines       67044    67096      +52     
==========================================
- Hits        46357    46333      -24     
- Misses      20687    20763      +76     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants