Skip to content

Commit 248d176

Browse files
committed
Fix out-of-bounds access in GraphSendUERecvCUDAKernel
This fix adds boundary checks to prevent out-of-bounds memory access when src/dst indices exceed valid node range or when broadcast offsets exceed feature dimensions. Root cause: - When src_indices contain values >= num_nodes, the kernel accesses memory beyond allocated buffer - When broadcast offsets exceed x_len/e_len, out-of-bounds access occurs Fix: - Add num_nodes parameter to kernel for boundary validation - Check src/dst indices are within [0, num_nodes) before access - Check x_add < x_len and e_add < e_len for broadcast offsets
1 parent 7a22b38 commit 248d176

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ __global__ void GraphSendUERecvCUDAKernel(const T* x_data,
135135
int64_t x_len,
136136
int64_t e_len,
137137
int64_t out_len,
138+
int64_t num_nodes,
138139
bool use_bcast,
139140
ComputeFunctor cfunctor,
140141
ReduceFunctor rfunctor) {
@@ -147,15 +148,22 @@ __global__ void GraphSendUERecvCUDAKernel(const T* x_data,
147148
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
148149
int64_t stride_x = blockDim.x * static_cast<int64_t>(gridDim.x);
149150

150-
const T* x_off = x_data + src * x_len;
151-
const T* e_off = e_data + ty * e_len;
152-
T* out_off = output + dst * out_len;
153-
while (tx < out_len) {
154-
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
155-
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
156-
T val = cfunctor(x_off[x_add], e_off[e_add]);
157-
rfunctor(out_off + tx, val);
158-
tx += stride_x;
151+
// Add boundary check for src/dst indices to prevent out-of-bounds access
152+
// src and dst must be within valid range: src < num_nodes, dst < num_nodes
153+
if (src >= 0 && src < num_nodes && dst >= 0 && dst < num_nodes) {
154+
const T* x_off = x_data + src * x_len;
155+
const T* e_off = e_data + ty * e_len;
156+
T* out_off = output + dst * out_len;
157+
while (tx < out_len) {
158+
int64_t x_add = use_bcast ? xbcast_off[tx] : tx;
159+
int64_t e_add = use_bcast ? ebcast_off[tx] : tx;
160+
// Add boundary check to prevent out-of-bounds access for bcast offsets
161+
if (x_add < x_len && e_add < e_len) {
162+
T val = cfunctor(x_off[x_add], e_off[e_add]);
163+
rfunctor(out_off + tx, val);
164+
}
165+
tx += stride_x;
166+
}
159167
}
160168
ty += stride_y;
161169
}

paddle/phi/kernels/gpu/send_ue_recv_grad_kernel.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ void CalculateXGrad(const Context& dev_ctx,
185185
const dim3 block_(ntx, nty);
186186
funcs::MultiplyFunctor<T> mul_functor;
187187
GraphSendUERecvSumCUDAFunctor<T> sum_functor;
188+
int64_t num_nodes = x_dims[0];
188189
if (!reduce) {
189190
GraphSendUERecvCUDAKernel<T,
190191
IndexT,
@@ -202,6 +203,7 @@ void CalculateXGrad(const Context& dev_ctx,
202203
bcast_info.l_len,
203204
bcast_info.r_len,
204205
out_len,
206+
num_nodes,
205207
bcast_info.use_bcast,
206208
mul_functor,
207209
sum_functor);
@@ -225,6 +227,7 @@ void CalculateXGrad(const Context& dev_ctx,
225227
bcast_info.l_len,
226228
bcast_info.r_len,
227229
out_len,
230+
num_nodes,
228231
bcast_info.use_bcast,
229232
mul_functor,
230233
sum_functor);

paddle/phi/kernels/gpu/send_ue_recv_kernel.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
9595
const dim3 grid(nbx, nby);
9696
const dim3 block(ntx, nty);
9797
int64_t input_size = x.dims()[0];
98+
int64_t num_nodes = x.dims()[0];
9899
int block_ = 1024;
99100
if (reduce_op == "SUM" || reduce_op == "MEAN") {
100101
GraphSendUERecvSumCUDAFunctor<T> sum_functor;
@@ -116,6 +117,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
116117
bcast_info.l_len,
117118
bcast_info.r_len,
118119
out_len,
120+
num_nodes,
119121
bcast_info.use_bcast,
120122
add_funtor,
121123
sum_functor);
@@ -137,6 +139,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
137139
bcast_info.l_len,
138140
bcast_info.r_len,
139141
out_len,
142+
num_nodes,
140143
bcast_info.use_bcast,
141144
mul_functor,
142145
sum_functor);
@@ -184,6 +187,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
184187
bcast_info.l_len,
185188
bcast_info.r_len,
186189
out_len,
190+
num_nodes,
187191
bcast_info.use_bcast,
188192
add_funtor,
189193
max_functor);
@@ -205,6 +209,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
205209
bcast_info.l_len,
206210
bcast_info.r_len,
207211
out_len,
212+
num_nodes,
208213
bcast_info.use_bcast,
209214
mul_functor,
210215
max_functor);
@@ -237,6 +242,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
237242
bcast_info.l_len,
238243
bcast_info.r_len,
239244
out_len,
245+
num_nodes,
240246
bcast_info.use_bcast,
241247
add_funtor,
242248
min_functor);
@@ -258,6 +264,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& dev_ctx,
258264
bcast_info.l_len,
259265
bcast_info.r_len,
260266
out_len,
267+
num_nodes,
261268
bcast_info.use_bcast,
262269
mul_functor,
263270
min_functor);

paddle/phi/kernels/gpu/send_uv_grad_kernel.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ void CalculateGrad(const Context& dev_ctx,
145145
funcs::MultiplyFunctor<T> mul_functor;
146146
GraphSendUERecvSumCUDAFunctor<T> sum_functor;
147147
const T* y_data = y.data<T>();
148+
int64_t num_nodes = x_grad_dims[0];
148149
if (!reduce) {
149150
GraphSendUERecvCUDAKernel<T,
150151
IndexT,
@@ -162,6 +163,7 @@ void CalculateGrad(const Context& dev_ctx,
162163
bcast_info.l_len,
163164
bcast_info.r_len,
164165
out_len,
166+
num_nodes,
165167
bcast_info.use_bcast,
166168
mul_functor,
167169
sum_functor);
@@ -189,6 +191,7 @@ void CalculateGrad(const Context& dev_ctx,
189191
bcast_info.l_len,
190192
bcast_info.r_len,
191193
out_len,
194+
num_nodes,
192195
bcast_info.use_bcast,
193196
mul_functor,
194197
sum_functor);

0 commit comments

Comments
 (0)