diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 89793c30f3710..528c8bd648332 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -9738,8 +9738,9 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL, // register class, find it. unsigned AssignedReg; const TargetRegisterClass *RC; + const MVT RefValueVT = RefOpInfo.ConstraintVT; std::tie(AssignedReg, RC) = TLI.getRegForInlineAsmConstraint( - &TRI, RefOpInfo.ConstraintCode, RefOpInfo.ConstraintVT); + &TRI, RefOpInfo.ConstraintCode, RefValueVT); // RC is unset only on failure. Return immediately. if (!RC) return std::nullopt; @@ -9747,7 +9748,17 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL, // Get the actual register value type. This is important, because the user // may have asked for (e.g.) the AX register in i32 type. We need to // remember that AX is actually i16 to get the right extension. - const MVT RegVT = *TRI.legalclasstypes_begin(*RC); + MVT RegVT = *TRI.legalclasstypes_begin(*RC); + + // If the reference value type is legal and belongs to the register class, + // use it instead of the first legal value type. This avoids generating + // inaccurate load/store instructions or unnecessary type extensions and + // truncations. + if (TLI.isTypeLegal(RefValueVT) && + llvm::is_contained(llvm::make_range(TRI.legalclasstypes_begin(*RC), + TRI.legalclasstypes_end(*RC)), + RefValueVT.SimpleTy)) + RegVT = RefValueVT.SimpleTy; if (OpInfo.ConstraintVT != MVT::Other && RegVT != MVT::Untyped) { // If this is an FP operand in an integer register (or visa versa), or more diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll index 9be54a746cacd..4db68f28b3c0a 100644 --- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll @@ -81,6 +81,23 @@ define bfloat @test_fadd(bfloat %0, bfloat %1) { ret bfloat %3 } +define bfloat @test_fadd_inlineasm(bfloat %0, bfloat %1) nounwind { +; SM90-LABEL: test_fadd_inlineasm( +; SM90: { +; SM90-NEXT: .reg .b16 %rs<4>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b16 %rs2, [test_fadd_inlineasm_param_0]; +; SM90-NEXT: ld.param.b16 %rs3, [test_fadd_inlineasm_param_1]; +; SM90-NEXT: // begin inline asm +; SM90-NEXT: add.rn.bf16 %rs1, %rs2, %rs3 +; SM90-NEXT: // end inline asm +; SM90-NEXT: st.param.b16 [func_retval0], %rs1; +; SM90-NEXT: ret; + %3 = tail call bfloat asm "add.rn.bf16 $0, $1, $2", "=h,h,h"(bfloat %0, bfloat %1) nounwind + ret bfloat %3 +} + define bfloat @test_fsub(bfloat %0, bfloat %1) { ; SM70-LABEL: test_fsub( ; SM70: {