Skip to content

[HLSL] Analyze updateCounter usage #135669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

V-FEXrt
Copy link
Contributor

@V-FEXrt V-FEXrt commented Apr 14, 2025

Fixes #135667

Analyze and annotate ResourceInfo with the derived direction of calls to updateCounter (if any).

This change only sets the value. Any diagnostics that should be raised must be done somewhere else.

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-llvm-analysis

Author: Ashley Coleman (V-FEXrt)

Changes

Fixes #135667

Analyze and annotate ResourceInfo with the derived direction of calls to updateCounter (if any).

This change only sets the value. Any diagnostics that should be raised must be done somewhere else.


Patch is 22.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135669.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+1-1)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+45-6)
  • (modified) llvm/test/Analysis/DXILResource/buffer-frombinding.ll (+4-2)
  • (modified) llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp (+190-41)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 96e90e563e230..a8124caf64420 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -463,7 +463,7 @@ class DXILResourceMap {
   /// ambiguous so multiple creation instructions may be returned. The resulting
   /// ResourceInfo can be used to depuplicate unique handles that
   /// reference the same resource
-  SmallVector<dxil::ResourceInfo> findByUse(const Value *Key) const;
+  SmallVector<dxil::ResourceInfo *> findByUse(const Value *Key);
 
   const_iterator find(const CallInst *Key) const {
     auto Pos = CallMap.find(Key);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 1c4348321c1d0..d2392ff929611 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -697,6 +697,9 @@ bool DXILResourceTypeMap::invalidate(Module &M, const PreservedAnalyses &PA,
 }
 
 //===----------------------------------------------------------------------===//
+static bool isUpdateCounterIntrinsic(Function &F) {
+  return F.getIntrinsicID() == Intrinsic::dx_resource_updatecounter;
+}
 
 void DXILResourceMap::populate(Module &M, DXILResourceTypeMap &DRTM) {
   SmallVector<std::tuple<CallInst *, ResourceInfo, ResourceTypeInfo>> CIToInfos;
@@ -775,6 +778,42 @@ void DXILResourceMap::populate(Module &M, DXILResourceTypeMap &DRTM) {
     // Adjust the resource binding to use the next ID.
     RI.setBindingID(NextID++);
   }
+
+  for (Function &F : M.functions()) {
+    if (!isUpdateCounterIntrinsic(F))
+      continue;
+
+    LLVM_DEBUG(dbgs() << "Update Counter Function: " << F.getName() << "\n");
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      assert(CI && "Users of dx_resource_updateCounter must be call instrs");
+
+      // Determine if the use is an increment or decrement
+      Value *CountArg = CI->getArgOperand(1);
+      ConstantInt *CountValue = cast<ConstantInt>(CountArg);
+      int64_t CountLiteral = CountValue->getSExtValue();
+
+      // 0 is an unknown direction and shouldn't result in an insert
+      if (CountLiteral == 0)
+        continue;
+
+      ResourceCounterDirection Direction = ResourceCounterDirection::Decrement;
+      if (CountLiteral > 0)
+        Direction = ResourceCounterDirection::Increment;
+
+      // Collect all potential creation points for the handle arg
+      Value *HandleArg = CI->getArgOperand(0);
+      SmallVector<ResourceInfo *> RBInfos = findByUse(HandleArg);
+      for (ResourceInfo *RBInfo : RBInfos) {
+        if (RBInfo->CounterDirection == ResourceCounterDirection::Unknown ||
+            RBInfo->CounterDirection == Direction)
+          RBInfo->CounterDirection = Direction;
+        else
+          RBInfo->CounterDirection = ResourceCounterDirection::Invalid;
+      }
+    }
+  }
 }
 
 void DXILResourceMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
@@ -793,10 +832,9 @@ void DXILResourceMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
   }
 }
 
