Skip to content

Commit 6bf2009

Browse files
authored
Prevent copy-elision pass for CudaDeviceExport (#8941)
The copy-elision optimization pass was incorrectly transforming struct parameters in functions marked with [CudaDeviceExport] from pass-by-value to pass-by-pointer, breaking the externally visible function signatures that clients depend on. This fix adds CudaDeviceExportDecoration to the exclusion list in shouldProcessFunction(), ensuring that exported CUDA device functions preserve their original signatures. Fixes: #8874
1 parent 286bd05 commit 6bf2009

2 files changed

Lines changed: 32 additions & 0 deletions

File tree

source/slang/slang-ir-transform-params-to-constref.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ struct TransformParamsToConstRefContext
222222
if (as<IREntryPointDecoration>(decoration) || as<IRCudaKernelDecoration>(decoration) ||
223223
as<IRAutoPyBindCudaDecoration>(decoration))
224224
return false;
225+
226+
// Skip functions with CudaDeviceExport decoration.
227+
// These functions have externally visible signatures that should not be changed.
228+
if (func->findDecorationImpl(kIROp_CudaDeviceExportDecoration))
229+
return false;
225230
}
226231

227232
// Skip functions with `kIROp_GenericAsm` since
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
2+
3+
// Verify that struct parameters in [CudaDeviceExport] functions are passed by value,
4+
// not transformed to pointers by the copy-elision pass.
5+
// This is a regression test for issue #8874.
6+
7+
struct CommonParameters
8+
{
9+
int id;
10+
float value;
11+
}
12+
13+
struct Parameters
14+
{
15+
float3 position;
16+
float scale;
17+
}
18+
19+
// CUDA: __device__ Parameters_[[#]] processParameters(uint idx_[[#]], CommonParameters_[[#]] commonParams_[[#]])
20+
[CudaDeviceExport]
21+
Parameters processParameters(uint idx, CommonParameters commonParams)
22+
{
23+
Parameters result;
24+
result.position = float3(commonParams.value, commonParams.id, 0.0f);
25+
result.scale = commonParams.value + float(idx);
26+
return result;
27+
}

0 commit comments

Comments
 (0)