@@ -321,7 +321,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
321321 // 4, split it into multiple ops.
322322 if (vec > 4 ) {
323323 // TODO(jlebar): Implement this once we can write a testcase.
324- assert (false && " not yet implemented" );
324+ assert (false && " vec > 4 not yet implemented" );
325325 }
326326
327327 // Get pointer to remote shared memory if needed.
@@ -335,19 +335,18 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
335335 .o (" shared" , !ctaId.has_value ())
336336 .b (bitwidth)
337337 .v (vec, /* predicate=*/ vec > 1 );
338-
339- PTXBuilder::Operand *valOpr;
340338 auto *ptrOpr = builder.newAddrOperand (ptr, " r" );
341339
342- std::string elemConstraint = getConstraintForBitwidth (bitwidth);
340+ PTXBuilder::Operand *valOpr;
341+ std::string constraint = getConstraintForBitwidth (bitwidth);
343342 if (vecTy) {
344- SmallVector<Value> vecVals;
343+ SmallVector<std::pair< Value, std::string> > vecVals;
345344 for (int i = 0 ; i < vec; i++) {
346- vecVals.push_back (extract_element (val, i32_val (i)));
345+ vecVals.push_back ({ extract_element (val, i32_val (i)), constraint} );
347346 }
348- valOpr = builder.newListOperand (vec, elemConstraint );
347+ valOpr = builder.newListOperand (vecVals );
349348 } else {
350- valOpr = builder.newOperand (val, elemConstraint );
349+ valOpr = builder.newOperand (val, constraint );
351350 }
352351 st (ptrOpr, valOpr).predicate (pred, " b" );
353352 builder.launch (rewriter, loc, void_ty (ctx));
@@ -377,7 +376,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
377376 // 4, split it into multiple ops.
378377 if (vec > 4 ) {
379378 // TODO(jlebar): Implement this once we can write a testcase.
380- assert (false && " not yet implemented" );
379+ assert (false && " vec > 4 not yet implemented" );
381380 }
382381
383382 // Get pointer to remote shared memory if needed.
@@ -389,36 +388,38 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
389388 auto ld = builder.create <>(" ld" )
390389 ->o (" shared::cta" , ctaId.has_value ())
391390 .o (" shared" , !ctaId.has_value ())
392- .b (bitwidth )
393- .v (vec, /* predicate= */ vec > 1 );
391+ .v (vec, /* predicate= */ vec > 1 )
392+ .b (bitwidth );
394393
395394 std::string elemConstraint = " =" + getConstraintForBitwidth (bitwidth);
396395 auto *outOpr = vec == 1 ? builder.newOperand (elemConstraint)
397396 : builder.newListOperand (vec, elemConstraint);
398397 ld (outOpr, builder.newAddrOperand (ptr, " r" )).predicate (pred, " b" );
399398
400- Type resultTy;
399+ Type resultTy =
400+ vec == 1 ? Type (int_ty (bitwidth))
401+ : Type (struct_ty (SmallVector<Type>(vec, int_ty (bitwidth))));
402+ Value load = builder.launch (rewriter, loc, resultTy, /* hasSideEffects=*/ true );
403+
404+ SmallVector<Value> resultVals;
401405 if (vec == 1 ) {
402- resultTy = int_ty (bitwidth );
406+ resultVals. push_back (load );
403407 } else {
404- resultTy = struct_ty (SmallVector<Type>(vec, int_ty (bitwidth)));
408+ for (int i = 0 ; i < vec; i++) {
409+ resultVals.push_back (extract_val (load, i));
410+ }
405411 }
406- Value load = builder.launch (rewriter, loc, resultTy, /* hasSideEffects=*/ true );
407412
408413 if (vecTy) {
409- // Unpack the struct returned by the inline asm into a vector.
410- SmallVector<Value> vals;
411- for (int i = 0 ; i < vec; i++) {
412- auto elem = extract_val (int_ty (bitwidth), load, i);
413- vals.push_back (bitcast (elem, vecTy.getElementType ()));
414- }
415414 Value ret = undef (loadTy);
416415 for (int i = 0 ; i < vec; i++) {
417- ret = insert_element (ret, i32_val (i), vals[i]);
416+ ret = insert_element (ret, bitcast (resultVals[i], vecTy.getElementType ()),
417+ i32_val (i));
418418 }
419419 return ret;
420420 } else {
421- return bitcast (load, loadTy);
421+ assert (vec == 1 );
422+ return bitcast (resultVals[0 ], loadTy);
422423 }
423424}
424425
0 commit comments