Skip to content

Commit ea84f7b

Browse files
authored
CUDA texture sampler support (#625)
* Add support for specifying sampler settings for CUDA textures and texture views * Enable sampler tests for CUDA * Add validation
1 parent 41dc75f commit ea84f7b

File tree

6 files changed

+107
-93
lines changed

6 files changed

+107
-93
lines changed

include/slang-rhi.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@ struct SubresourceLayout
974974
Size rowCount;
975975
};
976976

977+
class ISampler;
978+
977979
static const uint32_t kAllLayers = 0xffffffff;
978980
static const uint32_t kAllMips = 0xffffffff;
979981
static const SubresourceRange kAllSubresources = {0, kAllLayers, 0, kAllMips};
@@ -1008,6 +1010,10 @@ struct TextureDesc
10081010

10091011
const ClearValue* optimalClearValue = nullptr;
10101012

1013+
/// Default sampler settings to use for the texture (CUDA only).
1014+
/// If not specified, tri-linear filtering and wrap addressing mode will be used.
1015+
ISampler* sampler = nullptr;
1016+
10111017
/// The name of the texture for debugging purposes.
10121018
const char* label = nullptr;
10131019

@@ -1027,6 +1033,11 @@ struct TextureViewDesc
10271033
Format format = Format::Undefined;
10281034
TextureAspect aspect = TextureAspect::All;
10291035
SubresourceRange subresourceRange = kEntireTexture;
1036+
1037+
/// Sampler settings to use for the texture view (CUDA only).
1038+
/// If not specified, the default sampler settings from the texture will be used.
1039+
ISampler* sampler = nullptr;
1040+
10301041
const char* label = nullptr;
10311042
};
10321043

src/cuda/cuda-device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ Result DeviceImpl::initialize(const DeviceDesc& desc)
345345
// Supports surface/swapchain (implemented in Vulkan).
346346
addFeature(Feature::Surface);
347347
#endif
348+
addFeature(Feature::CustomBorderColor);
348349
addFeature(Feature::CombinedTextureSampler);
349350
addFeature(Feature::TimestampQuery);
350351
addFeature(Feature::RealtimeClock);

src/cuda/cuda-texture.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,19 @@ Result DeviceImpl::createTexture(const TextureDesc& desc_, const SubresourceData
270270

271271
auto& samplerSettings = tex->m_defaultSamplerSettings;
272272
samplerSettings = {};
273-
samplerSettings.addressMode[0] = CU_TR_ADDRESS_MODE_WRAP;
274-
samplerSettings.addressMode[1] = CU_TR_ADDRESS_MODE_WRAP;
275-
samplerSettings.addressMode[2] = CU_TR_ADDRESS_MODE_WRAP;
276-
samplerSettings.filterMode = CU_TR_FILTER_MODE_LINEAR;
277-
samplerSettings.maxAnisotropy = 1;
278-
samplerSettings.mipmapFilterMode = CU_TR_FILTER_MODE_LINEAR;
273+
if (desc.sampler)
274+
{
275+
samplerSettings = checked_cast<SamplerImpl*>(desc.sampler)->m_samplerSettings;
276+
}
277+
else
278+
{
279+
samplerSettings.addressMode[0] = CU_TR_ADDRESS_MODE_WRAP;
280+
samplerSettings.addressMode[1] = CU_TR_ADDRESS_MODE_WRAP;
281+
samplerSettings.addressMode[2] = CU_TR_ADDRESS_MODE_WRAP;
282+
samplerSettings.filterMode = CU_TR_FILTER_MODE_LINEAR;
283+
samplerSettings.maxAnisotropy = 1;
284+
samplerSettings.mipmapFilterMode = CU_TR_FILTER_MODE_LINEAR;
285+
}
279286

280287
// The size of the element/texel in bytes
281288
const FormatInfo& formatInfo = getFormatInfo(desc.format);
@@ -647,6 +654,8 @@ Result DeviceImpl::createTextureView(ITexture* texture, const TextureViewDesc& d
647654
if (view->m_desc.format == Format::Undefined)
648655
view->m_desc.format = view->m_texture->m_desc.format;
649656
view->m_desc.subresourceRange = view->m_texture->resolveSubresourceRange(desc.subresourceRange);
657+
view->m_samplerSettings = desc.sampler ? checked_cast<SamplerImpl*>(desc.sampler)->m_samplerSettings
658+
: view->m_texture->m_defaultSamplerSettings;
650659
returnComPtr(outView, view);
651660
return SLANG_OK;
652661
}

src/cuda/cuda-texture.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ class TextureViewImpl : public TextureView
104104
CUtexObject getTexObject()
105105
{
106106
if (!m_cudaTexObj)
107-
m_cudaTexObj =
108-
m_texture->getTexObject(m_desc.format, m_texture->m_defaultSamplerSettings, m_desc.subresourceRange);
107+
m_cudaTexObj = m_texture->getTexObject(m_desc.format, m_samplerSettings, m_desc.subresourceRange);
109108
return m_cudaTexObj;
110109
}
111110

@@ -122,6 +121,7 @@ class TextureViewImpl : public TextureView
122121
}
123122

