-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[HLSL][DXIL] Implement refract
intrinsic
#136026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
a1ccf10
to
dff5181
Compare
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-clang Author: None (raoanag) Changes
Resolves #99153 Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff 12 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}
+def SPIRVRefract : Builtin {
+ let Spellings = ["__builtin_spirv_refract"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ Value *I = EmitScalarExpr(E->getArg(0));
+ Value *N = EmitScalarExpr(E->getArg(1));
+ Value *eta = EmitScalarExpr(E->getArg(2));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasFloatingRepresentation() &&
+ E->getArg(2)->getType()->hasFloatingRepresentation() &&
+ "refract operands must have a float representation");
+ assert(E->getArg(0)->getType()->isVectorType() &&
+ E->getArg(1)->getType()->isVectorType() &&
+ "refract I and N operands must be a vector");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+ ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+ }
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+ T k = 1 - eta * eta * (1 - (N * I * N *I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+ return __builtin_spirv_refract(I, N, eta);
+#else
+ vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+ __detail::is_same<half, T>::value,
+ T> refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+ __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+ refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+ __detail::HLSL_FIXED_VECTOR<half, L> I,
+ __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+ __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ auto *VTyA = ArgTyA->getAs<VectorType>();
+ if (VTyA == nullptr) {
+ SemaRef.Diag(A.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyA
+ << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult B = TheCall->getArg(1);
+ QualType ArgTyB = B.get()->getType();
+ auto *VTyB = ArgTyB->getAs<VectorType>();
+ if (VTyB == nullptr) {
+ SemaRef.Diag(B.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyB
+ << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult C = TheCall->getArg(2);
+ QualType ArgTyC = C.get()->getType();
+ if (!ArgTyC->hasFloatingRepresentation()) {
+ SemaRef.Diag(C.get()->getBeginLoc(),
+ diag::err_builtin_invalid_arg_type)
+ << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+ << ArgTyC;
+ return true;
+ }
+
+ QualType RetTy = ArgTyA;
+ TheCall->setType(RetTy);
+ assert(RetTy == ArgTyA);
+ break;
+ }
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
+ SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT: [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT: [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT: [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT: [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]]
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT: [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK: if.else.i: ; preds = %entry
+// SPVCHECK-NEXT: [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT: [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT: ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]
|
@llvm/pr-subscribers-clang-codegen Author: None (raoanag) Changes
Resolves #99153 Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff 12 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}
+def SPIRVRefract : Builtin {
+ let Spellings = ["__builtin_spirv_refract"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ Value *I = EmitScalarExpr(E->getArg(0));
+ Value *N = EmitScalarExpr(E->getArg(1));
+ Value *eta = EmitScalarExpr(E->getArg(2));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasFloatingRepresentation() &&
+ E->getArg(2)->getType()->hasFloatingRepresentation() &&
+ "refract operands must have a float representation");
+ assert(E->getArg(0)->getType()->isVectorType() &&
+ E->getArg(1)->getType()->isVectorType() &&
+ "refract I and N operands must be a vector");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+ ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+ }
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+ T k = 1 - eta * eta * (1 - (N * I * N *I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+ return __builtin_spirv_refract(I, N, eta);
+#else
+ vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+ __detail::is_same<half, T>::value,
+ T> refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+ __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+ refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+ __detail::HLSL_FIXED_VECTOR<half, L> I,
+ __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+ __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ auto *VTyA = ArgTyA->getAs<VectorType>();
+ if (VTyA == nullptr) {
+ SemaRef.Diag(A.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyA
+ << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult B = TheCall->getArg(1);
+ QualType ArgTyB = B.get()->getType();
+ auto *VTyB = ArgTyB->getAs<VectorType>();
+ if (VTyB == nullptr) {
+ SemaRef.Diag(B.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyB
+ << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult C = TheCall->getArg(2);
+ QualType ArgTyC = C.get()->getType();
+ if (!ArgTyC->hasFloatingRepresentation()) {
+ SemaRef.Diag(C.get()->getBeginLoc(),
+ diag::err_builtin_invalid_arg_type)
+ << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+ << ArgTyC;
+ return true;
+ }
+
+ QualType RetTy = ArgTyA;
+ TheCall->setType(RetTy);
+ assert(RetTy == ArgTyA);
+ break;
+ }
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
+ SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT: [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT: [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT: [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT: [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]]
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT: [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK: if.else.i: ; preds = %entry
+// SPVCHECK-NEXT: [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT: [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT: ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]
|
refract
intrinsic
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 | ||
// RUN: %clang_cc1 -finclude-default-header -triple \ | ||
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ | ||
// RUN: -emit-llvm -O1 -o - | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to use O1
for these tests. Also don't use utils/update_cc_test_checks.py
your checking for stuff that isn't relevant to the alrgorithm.
// | ||
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh( | ||
// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] | ||
// SPVCHECK-NEXT: [[ENTRY:.*:]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the SPIRV and CHECK lines are the same we can combine these by adding a third check that covers the cases where the runlines are the same. There should be examples of this.
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) { | |||
#endif | |||
} | |||
|
|||
template <typename T> constexpr T refract_impl(T I, T N, T eta) { | |||
T k = 1 - eta * eta * (1 - (N * I * N *I)); | |||
if(k < 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we can simplify these conditionals to use a select instead.
|
||
ExprResult C = TheCall->getArg(2); | ||
QualType ArgTyC = C.get()->getType(); | ||
if (!ArgTyC->hasFloatingRepresentation()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasFloatingRepresentation
by itself is wrong here because its going to include float\half vectors.
if (!ArgTyC->hasFloatingRepresentation()) { | ||
SemaRef.Diag(C.get()->getBeginLoc(), | ||
diag::err_builtin_invalid_arg_type) | ||
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this wouldn't be scalar or vector needs to just be scalar float\half.
if (SemaRef.checkArgCount(TheCall, 3)) | ||
return true; | ||
|
||
ExprResult A = TheCall->getArg(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could do for (unsigned i = 0; i < TheCall->getNumArgs()-1; ++i) {
so you don't have to repeat the code for ExprResult A and B.
|
||
QualType RetTy = ArgTyA; | ||
TheCall->setType(RetTy); | ||
assert(RetTy == ArgTyA); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assert doesn't do anything you are setting RetTy to ArgTyA two lines above. delete this.
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, | |||
QualType ArgTyB = B.get()->getType(); | |||
auto *VTyB = ArgTyB->getAs<VectorType>(); | |||
if (VTyB == nullptr) { | |||
SemaRef.Diag(A.get()->getBeginLoc(), | |||
SemaRef.Diag(B.get()->getBeginLoc(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch
return true; | ||
} | ||
|
||
ExprResult B = TheCall->getArg(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be a helper function since reflect is doing the same thing.
@@ -0,0 +1,34 @@ | |||
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 | |||
|
|||
// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove 01
Value *eta = EmitScalarExpr(E->getArg(2)); | ||
assert(E->getArg(0)->getType()->hasFloatingRepresentation() && | ||
E->getArg(1)->getType()->hasFloatingRepresentation() && | ||
E->getArg(2)->getType()->hasFloatingRepresentation() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should just check E->getArg(2)->getType()
is a float or half scalar might be worth putting this in its own assert so you can have a different message.
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp,c,h -- clang/test/CodeGenSPIRV/Builtins/refract.c clang/test/SemaSPIRV/BuiltIns/refract-errors.c clang/lib/CodeGen/TargetBuiltins/SPIR.cpp clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h clang/lib/Headers/hlsl/hlsl_intrinsics.h clang/lib/Sema/SemaSPIRV.cpp llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp View the diff from clang-format here.diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index e9d376b7f..e076f4ded 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -72,8 +72,8 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
}
template <typename T> constexpr T refract_impl(T I, T N, T eta) {
- T k = 1 - eta * eta * (1 - (N * I * N *I));
- if(k < 0)
+ T k = 1 - eta * eta * (1 - (N * I * N * I));
+ if (k < 0)
return 0;
else
return (eta * I - (eta * N * I + sqrt(k)) * N);
@@ -85,7 +85,7 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
return __builtin_spirv_refract(I, N, eta);
#else
vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
- if(k < 0)
+ if (k < 0)
return 0;
else
return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 54fae9f3b..ccfd0b75a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -445,8 +445,8 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
-/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
-/// off a surface with the normal \a N.
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
@@ -473,7 +473,7 @@ const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
- refract(T I, T N, T eta) {
+refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index b10c6b767..0f8c8c660 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -117,8 +117,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->hasFloatingRepresentation()) {
- SemaRef.Diag(C.get()->getBeginLoc(),
- diag::err_builtin_invalid_arg_type)
+ SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
<< ArgTyC;
return true;
|
run |
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_( | ||
// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] { | ||
// SPVCHECK-NEXT: [[ENTRY:.*:]] | ||
// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are having this extractelement because you are treated the third argument as the same type as the first two args but it is always a scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests need to be redone. Try not to use any automation tooling to generate your tests. its causing you to miss sublte but important implementation details.
Resolves #99153