Skip to content

[TritonGPU] Preserve memdesc_reshape encoding on type propagation#9973

Open
Sibylau wants to merge 1 commit intotriton-lang:mainfrom
Sibylau:jieeliu/preserve-reshape-encoding
Open

[TritonGPU] Preserve memdesc_reshape encoding on type propagation#9973
Sibylau wants to merge 1 commit intotriton-lang:mainfrom
Sibylau:jieeliu/preserve-reshape-encoding

Conversation

@Sibylau
Copy link
Copy Markdown

@Sibylau Sibylau commented Apr 9, 2026

When replaceUsesAndPropagateType recreates a MemDescReshapeOp (e.g. during aref insertion), the old code re-inferred the encoding from the source. This silently changed encodings like #ttg.nvmma_shared to #ttg.shared_linear, since inferMemDescReshapeOpEncoding always produces #ttg.shared_linear when the source has #ttg.shared_linear encoding — even though the two may have equivalent LinearLayouts.

Fix: call inferReturnTypes to compute the correct allocShape from the new source, then preserve the original reshape's encoding.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@Sibylau Sibylau requested a review from ptillet as a code owner April 9, 2026 00:25
Comment thread test/NVWS/insert_aref.mlir Outdated
@@ -1,4 +1,5 @@
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --nvws-insert-aref | FileCheck %s
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --nvws-insert-aref --mlir-print-local-scope | FileCheck %s --check-prefix=LOCAL
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use --mlir-print-local-scope to inline the encoding attributes so we can check it after CHECK-LABEL anchors.

Comment on lines +1456 to +1471
ttg::MemDescType oldType = reshape.getType();
ttg::MemDescType srcType = cast<ttg::MemDescType>(val.getType());
// Use inferReturnTypes to compute the correct allocShape from the new
// source, but preserve the original reshape's encoding rather than
// re-inferring it (which can change e.g. nvmma_shared to shared_linear).
ttg::MemDescType inferredType;
LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes(
builder.getContext(), reshape.getLoc(), srcType, oldType.getShape(),
inferredType);
assert(succeeded(result) && "failed to infer reshape return type");
Type newDstType = ttg::MemDescType::get(
inferredType.getShape(), inferredType.getElementType(),
oldType.getEncoding(), inferredType.getMemorySpace(),
inferredType.getMutableMemory(), inferredType.getAllocShape());
newVal = ttg::MemDescReshapeOp::create(builder, reshape.getLoc(),
newDstType, val);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this assumes the source layout is unchanged?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, only allocShape and mutability change, not the source encoding. This is the same assumption the MemDescIndexOp and MemDescSubsliceOp branches make, which also preserve the original encoding. Just revised to use oldType for shape, element type, encoding, and memory space. inferredType is now only used for allocShape, which is the only thing that actually differs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there seems to be some mix up in this function, not all the cases make the same assumptions. This would probably need a better clean up.