124123
BreakableReference<TextureImpl> m_texture;
124+
SamplerSettings m_samplerSettings;
125125
CUtexObject m_cudaTexObj = 0;
126126
CUsurfObject m_cudaSurfObj = 0;
127127
};

src/debug-layer/debug-device.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ Result DebugDevice::createTexture(const TextureDesc& desc, const SubresourceData
9999
{
100100
SLANG_RHI_API_FUNC;
101101

102-
DeviceType deviceType = getDeviceType();
103-
104102
if (uint32_t(desc.type) > uint32_t(TextureType::TextureCubeArray))
105103
{
106104
RHI_VALIDATION_ERROR("Invalid texture type");
@@ -165,12 +163,12 @@ Result DebugDevice::createTexture(const TextureDesc& desc, const SubresourceData
165163
RHI_VALIDATION_ERROR("Texture with multisample type cannot have mip levels");
166164
return SLANG_E_INVALID_ARG;
167165
}
168-
if (deviceType == DeviceType::WGPU && desc.sampleCount != 4)
166+
if (ctx->deviceType == DeviceType::WGPU && desc.sampleCount != 4)
169167
{
170168
RHI_VALIDATION_ERROR("WebGPU only supports sample count of 4");
171169
return SLANG_E_INVALID_ARG;
172170
}
173-
if (deviceType == DeviceType::WGPU && desc.arrayLength != 1)
171+
if (ctx->deviceType == DeviceType::WGPU && desc.arrayLength != 1)
174172
{
175173
RHI_VALIDATION_ERROR("WebGPU doesn't support multisampled texture arrays");
176174
return SLANG_E_INVALID_ARG;
@@ -220,6 +218,12 @@ Result DebugDevice::createTexture(const TextureDesc& desc, const SubresourceData
220218
break;
221219
}
222220

221+
if (desc.sampler != nullptr && ctx->deviceType != DeviceType::CUDA)
222+
{
223+
RHI_VALIDATION_ERROR("Setting default sampler for texture is only supported on CUDA device");
224+
return SLANG_E_INVALID_ARG;
225+
}
226+
223227
TextureDesc patchedDesc = desc;
224228
std::string label;
225229
if (!patchedDesc.label)
@@ -421,6 +425,12 @@ Result DebugDevice::createTextureView(ITexture* texture, const TextureViewDesc&
421425
{
422426
SLANG_RHI_API_FUNC;
423427

428+
if (desc.sampler != nullptr && ctx->deviceType != DeviceType::CUDA)
429+
{
430+
RHI_VALIDATION_ERROR("Setting default sampler for texture is only supported on CUDA device");
431+
return SLANG_E_INVALID_ARG;
432+
}
433+
424434
TextureViewDesc patchedDesc = desc;
425435
std::string label;
426436
if (!patchedDesc.label)

tests/test-sampler.cpp

Lines changed: 64 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using namespace rhi;
55
using namespace rhi::testing;
66

7-
static Result createTestTexture(IDevice* device, ITexture** outTexture)
7+
static Result createTestTexture(IDevice* device, ISampler* sampler, ITexture** outTexture)
88
{
99
ComPtr<ITexture> texture;
1010
TextureDesc desc = {};
@@ -14,6 +14,7 @@ static Result createTestTexture(IDevice* device, ITexture** outTexture)
1414
desc.mipCount = 2;
1515
desc.memoryType = MemoryType::DeviceLocal;
1616
desc.usage = TextureUsage::ShaderResource | TextureUsage::CopyDestination | TextureUsage::CopySource;
17+
desc.sampler = sampler;
1718

1819
// mip 0
1920
// ---------------------
@@ -62,95 +63,77 @@ struct TestRecord
6263
float expectedColor[4];
6364
};
6465

65-
struct SamplerTest
66+
static void testSampler(IDevice* device, const SamplerDesc& samplerDesc, span<TestRecord> testRecords)
6667
{
67-
static constexpr size_t kMaxRecords = 32;
68+
ComPtr<ISampler> sampler;
69+
REQUIRE_CALL(device->createSampler(samplerDesc, sampler.writeRef()));
6870

69-
ComPtr<IDevice> device;
7071
ComPtr<ITexture> texture;
72+
// On CUDA, we need to have a sampler associated with the texture for sampling to work.
73+
REQUIRE_CALL(
74+
createTestTexture(device, device->getDeviceType() == DeviceType::CUDA ? sampler : nullptr, texture.writeRef())
75+
);
76+
77+
ComPtr<IShaderProgram> shaderProgram;
78+
REQUIRE_CALL(loadProgram(device, "test-sampler", "sampleTexture", shaderProgram.writeRef()));
79+
80+
ComputePipelineDesc pipelineDesc = {};
81+
pipelineDesc.program = shaderProgram.get();
82+
ComPtr<IComputePipeline> pipeline;
83+
REQUIRE_CALL(device->createComputePipeline(pipelineDesc, pipeline.writeRef()));
84+
85+
BufferDesc bufferDesc = {};
86+
bufferDesc.size = testRecords.size() * sizeof(TestInput);
87+
bufferDesc.elementSize = sizeof(TestInput);
88+
bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::CopyDestination;
7189
ComPtr<IBuffer> inputBuffer;
90+
REQUIRE_CALL(device->createBuffer(bufferDesc, nullptr, inputBuffer.writeRef()));
91+
92+
bufferDesc.size = testRecords.size() * sizeof(TestOutput);
93+
bufferDesc.elementSize = sizeof(TestOutput);
94+
bufferDesc.usage = BufferUsage::UnorderedAccess | BufferUsage::CopySource;
7295
ComPtr<IBuffer> resultBuffer;
73-
ComPtr<IComputePipeline> pipeline;
96+
REQUIRE_CALL(device->createBuffer(bufferDesc, nullptr, resultBuffer.writeRef()));
7497

75-
void init(IDevice* device_)
98+
std::vector<TestInput> inputData;
99+
for (const auto& record : testRecords)
76100
{
77-
this->device = device_;
78-
REQUIRE_CALL(createTestTexture(device, texture.writeRef()));
79-
80-
ComPtr<IShaderProgram> shaderProgram;
81-
REQUIRE_CALL(loadProgram(device, "test-sampler", "sampleTexture", shaderProgram.writeRef()));
82-
83-
ComputePipelineDesc pipelineDesc = {};
84-
pipelineDesc.program = shaderProgram.get();
85-
REQUIRE_CALL(device->createComputePipeline(pipelineDesc, pipeline.writeRef()));
86-
87-
BufferDesc bufferDesc = {};
88-
bufferDesc.size = kMaxRecords * sizeof(TestInput);
89-
bufferDesc.elementSize = sizeof(TestInput);
90-
bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::CopyDestination;
91-
REQUIRE_CALL(device->createBuffer(bufferDesc, nullptr, inputBuffer.writeRef()));
92-
93-
bufferDesc.size = kMaxRecords * sizeof(TestOutput);
94-
bufferDesc.elementSize = sizeof(TestOutput);
95-
bufferDesc.usage = BufferUsage::UnorderedAccess | BufferUsage::CopySource;
96-
REQUIRE_CALL(device->createBuffer(bufferDesc, nullptr, resultBuffer.writeRef()));
101+
inputData.push_back({record.u, record.v, record.level, 0.f});
97102
}
98-
99-
void check(ISampler* sampler, span<TestRecord> testRecords)
103+
ComPtr<ICommandQueue> queue = device->getQueue(QueueType::Graphics);
104+
ComPtr<ICommandEncoder> encoder = queue->createCommandEncoder();
105+
encoder->uploadBufferData(inputBuffer, 0, inputData.size() * sizeof(TestInput), inputData.data());
106+
IComputePassEncoder* passEncoder = encoder->beginComputePass();
107+
IShaderObject* rootObject = passEncoder->bindPipeline(pipeline);
108+
ShaderCursor cursor(rootObject);
109+
cursor["texture"].setBinding(texture);
110+
cursor["sampler"].setBinding(sampler);
111+
cursor["inputs"].setBinding(inputBuffer);
112+
cursor["results"].setBinding(resultBuffer);
113+
cursor["count"].setData((uint32_t)testRecords.size());
114+
passEncoder->dispatchCompute((uint32_t)testRecords.size(), 1, 1);
115+
passEncoder->end();
116+
queue->submit(encoder->finish());
117+
queue->waitOnHost();
118+
119+
ComPtr<ISlangBlob> resultData;
120+
REQUIRE_CALL(device->readBuffer(resultBuffer, 0, testRecords.size() * sizeof(TestOutput), resultData.writeRef()));
121+
const TestOutput* output = (const TestOutput*)resultData->getBufferPointer();
122+
for (size_t i = 0; i < testRecords.size(); i++)
100123
{
101-
REQUIRE(testRecords.size() <= kMaxRecords);
102-
std::vector<TestInput> inputData;
103-
for (const auto& record : testRecords)
124+
const TestRecord& record = testRecords[i];
125+
CAPTURE(record.u);
126+
CAPTURE(record.v);
127+
CAPTURE(record.level);
128+
for (size_t j = 0; j < 4; j++)
104129
{
105-
inputData.push_back({record.u, record.v, record.level, 0.f});
106-
}
107-
ComPtr<ICommandQueue> queue = device->getQueue(QueueType::Graphics);
108-
ComPtr<ICommandEncoder> encoder = queue->createCommandEncoder();
109-
encoder->uploadBufferData(inputBuffer, 0, inputData.size() * sizeof(TestInput), inputData.data());
110-
IComputePassEncoder* passEncoder = encoder->beginComputePass();
111-
IShaderObject* rootObject = passEncoder->bindPipeline(pipeline);
112-
ShaderCursor cursor(rootObject);
113-
cursor["texture"].setBinding(texture);
114-
cursor["sampler"].setBinding(sampler);
115-
cursor["inputs"].setBinding(inputBuffer);
116-
cursor["results"].setBinding(resultBuffer);
117-
cursor["count"].setData((uint32_t)testRecords.size());
118-
passEncoder->dispatchCompute((uint32_t)testRecords.size(), 1, 1);
119-
passEncoder->end();
120-
queue->submit(encoder->finish());
121-
queue->waitOnHost();
122-
123-
ComPtr<ISlangBlob> resultData;
124-
REQUIRE_CALL(
125-
device->readBuffer(resultBuffer, 0, testRecords.size() * sizeof(TestOutput), resultData.writeRef())
126-
);
127-
const TestOutput* output = (const TestOutput*)resultData->getBufferPointer();
128-
for (size_t i = 0; i < testRecords.size(); i++)
129-
{
130-
const TestRecord& record = testRecords[i];
131-
CAPTURE(record.u);
132-
CAPTURE(record.v);
133-
CAPTURE(record.level);
134-
for (size_t j = 0; j < 4; j++)
135-
{
136-
CAPTURE(j);
137-
REQUIRE_EQ(output[i].color[j], testRecords[i].expectedColor[j]);
138-
}
130+
CAPTURE(j);
131+
REQUIRE_EQ(output[i].color[j], testRecords[i].expectedColor[j]);
139132
}
140133
}
141-
};
142-
143-
static void testSampler(IDevice* device, const SamplerDesc& samplerDesc, span<TestRecord> testRecords)
144-
{
145-
ComPtr<ISampler> sampler;
146-
REQUIRE_CALL(device->createSampler(samplerDesc, sampler.writeRef()));
147-
148-
SamplerTest test;
149-
test.init(device);
150-
test.check(sampler, testRecords);
151134
}
152135

153-
GPU_TEST_CASE("sampler-filter-point", D3D11 | D3D12 | Vulkan | Metal | WGPU)
136+
GPU_TEST_CASE("sampler-filter-point", D3D11 | D3D12 | Vulkan | Metal | WGPU | CUDA)
154137
{
155138
SamplerDesc desc = {};
156139
desc.minFilter = TextureFilteringMode::Point;
@@ -178,7 +161,7 @@ GPU_TEST_CASE("sampler-filter-point", D3D11 | D3D12 | Vulkan | Metal | WGPU)
178161
testSampler(device, desc, testRecords);
179162
}
180163

181-
GPU_TEST_CASE("sampler-filter-linear", D3D11 | D3D12 | Vulkan | Metal | WGPU)
164+
GPU_TEST_CASE("sampler-filter-linear", D3D11 | D3D12 | Vulkan | Metal | WGPU | CUDA)
182165
{
183166
SamplerDesc desc = {};
184167
desc.minFilter = TextureFilteringMode::Linear;
@@ -208,7 +191,7 @@ GPU_TEST_CASE("sampler-filter-linear", D3D11 | D3D12 | Vulkan | Metal | WGPU)
208191
testSampler(device, desc, testRecords);
209192
}
210193

211-
GPU_TEST_CASE("sampler-border-black-transparent", D3D11 | D3D12 | Vulkan | Metal)
194+
GPU_TEST_CASE("sampler-border-black-transparent", D3D11 | D3D12 | Vulkan | Metal | CUDA)
212195
{
213196
SamplerDesc desc = {};
214197
desc.addressU = TextureAddressingMode::ClampToBorder;
@@ -223,7 +206,7 @@ GPU_TEST_CASE("sampler-border-black-transparent", D3D11 | D3D12 | Vulkan | Metal
223206
testSampler(device, desc, testRecords);
224207
}
225208

226-
GPU_TEST_CASE("sampler-border-black-opaque", D3D11 | D3D12 | Vulkan | Metal)
209+
GPU_TEST_CASE("sampler-border-black-opaque", D3D11 | D3D12 | Vulkan | Metal | CUDA)
227210
{
228211
SamplerDesc desc = {};
229212
desc.addressU = TextureAddressingMode::ClampToBorder;
@@ -243,7 +226,7 @@ GPU_TEST_CASE("sampler-border-black-opaque", D3D11 | D3D12 | Vulkan | Metal)
243226
testSampler(device, desc, testRecords);
244227
}
245228

246-
GPU_TEST_CASE("sampler-border-white-opaque", D3D11 | D3D12 | Vulkan | Metal)
229+
GPU_TEST_CASE("sampler-border-white-opaque", D3D11 | D3D12 | Vulkan | Metal | CUDA)
247230
{
248231
SamplerDesc desc = {};
249232
desc.addressU = TextureAddressingMode::ClampToBorder;
@@ -262,7 +245,7 @@ GPU_TEST_CASE("sampler-border-white-opaque", D3D11 | D3D12 | Vulkan | Metal)
262245
testSampler(device, desc, testRecords);
263246
}
264247

265-
GPU_TEST_CASE("sampler-border-custom-color", D3D11 | D3D12 | Vulkan | Metal)
248+
GPU_TEST_CASE("sampler-border-custom-color", D3D11 | D3D12 | Vulkan | Metal | CUDA)
266249
{
267250
if (!device->hasFeature(Feature::CustomBorderColor))
268251
SKIP("Custom border color not supported");

0 commit comments

Comments
 (0)