[TritonGPU] Preserve memdesc_reshape encoding on type propagation#9973
[TritonGPU] Preserve memdesc_reshape encoding on type propagation#9973Sibylau wants to merge 1 commit intotriton-lang:mainfrom
Conversation
| @@ -1,4 +1,5 @@ | |||
| // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --nvws-insert-aref | FileCheck %s | |||
| // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --nvws-insert-aref --mlir-print-local-scope | FileCheck %s --check-prefix=LOCAL | |||
There was a problem hiding this comment.
Use --mlir-print-local-scope to inline the encoding attributes so we can check it after CHECK-LABEL anchors.
| ttg::MemDescType oldType = reshape.getType(); | ||
| ttg::MemDescType srcType = cast<ttg::MemDescType>(val.getType()); | ||
| // Use inferReturnTypes to compute the correct allocShape from the new | ||
| // source, but preserve the original reshape's encoding rather than | ||
| // re-inferring it (which can change e.g. nvmma_shared to shared_linear). | ||
| ttg::MemDescType inferredType; | ||
| LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes( | ||
| builder.getContext(), reshape.getLoc(), srcType, oldType.getShape(), | ||
| inferredType); | ||
| assert(succeeded(result) && "failed to infer reshape return type"); | ||
| Type newDstType = ttg::MemDescType::get( | ||
| inferredType.getShape(), inferredType.getElementType(), | ||
| oldType.getEncoding(), inferredType.getMemorySpace(), | ||
| inferredType.getMutableMemory(), inferredType.getAllocShape()); | ||
| newVal = ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), | ||
| newDstType, val); |
There was a problem hiding this comment.
so this assumes the source layout is unchanged?
There was a problem hiding this comment.
Yes, only allocShape and mutability change, not the source encoding. This is the same assumption the MemDescIndexOp and MemDescSubsliceOp branches make, which also preserve the original encoding. Just revised to use oldType for shape, element type, encoding, and memory space. inferredType is now only used for allocShape, which is the only thing that actually differs.
There was a problem hiding this comment.
there seems to be some mix up in this function, not all the cases make the same assumptions. This would probably need a better clean up.
Why is getting a linear layout is causing problems?
There was a problem hiding this comment.
Yeah it can cause issue if the use op of ttg.memdesc_reshape only accepts NVMMASharedEncodingAttr. I modified the lit test as an example -- ttng.async_tma_copy_local_to_global has a verifier that requires NVMMASharedEncodingAttr, and it fails without this fix.
There was a problem hiding this comment.
This could only be done if the previous layout is equivalent to the new layout, so at the very least it should be guarded by that, and if this is an expected invariant, it should be asserted.
There was a problem hiding this comment.
Thanks. I added the assertion to check if the preserved destination linear layout is equivalent to the inferred linear layout, and it checks the same invariant as the reshape op verifier already checks (https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/IR/Ops.cpp#L580-L584).
2bf3b20 to
7c9249b
Compare
|
Actually, is this needed after #9931? |
7c9249b to
d7410a7
Compare
The lit test I added still ran into error after #9931 without this fix, as the TMA op verifier only accepts NVMMASharedEncodingAttr, not the SharedLinearEncodingAttr (https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp#L349-L353). This fix is needed for the case as exemplified in the lit test to pass. |
| assert(succeeded(result) && "failed to infer reshape return type"); | ||
| assert(ttg::areLayoutsEquivalent( | ||
| oldType.getShape(), | ||
| cast<ttg::LayoutEncodingTrait>(oldType.getEncoding()), | ||
| cast<ttg::LayoutEncodingTrait>(inferredType.getEncoding())) && | ||
| "preserved encoding is not equivalent to inferred encoding"); | ||
| Type newDstType = ttg::MemDescType::get( | ||
| oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), | ||
| oldType.getMemorySpace(), srcType.getMutableMemory(), | ||
| inferredType.getAllocShape()); |
There was a problem hiding this comment.
if this is the case, are we ever using a new type? do we ever change the mutability or the alloc shape?
There was a problem hiding this comment.
Yes, for each argument of MemDescType::get(shape, elementType, encoding, memorySpace, mutableMemory, allocShape):
oldType.getShape()— same asinferredType.getShape(), since we passoldType.getShape()as the dstShape arg toinferReturnTypes.oldType.getElementType()— same assrcType.getElementType()andinferredType.getElementType(), since reshape requires matching element types.oldType.getEncoding()— must be fromoldType. This is the whole point of the fix.oldType.getMemorySpace()— same assrcType.getMemorySpace()andinferredType.getMemorySpace(). Always #smem.srcType.getMutableMemory()— same asinferredType.getMutableMemory(), sinceinferReturnTypescopies mutability from its source arg (srcType). Must not useoldType— mutability changes after aref insertion.inferredType.getAllocShape()— must be frominferredType. This has the aref buffer prefix dimension.
So we can simplify to "take everything from inferredType, but swap in oldType's encoding"', as amended in this PR:
Type newDstType = ttg::MemDescType::get(
inferredType.getShape(), inferredType.getElementType(),
oldType.getEncoding(), inferredType.getMemorySpace(),
inferredType.getMutableMemory(), inferredType.getAllocShape());
When replaceUsesAndPropagateType recreates a MemDescReshapeOp (e.g. during aref insertion), the old code re-inferred the encoding from the source. This silently changed encodings like nvmma_shared to shared_linear, since inferMemDescReshapeOpEncoding always produces shared_linear when the source has shared_linear encoding — even though the two may have equivalent LinearLayouts. Call inferReturnTypes to compute the correct allocShape from the new source, then preserve the original reshape's encoding. Co-authored-by: Claude <noreply@anthropic.com>
d7410a7 to
e0cc909
Compare
| // shared_linear). | ||
| ttg::MemDescType inferredType; | ||
| LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes( | ||
| builder.getContext(), reshape.getLoc(), | ||
| cast<ttg::MemDescType>(val.getType()), | ||
| reshape.getType().getShape(), inferredType); | ||
| assert(succeeded(result) && "failed to infer reshape return type"); | ||
| assert(ttg::areLayoutsEquivalent( | ||
| inferredType.getShape(), | ||
| cast<ttg::LayoutEncodingTrait>(reshape.getType().getEncoding()), | ||
| cast<ttg::LayoutEncodingTrait>(inferredType.getEncoding())) && | ||
| "preserved encoding is not equivalent to inferred encoding"); |
There was a problem hiding this comment.
All this is unnecessary. You can asser tthat the initial types are the same and that's it.
There was a problem hiding this comment.
We need the inferReturnTypes to re-compute the allocShape field which changes with aref insertion, while still keep the #ttg.nvvma_shared. The issue is that prior to this fix, the shape-only overload MemDescReshapeOp::create(builder, loc, val, shape) calls inferReturnTypes and it falls back the encoding to #ttg.shared_linear, which downstream ops such as TMA ops would reject because they check for NVMMASharedEncoding, even though verifier of MemDescReshapeOp already checks the inferred #ttg.shared_linear and #ttg.nvmma_shared are layout-equivalent.
| Type newDstType = ttg::MemDescType::get( | ||
| inferredType.getShape(), inferredType.getElementType(), | ||
| reshape.getType().getEncoding(), inferredType.getMemorySpace(), | ||
| inferredType.getMutableMemory(), inferredType.getAllocShape()); |
There was a problem hiding this comment.
also unnecessary as you are creating the same type we already had?
There was a problem hiding this comment.
We want to use the explicit type overload MemDescReshapeOp::create(builder, loc, newDstType, val) to keep the original #ttg.nvmma_shared encoding, and the newly inferred allocShape.
When
replaceUsesAndPropagateTyperecreates aMemDescReshapeOp(e.g. during aref insertion), the old code re-inferred the encoding from the source. This silently changed encodings like#ttg.nvmma_sharedto#ttg.shared_linear, sinceinferMemDescReshapeOpEncodingalways produces#ttg.shared_linearwhen the source has#ttg.shared_linearencoding — even though the two may have equivalent LinearLayouts.Fix: call
inferReturnTypesto compute the correct allocShape from the new source, then preserve the original reshape's encoding.New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)