Skip to content

Commit 90e3fc8

Browse files
jlebarliuyunqi20
authored andcommitted
[BACKEND] Fix bugs in load/storeDShared. (#4181)
Fix bugs in load/storeDShared. Unfortunately we can't test this directly today. But all of these bugs were found by a WIP PR running existing unit tests.
1 parent 5331dd7 commit 90e3fc8

File tree

3 files changed

+26
-25
lines changed

3 files changed

+26
-25
lines changed

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
321321
// 4, split it into multiple ops.
322322
if (vec > 4) {
323323
// TODO(jlebar): Implement this once we can write a testcase.
324-
assert(false && "not yet implemented");
324+
assert(false && "vec > 4 not yet implemented");
325325
}
326326

327327
// Get pointer to remote shared memory if needed.
@@ -335,19 +335,18 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
335335
.o("shared", !ctaId.has_value())
336336
.b(bitwidth)
337337
.v(vec, /*predicate=*/vec > 1);
338-
339-
PTXBuilder::Operand *valOpr;
340338
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
341339

342-
std::string elemConstraint = getConstraintForBitwidth(bitwidth);
340+
PTXBuilder::Operand *valOpr;
341+
std::string constraint = getConstraintForBitwidth(bitwidth);
343342
if (vecTy) {
344-
SmallVector<Value> vecVals;
343+
SmallVector<std::pair<Value, std::string>> vecVals;
345344
for (int i = 0; i < vec; i++) {
346-
vecVals.push_back(extract_element(val, i32_val(i)));
345+
vecVals.push_back({extract_element(val, i32_val(i)), constraint});
347346
}
348-
valOpr = builder.newListOperand(vec, elemConstraint);
347+
valOpr = builder.newListOperand(vecVals);
349348
} else {
350-
valOpr = builder.newOperand(val, elemConstraint);
349+
valOpr = builder.newOperand(val, constraint);
351350
}
352351
st(ptrOpr, valOpr).predicate(pred, "b");
353352
builder.launch(rewriter, loc, void_ty(ctx));
@@ -377,7 +376,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
377376
// 4, split it into multiple ops.
378377
if (vec > 4) {
379378
// TODO(jlebar): Implement this once we can write a testcase.
380-
assert(false && "not yet implemented");
379+
assert(false && "vec > 4 not yet implemented");
381380
}
382381

383382
// Get pointer to remote shared memory if needed.
@@ -389,36 +388,38 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
389388
auto ld = builder.create<>("ld")
390389
->o("shared::cta", ctaId.has_value())
391390
.o("shared", !ctaId.has_value())
392-
.b(bitwidth)
393-
.v(vec, /*predicate=*/vec > 1);
391+
.v(vec, /*predicate=*/vec > 1)
392+
.b(bitwidth);
394393

395394
std::string elemConstraint = "=" + getConstraintForBitwidth(bitwidth);
396395
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
397396
: builder.newListOperand(vec, elemConstraint);
398397
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");
399398

400-
Type resultTy;
399+
Type resultTy =
400+
vec == 1 ? Type(int_ty(bitwidth))
401+
: Type(struct_ty(SmallVector<Type>(vec, int_ty(bitwidth))));
402+
Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);
403+
404+
SmallVector<Value> resultVals;
401405
if (vec == 1) {
402-
resultTy = int_ty(bitwidth);
406+
resultVals.push_back(load);
403407
} else {
404-
resultTy = struct_ty(SmallVector<Type>(vec, int_ty(bitwidth)));
408+
for (int i = 0; i < vec; i++) {
409+
resultVals.push_back(extract_val(load, i));
410+
}
405411
}
406-
Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);
407412

408413
if (vecTy) {
409-
// Unpack the struct returned by the inline asm into a vector.
410-
SmallVector<Value> vals;
411-
for (int i = 0; i < vec; i++) {
412-
auto elem = extract_val(int_ty(bitwidth), load, i);
413-
vals.push_back(bitcast(elem, vecTy.getElementType()));
414-
}
415414
Value ret = undef(loadTy);
416415
for (int i = 0; i < vec; i++) {
417-
ret = insert_element(ret, i32_val(i), vals[i]);
416+
ret = insert_element(ret, bitcast(resultVals[i], vecTy.getElementType()),
417+
i32_val(i));
418418
}
419419
return ret;
420420
} else {
421-
return bitcast(load, loadTy);
421+
assert(vec == 1);
422+
return bitcast(resultVals[0], loadTy);
422423
}
423424
}
424425

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Value permute(Location loc, RewriterBase &rewriter, Value a, Value b,
100100
}
101101

102102
/// Create a predicate with just single active thread.
103-
Value createElectPredicate(Location loc, PatternRewriter &rewriter) {
103+
Value createElectPredicate(Location loc, RewriterBase &rewriter) {
104104
PTXBuilder ptxBuilder;
105105
auto &elect = *ptxBuilder.create<>("elect.sync _|$0, 0xffffffff;");
106106
elect({ptxBuilder.newOperand("=b")}, /*onlyAttachMLIRArgs=*/true);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
4646
int axis);
4747

4848
/// Create a predicate with just single active thread.
49-
Value createElectPredicate(Location loc, PatternRewriter &rewriter);
49+
Value createElectPredicate(Location loc, RewriterBase &rewriter);
5050

5151
} // namespace NVIDIA
5252
} // namespace LLVM

0 commit comments

Comments
 (0)