diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 74c9260bb7..43923d1843 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -5663,9 +5663,15 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI, if (SPIRV::FPConvertToEncodingMap::find(DemangledName)) { FPConversionDesc FPDesc = SPIRV::FPConvertToEncodingMap::map(DemangledName); - Value *Src = CI->getOperand(0); + + // Handle both direct return and sret cases + // sret: void @func(ptr sret(...) %result, i8 %source) + // direct: half @func(i8 %source) + bool HasSRet = CI->hasStructRetAttr(); + unsigned SrcOperandIdx = HasSRet ? 1 : 0; + Value *Src = CI->getOperand(SrcOperandIdx); Type *LLVMSrcTy = Src->getType(); - Type *LLVMDstTy = CI->getType(); + Type *LLVMDstTy = HasSRet ? CI->getParamStructRetType(0) : CI->getType(); SPIRVType *SrcTy = nullptr; SPIRVType *DstTy = nullptr; @@ -5734,7 +5740,8 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI, const auto OC = static_cast(FPDesc.ConvOpCode); // Translate operands for stochastic roundings. - for (size_t I = 1; I != CI->arg_size(); ++I) + // Skip sret parameter (if present) and value operand. + for (size_t I = SrcOperandIdx + 1; I != CI->arg_size(); ++I) Ops.push_back(transValue(CI->getOperand(I), BB)); SPIRVValue *Conv = BM->addInstTemplate(OC, BM->getIds(Ops), BB, DstTy); @@ -5750,7 +5757,7 @@ SPIRVValue *LLVMToSPIRVBase::transDirectCallInst(CallInst *CI, return Conv; // Need to adjust types: create bitcast for FP8 and packed Int4. SPIRVValue *BitCast = - BM->addUnaryInst(OpBitcast, transType(CI->getType()), Conv, BB); + BM->addUnaryInst(OpBitcast, transType(LLVMDstTy), Conv, BB); return BitCast; } } diff --git a/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll b/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll index b0f08b92cd..83dc6eabf5 100644 --- a/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll +++ b/test/extensions/EXT/SPV_EXT_float8/conversions_scalar_vector.ll @@ -16,6 +16,7 @@ ; CHECK-SPIRV-DAG: Extension "SPV_EXT_float8" ; CHECK-SPIRV-DAG: Name [[#e4m3_hf16_scalar:]] "e4m3_hf16_scalar" +; CHECK-SPIRV-DAG: Name [[#e4m3_sycl_half_scalar:]] "e4m3_sycl_half_scalar" ; CHECK-SPIRV-DAG: Name [[#e4m3_hf16_vector:]] "e4m3_hf16_vector" ; CHECK-SPIRV-DAG: Name [[#e5m2_hf16_scalar:]] "e5m2_hf16_scalar" ; CHECK-SPIRV-DAG: Name [[#e5m2_hf16_vector:]] "e5m2_hf16_vector" @@ -45,6 +46,7 @@ ; CHECK-SPIRV-DAG: TypeFloat [[#HFloat16Ty:]] 16 {{$}} ; CHECK-SPIRV-DAG: TypeVector [[#HFloat16VecTy:]] [[#HFloat16Ty]] 8 +; CHECK-SPIRV-DAG: TypeStruct [[#S_HALF:]] [[#HFloat16Ty]] {{$}} ; CHECK-SPIRV-DAG: Constant [[#HFloat16Ty]] [[#HalfConst:]] 15360 ; CHECK-SPIRV-DAG: ConstantComposite [[#HFloat16VecTy]] [[#HalfVecConst:]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] [[#HalfConst]] @@ -73,6 +75,23 @@ entry: declare dso_local spir_func half @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTc(i8) +; CHECK-SPIRV: Function [[#]] [[#e4m3_sycl_half_scalar]] +; CHECK-SPIRV: Bitcast [[#E4M3Ty]] [[#Cast1:]] [[#Int8Const]] +; CHECK-SPIRV: FConvert [[#S_HALF]] [[#Conv:]] [[#Cast1]] + +; CHECK-LLVM-LABEL: e4m3_sycl_half_scalar +; CHECK-LLVM: %[[#Call:]] = call %"class.sycl::_V1::detail::half_impl::half" @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTc(i8 1) + +%"class.sycl::_V1::detail::half_impl::half" = type { half } +define spir_func void @e4m3_sycl_half_scalar() { +entry: + %hi = alloca %"class.sycl::_V1::detail::half_impl::half", align 2 + call spir_func void @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTh(ptr sret(%"class.sycl::_V1::detail::half_impl::half") align 2 %hi, i8 1) #5 + ret void +} + +declare dso_local spir_func void @_Z36__builtin_spirv_ConvertE4M3ToFP16EXTh(ptr sret(%"class.sycl::_V1::detail::half_impl::half") align 2, i8) + ; CHECK-SPIRV: Function [[#]] [[#e4m3_hf16_vector]] [[#]] ; CHECK-SPIRV: Bitcast [[#E4M3VecTy]] [[#Cast1:]] [[#Int8VecConst]] ; CHECK-SPIRV: FConvert [[#HFloat16VecTy]] [[#Conv:]] [[#Cast1]]