-SmallVector<dxil::ResourceInfo>
-DXILResourceMap::findByUse(const Value *Key) const {
+SmallVector<dxil::ResourceInfo *> DXILResourceMap::findByUse(const Value *Key) {
   if (const PHINode *Phi = dyn_cast<PHINode>(Key)) {
-    SmallVector<dxil::ResourceInfo> Children;
+    SmallVector<dxil::ResourceInfo *> Children;
     for (const Value *V : Phi->operands()) {
       Children.append(findByUse(V));
     }
@@ -810,9 +848,10 @@ DXILResourceMap::findByUse(const Value *Key) const {
   switch (CI->getIntrinsicID()) {
   // Found the create, return the binding
   case Intrinsic::dx_resource_handlefrombinding: {
-    const auto *It = find(CI);
+    auto Pos = CallMap.find(CI);
+    ResourceInfo *It = &Infos[Pos->second];
     assert(It != Infos.end() && "HandleFromBinding must be in resource map");
-    return {*It};
+    return {It};
   }
   default:
     break;
@@ -821,7 +860,7 @@ DXILResourceMap::findByUse(const Value *Key) const {
   // Check if any of the parameters are the resource we are following. If so
   // keep searching. If none of them are return an empty list
   const Type *UseType = CI->getType();
-  SmallVector<dxil::ResourceInfo> Children;
+  SmallVector<dxil::ResourceInfo *> Children;
   for (const Value *V : CI->args()) {
     if (V->getType() != UseType)
       continue;
diff --git a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll
index 81c8b5530afb6..ea683bb4e5783 100644
--- a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll
+++ b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll
@@ -69,6 +69,7 @@ define void @test_typedbuffer() {
   %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_1_0(
                   i32 3, i32 5, i32 1, i32 0, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i8 -1)
   ; CHECK: Resource [[UAV1:[0-9]+]]:
   ; CHECK:   Binding:
   ; CHECK:     Record ID: 1
@@ -76,7 +77,7 @@ define void @test_typedbuffer() {
   ; CHECK:     Lower Bound: 5
   ; CHECK:     Size: 1
   ; CHECK:   Globally Coherent: 0
-  ; CHECK:   Counter Direction: Unknown
+  ; CHECK:   Counter Direction: Decrement
   ; CHECK:   Class: UAV
   ; CHECK:   Kind: TypedBuffer
   ; CHECK:   IsROV: 0
@@ -92,6 +93,7 @@ define void @test_typedbuffer() {
   %uav2_2 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_1_0(
                   i32 4, i32 0, i32 10, i32 5, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav2_2, i8 1)
   ; CHECK: Resource [[UAV2:[0-9]+]]:
   ; CHECK:   Binding:
   ; CHECK:     Record ID: 2
@@ -99,7 +101,7 @@ define void @test_typedbuffer() {
   ; CHECK:     Lower Bound: 0
   ; CHECK:     Size: 10
   ; CHECK:   Globally Coherent: 0
-  ; CHECK:   Counter Direction: Unknown
+  ; CHECK:   Counter Direction: Increment
   ; CHECK:   Class: UAV
   ; CHECK:   Kind: TypedBuffer
   ; CHECK:   IsROV: 0
diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
index 675a3dc19b912..d1ebfc3b1da41 100644
--- a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
+++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
@@ -11,6 +11,7 @@
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
@@ -28,8 +29,9 @@ class UniqueResourceFromUseTest : public testing::Test {
 protected:
   PassBuilder *PB;
   ModuleAnalysisManager *MAM;
-
+  LLVMContext *Context;
   virtual void SetUp() {
+    Context = new LLVMContext();
     MAM = new ModuleAnalysisManager();
     PB = new PassBuilder();
     PB->registerModuleAnalyses(*MAM);
@@ -37,9 +39,17 @@ class UniqueResourceFromUseTest : public testing::Test {
     MAM->registerPass([&] { return DXILResourceAnalysis(); });
   }
 
+  std::unique_ptr<Module> parseAsm(StringRef Asm) {
+    SMDiagnostic Error;
+    std::unique_ptr<Module> M = parseAssemblyString(Asm, Error, *Context);
+    EXPECT_TRUE(M) << "Bad assembly?: " << Error.getMessage();
+    return M;
+  }
+
   virtual void TearDown() {
     delete PB;
     delete MAM;
+    delete Context;
   }
 };
 
@@ -47,22 +57,18 @@ TEST_F(UniqueResourceFromUseTest, TestTrivialUse) {
   StringRef Assembly = R"(
 define void @main() {
 entry:
-  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
   call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
   call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -77,7 +83,7 @@ declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
       ASSERT_EQ(Bindings.size(), 1u)
           << "Handle should resolve into one resource";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(2u, Binding.LowerBound);
@@ -94,7 +100,7 @@ declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 TEST_F(UniqueResourceFromUseTest, TestIndirectUse) {
   StringRef Assembly = R"(
 define void @foo() {
-  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
   %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
   %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2)
   %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3)
@@ -102,17 +108,13 @@ define void @foo() {
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -127,7 +129,7 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
       ASSERT_EQ(Bindings.size(), 1u)
           << "Handle should resolve into one resource";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(2u, Binding.LowerBound);
@@ -144,10 +146,10 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
 TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) {
   StringRef Assembly = R"(
 define void @foo() {
-  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
-  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false)
-  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false)
-  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 2, i32 2, i32 2, i32 2, i1 false)
+  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 3, i32 3, i32 3, i32 3, i1 false)
+  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 4, i32 4, i32 4, i32 4, i1 false)
   %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar)
   %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat)
   %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b)
@@ -155,17 +157,13 @@ define void @foo() {
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -180,25 +178,25 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
       ASSERT_EQ(Bindings.size(), 4u)
           << "Handle should resolve into four resources";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(1u, Binding.LowerBound);
       EXPECT_EQ(1u, Binding.Size);
 
-      Binding = Bindings[1].getBinding();
+      Binding = Bindings[1]->getBinding();
       EXPECT_EQ(1u, Binding.RecordID);
       EXPECT_EQ(2u, Binding.Space);
       EXPECT_EQ(2u, Binding.LowerBound);
       EXPECT_EQ(2u, Binding.Size);
 
-      Binding = Bindings[2].getBinding();
+      Binding = Bindings[2]->getBinding();
       EXPECT_EQ(2u, Binding.RecordID);
       EXPECT_EQ(3u, Binding.Space);
       EXPECT_EQ(3u, Binding.LowerBound);
       EXPECT_EQ(3u, Binding.Size);
 
-      Binding = Bindings[3].getBinding();
+      Binding = Bindings[3]->getBinding();
       EXPECT_EQ(3u, Binding.RecordID);
       EXPECT_EQ(4u, Binding.Space);
       EXPECT_EQ(4u, Binding.LowerBound);
@@ -216,8 +214,8 @@ TEST_F(UniqueResourceFromUseTest, TestConditionalUse) {
   StringRef Assembly = R"(
 define void @foo(i32 %n) {
 entry:
-  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
-  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 4, i32 4, i32 4, i32 4, i1 false)
   %cond = icmp eq i32 %n, 0
   br i1 %cond, label %bb.true, label %bb.false
 
@@ -235,17 +233,13 @@ bb.exit:
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -260,13 +254,13 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
       ASSERT_EQ(Bindings.size(), 2u)
           << "Handle should resolve into four resources";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(1u, Binding.LowerBound);
       EXPECT_EQ(1u, Binding.Size);
 
-      Binding = Bindings[1].getBinding();
+      Binding = Bindings[1]->getBinding();
       EXPECT_EQ(1u, Binding.RecordID);
       EXPECT_EQ(4u, Binding.Space);
       EXPECT_EQ(4u, Binding.LowerBound);
@@ -280,4 +274,159 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
   }
 }
 
+// Test that several calls to decrement on the same resource don't raise a
+// Diagnositic and resolves to a single decrement entry
+TEST_F(UniqueResourceFromUseTest, TestResourceCounterDecrement) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
+  ret void
+}
+  )";
+
+  auto M = parseAsm(Assembly);
+
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+
+  for (const Function &F : M->functions()) {
+    if (F.getIntrinsicID() != Intrinsic::dx_resource_handlefrombinding) {
+      continue;
+    }
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      const auto *const Binding = DRM.find(CI);
+      ASSERT_EQ(Binding->CounterDirection, ResourceCounterDirection::Decrement);
+    }
+  }
+}
+
+// Test that several calls to increment on the same resource don't raise a
+// Diagnositic and resolves to a single increment entry
+TEST_F(UniqueResourceFromUseTest, TestResourceCounterIncrement) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
+  ret void
+}
+  )";
+
+  auto M = parseAsm(Assembly);
+
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+
+  for (const Function &F : M->functions()) {
+    if (F.getIntrinsicID() != Intrinsic::dx_resource_handlefrombinding) {
+      continue;
+    }
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      const auto *const Binding = DRM.find(CI);
+      ASSERT_EQ(Binding->CounterDirection, ResourceCounterDirection::Increment);
+    }
+  }
+}
+
+// Test that looking up a resource that doesn't have the counter updated
+// resoves to unknown
+TEST_F(UniqueResourceFromUseTest, TestResourceCounterUnknown) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
+  ret void
+}
+  )";
+
+  auto M = parseAsm(Assembly);
+
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+
+  for (const Function &F : M->functions()) {
+    if (F.getIntrinsicID() != Intrinsic::dx_resource_handlefrombinding) {
+      continue;
+    }
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      const auto *const Binding = DRM.find(CI);
+      ASSERT_EQ(Binding->CounterDirection, ResourceCounterDirection::Unknown);
+    }
+  }
+}
+
+// Test that multiple different resources with unique incs/decs aren't
+// marked...
[truncated]

@@ -463,7 +463,7 @@ class DXILResourceMap {
/// ambiguous so multiple creation instructions may be returned. The resulting
/// ResourceInfo can be used to depuplicate unique handles that
/// reference the same resource
SmallVector<dxil::ResourceInfo> findByUse(const Value *Key) const;
SmallVector<dxil::ResourceInfo *> findByUse(const Value *Key);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that setting the value is internal to the analysis, it might make sense to make this an internal/static function? Feels a bit strange to leave the non-const function in the public api.

Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[HLSL] Analyze and annotate updateCounter direction on a resource
2 participants