Why is getting a linear layout is causing problems?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it can cause issue if the use op of ttg.memdesc_reshape only accepts NVMMASharedEncodingAttr. I modified the lit test as an example -- ttng.async_tma_copy_local_to_global has a verifier that requires NVMMASharedEncodingAttr, and it fails without this fix.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could only be done if the previous layout is equivalent to the new layout, so at the very least it should be guarded by that, and if this is an expected invariant, it should be asserted.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added the assertion to check if the preserved destination linear layout is equivalent to the inferred linear layout, and it checks the same invariant as the reshape op verifier already checks (https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/IR/Ops.cpp#L580-L584).

@Sibylau Sibylau force-pushed the jieeliu/preserve-reshape-encoding branch 3 times, most recently from 2bf3b20 to 7c9249b Compare April 10, 2026 01:16
@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 10, 2026

Actually, is this needed after #9931?

@Sibylau Sibylau force-pushed the jieeliu/preserve-reshape-encoding branch from 7c9249b to d7410a7 Compare April 10, 2026 22:02
@Sibylau
Copy link
Copy Markdown
Author

Sibylau commented Apr 10, 2026

Actually, is this needed after #9931?

The lit test I added still ran into error after #9931 without this fix, as the TMA op verifier only accepts NVMMASharedEncodingAttr, not the SharedLinearEncodingAttr (https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp#L349-L353).

test/NVWS/insert_aref.mlir:806:7: error: 'ttng.async_tma_copy_local_to_global' op TMA descriptor must have NVMMA shared layout
      ttng.async_tma_copy_local_to_global %desc[%c0, %c0] %reshaped {ttg.partition = array<i32: 1>} : !tt.tensordesc<128x128xbf16, #nvmma>, !ttg.memdesc<128x128xbf16, #nvmma, #smem>
      ^

This fix is needed for the case as exemplified in the lit test to pass.

Comment on lines +1468 to +1477
assert(succeeded(result) && "failed to infer reshape return type");
assert(ttg::areLayoutsEquivalent(
oldType.getShape(),
cast<ttg::LayoutEncodingTrait>(oldType.getEncoding()),
cast<ttg::LayoutEncodingTrait>(inferredType.getEncoding())) &&
"preserved encoding is not equivalent to inferred encoding");
Type newDstType = ttg::MemDescType::get(
oldType.getShape(), oldType.getElementType(), oldType.getEncoding(),
oldType.getMemorySpace(), srcType.getMutableMemory(),
inferredType.getAllocShape());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is the case, are we ever using a new type? do we ever change the mutability or the alloc shape?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for each argument of MemDescType::get(shape, elementType, encoding, memorySpace, mutableMemory, allocShape):

  1. oldType.getShape() — same as inferredType.getShape(), since we pass oldType.getShape() as the dstShape arg to inferReturnTypes.
  2. oldType.getElementType() — same as srcType.getElementType() and inferredType.getElementType(), since reshape requires matching element types.
  3. oldType.getEncoding() — must be from oldType. This is the whole point of the fix.
  4. oldType.getMemorySpace() — same as srcType.getMemorySpace() and inferredType.getMemorySpace(). Always #smem.
  5. srcType.getMutableMemory() — same as inferredType.getMutableMemory(), since inferReturnTypes copies mutability from its source arg (srcType). Must not use oldType — mutability changes after aref insertion.
  6. inferredType.getAllocShape() — must be from inferredType. This has the aref buffer prefix dimension.

So we can simplify to "take everything from inferredType, but swap in oldType's encoding"', as amended in this PR:

  Type newDstType = ttg::MemDescType::get(
      inferredType.getShape(), inferredType.getElementType(),
      oldType.getEncoding(), inferredType.getMemorySpace(),
      inferredType.getMutableMemory(), inferredType.getAllocShape());

When replaceUsesAndPropagateType recreates a MemDescReshapeOp
(e.g. during aref insertion), the old code re-inferred the
encoding from the source. This silently changed encodings like
nvmma_shared to shared_linear, since inferMemDescReshapeOpEncoding
always produces shared_linear when the source has shared_linear
encoding — even though the two may have equivalent LinearLayouts.

Call inferReturnTypes to compute the correct allocShape from the
new source, then preserve the original reshape's encoding.

Co-authored-by: Claude <noreply@anthropic.com>
@Sibylau Sibylau force-pushed the jieeliu/preserve-reshape-encoding branch from d7410a7 to e0cc909 Compare April 14, 2026 00:28
Comment on lines +1462 to +1473
// shared_linear).
ttg::MemDescType inferredType;
LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes(
builder.getContext(), reshape.getLoc(),
cast<ttg::MemDescType>(val.getType()),
reshape.getType().getShape(), inferredType);
assert(succeeded(result) && "failed to infer reshape return type");
assert(ttg::areLayoutsEquivalent(
inferredType.getShape(),
cast<ttg::LayoutEncodingTrait>(reshape.getType().getEncoding()),
cast<ttg::LayoutEncodingTrait>(inferredType.getEncoding())) &&
"preserved encoding is not equivalent to inferred encoding");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this is unnecessary. You can asser tthat the initial types are the same and that's it.

Copy link
Copy Markdown
Author

@Sibylau Sibylau Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the inferReturnTypes to re-compute the allocShape field which changes with aref insertion, while still keep the #ttg.nvvma_shared. The issue is that prior to this fix, the shape-only overload MemDescReshapeOp::create(builder, loc, val, shape) calls inferReturnTypes and it falls back the encoding to #ttg.shared_linear, which downstream ops such as TMA ops would reject because they check for NVMMASharedEncoding, even though verifier of MemDescReshapeOp already checks the inferred #ttg.shared_linear and #ttg.nvmma_shared are layout-equivalent.

Comment on lines +1474 to +1477
Type newDstType = ttg::MemDescType::get(
inferredType.getShape(), inferredType.getElementType(),
reshape.getType().getEncoding(), inferredType.getMemorySpace(),
inferredType.getMutableMemory(), inferredType.getAllocShape());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also unnecessary as you are creating the same type we already had?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to use the explicit type overload MemDescReshapeOp::create(builder, loc, newDstType, val) to keep the original #ttg.nvmma_shared encoding, and the newly inferred allocShape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants