Skip to content

[HLSL][DXIL] Implement refract intrinsic #136026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

raoanag
Copy link

@raoanag raoanag commented Apr 16, 2025

  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@raoanag raoanag force-pushed the user/raoanag/refract branch from a1ccf10 to dff5181 Compare April 29, 2025 00:16
@raoanag raoanag marked this pull request as ready for review April 29, 2025 00:21
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Apr 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-spir-v

@llvm/pr-subscribers-clang

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h

  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td

  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp

  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp

  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl

  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c

  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl

  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c

  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td

  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.

  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll

  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRV.td (+6)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+21)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+44-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+356)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+34)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+37)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
   let Prototype = "void(...)";
 }
 
+def SPIRVRefract : Builtin {
+  let Spellings = ["__builtin_spirv_refract"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def SPIRVSmoothStep : Builtin {
   let Spellings = ["__builtin_spirv_smoothstep"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->hasFloatingRepresentation() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+  T k = 1 - eta * eta * (1 - (N * I * N *I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  return __builtin_spirv_refract(I, N, eta);
+#else
+  vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+    refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     TheCall->setType(RetTy);
     break;
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(B.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->hasFloatingRepresentation()) {
+      SemaRef.Diag(C.get()->getBeginLoc(),
+                   diag::err_builtin_invalid_arg_type)
+          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << ArgTyC;
+      return true;
+    }
+
+    QualType RetTy = ArgTyA;
+    TheCall->setType(RetTy);
+    assert(RetTy == ArgTyA);
+    break;
+  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     QualType ArgTyB = B.get()->getType();
     auto *VTyB = ArgTyB->getAs<VectorType>();
     if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
+      SemaRef.Diag(B.get()->getBeginLoc(),
                    diag::err_typecheck_convert_incompatible)
           << ArgTyB
           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK:  if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] 
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK:  if.else.i:                                        ; preds = %entry
+// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-clang-codegen

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h

  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td

  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp

  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp

  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl

  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c

  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl

  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c

  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td

  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.

  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll

  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRV.td (+6)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+21)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+44-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+356)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+34)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+37)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
   let Prototype = "void(...)";
 }
 
+def SPIRVRefract : Builtin {
+  let Spellings = ["__builtin_spirv_refract"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def SPIRVSmoothStep : Builtin {
   let Spellings = ["__builtin_spirv_smoothstep"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->hasFloatingRepresentation() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+  T k = 1 - eta * eta * (1 - (N * I * N *I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  return __builtin_spirv_refract(I, N, eta);
+#else
+  vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+    refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     TheCall->setType(RetTy);
     break;
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(B.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->hasFloatingRepresentation()) {
+      SemaRef.Diag(C.get()->getBeginLoc(),
+                   diag::err_builtin_invalid_arg_type)
+          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << ArgTyC;
+      return true;
+    }
+
+    QualType RetTy = ArgTyA;
+    TheCall->setType(RetTy);
+    assert(RetTy == ArgTyA);
+    break;
+  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     QualType ArgTyB = B.get()->getType();
     auto *VTyB = ArgTyB->getAs<VectorType>();
     if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
+      SemaRef.Diag(B.get()->getBeginLoc(),
                    diag::err_typecheck_convert_incompatible)
           << ArgTyB
           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK:  if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] 
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK:  if.else.i:                                        ; preds = %entry
+// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]

@raoanag raoanag changed the title User/raoanag/refract [HLSL][DXIL] Implement refract intrinsic Apr 29, 2025
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -O1 -o - | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to use O1 for these tests. Also don't use utils/update_cc_test_checks.py your checking for stuff that isn't relevant to the alrgorithm.

//
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]]
// SPVCHECK-NEXT: [[ENTRY:.*:]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the SPIRV and CHECK lines are the same we can combine these by adding a third check that covers the cases where the runlines are the same. There should be examples of this.

@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T> constexpr T refract_impl(T I, T N, T eta) {
T k = 1 - eta * eta * (1 - (N * I * N *I));
if(k < 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we can simplify these conditionals to use a select instead.


ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->hasFloatingRepresentation()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasFloatingRepresentation by itself is wrong here because its going to include float\half vectors.

if (!ArgTyC->hasFloatingRepresentation()) {
SemaRef.Diag(C.get()->getBeginLoc(),
diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
Copy link
Member

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wouldn't be scalar or vector needs to just be scalar float\half.

if (SemaRef.checkArgCount(TheCall, 3))
return true;

ExprResult A = TheCall->getArg(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could do for (unsigned i = 0; i < TheCall->getNumArgs()-1; ++i) { so you don't have to repeat the code for ExprResult A and B.


QualType RetTy = ArgTyA;
TheCall->setType(RetTy);
assert(RetTy == ArgTyA);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assert doesn't do anything you are setting RetTy to ArgTyA two lines above. delete this.

@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
SemaRef.Diag(B.get()->getBeginLoc(),
Copy link
Member

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

return true;
}

ExprResult B = TheCall->getArg(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a helper function since reflect is doing the same thing.

@@ -0,0 +1,34 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5

// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove 01

Value *eta = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->hasFloatingRepresentation() &&
Copy link
Member

@farzonl farzonl Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just check E->getArg(2)->getType() is a float or half scalar might be worth putting this in its own assert so you can have a different message.

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp,c,h -- clang/test/CodeGenSPIRV/Builtins/refract.c clang/test/SemaSPIRV/BuiltIns/refract-errors.c clang/lib/CodeGen/TargetBuiltins/SPIR.cpp clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h clang/lib/Headers/hlsl/hlsl_intrinsics.h clang/lib/Sema/SemaSPIRV.cpp llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
View the diff from clang-format here.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index e9d376b7f..e076f4ded 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -72,8 +72,8 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 }
 
 template <typename T> constexpr T refract_impl(T I, T N, T eta) {
-  T k = 1 - eta * eta * (1 - (N * I * N *I));
-  if(k < 0)
+  T k = 1 - eta * eta * (1 - (N * I * N * I));
+  if (k < 0)
     return 0;
   else
     return (eta * I - (eta * N * I + sqrt(k)) * N);
@@ -85,7 +85,7 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
   return __builtin_spirv_refract(I, N, eta);
 #else
   vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
-  if(k < 0)
+  if (k < 0)
     return 0;
   else
     return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 54fae9f3b..ccfd0b75a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -445,8 +445,8 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
 /// \param eta The refraction index.
 ///
 /// The return value is a floating-point vector that represents the refraction
-/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
-/// off a surface with the normal \a N.
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
 ///
 /// This function calculates the refraction vector using the following formulas:
 /// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
@@ -473,7 +473,7 @@ const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
 template <typename T>
 const inline __detail::enable_if_t<
     __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
-    refract(T I, T N, T eta) {
+refract(T I, T N, T eta) {
   return __detail::refract_impl(I, N, eta);
 }
 
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index b10c6b767..0f8c8c660 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -117,8 +117,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     ExprResult C = TheCall->getArg(2);
     QualType ArgTyC = C.get()->getType();
     if (!ArgTyC->hasFloatingRepresentation()) {
-      SemaRef.Diag(C.get()->getBeginLoc(),
-                   diag::err_builtin_invalid_arg_type)
+      SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
           << ArgTyC;
       return true;

@farzonl
Copy link
Member

farzonl commented Apr 29, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
View the diff from clang-format here.

run git clang-format <git branch or hash before your commits>

// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are having this extractelement because you are treated the third argument as the same type as the first two args but it is always a scalar.

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests need to be redone. Try not to use any automation tooling to generate your tests. its causing you to miss sublte but important implementation details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the refract HLSL Function
3 participants