Skip to content

Commit f9815bf

Browse files
authored
Scatter 0D index for gather, 0D index and 0D updates for scatter. (#48452)
1 parent a3ae080 commit f9815bf

File tree

12 files changed

+377
-147
lines changed

12 files changed

+377
-147
lines changed

paddle/phi/infermeta/binary.cc

+55-23
Original file line numberDiff line numberDiff line change
@@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x,
12681268
index_dims[1]));
12691269
} else {
12701270
PADDLE_ENFORCE_EQ(
1271-
index_dims.size(),
1272-
1,
1271+
index_dims.size() == 1 || index_dims.size() == 0,
1272+
true,
12731273
phi::errors::InvalidArgument(
1274-
"The index should be 1D, when it is not 2D, but we get %d",
1274+
"The index should be 0D or 1D, when it is not 2D, but we get %d",
12751275
index_dims.size()));
12761276
}
12771277

12781278
auto input_dim = x.dims();
12791279
auto axis_v = axis.to<int>();
1280-
if (axis.FromTensor() || axis_v == 0) {
1281-
// if axis.FromTensor(), we can not obtain correct shape of output
1282-
int batch_size = index_dims[0];
1283-
phi::DDim output_dims(input_dim);
1284-
output_dims[0] = batch_size;
1285-
out->set_dims(output_dims);
1286-
out->set_dtype(x.dtype());
1287-
out->share_lod(x);
1288-
} else {
1289-
int index_size = index_dims[0];
1290-
std::vector<int> out_dim_vec;
1291-
for (int i = 0; i < axis_v; i++) {
1292-
out_dim_vec.push_back(input_dim[i]);
1280+
if (index_dims.size() == 0) {
1281+
// 0D index will decrease the dimension
1282+
if (input_dim.size() == 1) {
1283+
// the index is a 0D tensor and the x is a 1D tensor
1284+
out->set_dims(phi::DDim(phi::Dim<0>()));
1285+
} else {
1286+
if (axis.FromTensor() || axis_v == 0) {
1287+
// decrease the output dimension
1288+
std::vector<int> out_dim_vec;
1289+
for (int i = 1; i < input_dim.size(); ++i) {
1290+
out_dim_vec.emplace_back(input_dim[i]);
1291+
}
1292+
auto output_dims = phi::make_ddim(out_dim_vec);
1293+
out->set_dims(output_dims);
1294+
out->set_dtype(x.dtype());
1295+
out->share_lod(x);
1296+
} else {
1297+
std::vector<int> out_dim_vec;
1298+
for (int i = 0; i < axis_v; i++) {
1299+
out_dim_vec.push_back(input_dim[i]);
1300+
}
1301+
for (int i = axis_v + 1; i < input_dim.size(); i++) {
1302+
out_dim_vec.push_back(input_dim[i]);
1303+
}
1304+
auto output_dims = phi::make_ddim(out_dim_vec);
1305+
out->set_dims(output_dims);
1306+
out->set_dtype(x.dtype());
1307+
out->share_lod(x);
1308+
}
12931309
}
1294-
out_dim_vec.push_back(index_size);
1295-
for (int i = axis_v + 1; i < input_dim.size(); i++) {
1296-
out_dim_vec.push_back(input_dim[i]);
1310+
} else {
1311+
if (axis.FromTensor() || axis_v == 0) {
1312+
// if axis.FromTensor(), we can not obtain correct shape of output
1313+
int batch_size = index_dims[0];
1314+
phi::DDim output_dims(input_dim);
1315+
output_dims[0] = batch_size;
1316+
out->set_dims(output_dims);
1317+
out->set_dtype(x.dtype());
1318+
out->share_lod(x);
1319+
} else {
1320+
int index_size = index_dims[0];
1321+
std::vector<int> out_dim_vec;
1322+
for (int i = 0; i < axis_v; i++) {
1323+
out_dim_vec.push_back(input_dim[i]);
1324+
}
1325+
out_dim_vec.push_back(index_size);
1326+
for (int i = axis_v + 1; i < input_dim.size(); i++) {
1327+
out_dim_vec.push_back(input_dim[i]);
1328+
}
1329+
auto output_dims = phi::make_ddim(out_dim_vec);
1330+
out->set_dims(output_dims);
1331+
out->set_dtype(x.dtype());
1332+
out->share_lod(x);
12971333
}
1298-
auto output_dims = phi::make_ddim(out_dim_vec);
1299-
out->set_dims(output_dims);
1300-
out->set_dtype(x.dtype());
1301-
out->share_lod(x);
13021334
}
13031335
}
13041336

paddle/phi/infermeta/ternary.cc

+26-23
Original file line numberDiff line numberDiff line change
@@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x,
995995
"index is a 2D tensor, but we get %d.",
996996
index_dims[1]));
997997
} else {
998+
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
999+
true,
1000+
phi::errors::InvalidArgument(
1001+
"The index should be a 0D or 1D tensor when the "
1002+
"index is not a 2D tensor, but we get %d.",
1003+
index_dims.size()));
1004+
}
1005+
if (index_dims.size() != 0) {
9981006
PADDLE_ENFORCE_EQ(
999-
index_dims.size(),
1000-
1,
1001-
phi::errors::InvalidArgument("The index should be a 1D tensor when the "
1002-
"index is not a 2D tensor, but we get %d.",
1003-
index_dims.size()));
1007+
(ref_dims.size() == updates_dims.size()),
1008+
true,
1009+
phi::errors::InvalidArgument(
1010+
"When the Input(Updates) is not a 0D tensor, the "
1011+
"Input(X) and Input(Updates) should have the same shape size, "
1012+
"but received the size of Input(x)'s shape is %d, the size of "
1013+
"Input(Updates)'s shape is %d.",
1014+
ref_dims.size(),
1015+
updates_dims.size()));
1016+
PADDLE_ENFORCE_EQ(
1017+
updates_dims[0],
1018+
index_dims[0],
1019+
phi::errors::InvalidArgument(
1020+
"Input(Updates) and Input(Ids) should have same batch-size, but"
1021+
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
1022+
"batch-size is %d.",
1023+
updates_dims[0],
1024+
index_dims[0]));
10041025
}
1005-
PADDLE_ENFORCE_EQ(
1006-
ref_dims.size(),
1007-
updates_dims.size(),
1008-
phi::errors::InvalidArgument(
1009-
"Input(X) and Input(Updates) should have the same shape size, "
1010-
"but received the size of Input(x)'s shape is %d, the size of "
1011-
"Input(Updates)'s shape is %d.",
1012-
ref_dims.size(),
1013-
updates_dims.size()));
1014-
PADDLE_ENFORCE_EQ(
1015-
updates_dims[0],
1016-
index_dims[0],
1017-
phi::errors::InvalidArgument(
1018-
"Input(Updates) and Input(Ids) should have same batch-size, but"
1019-
" received Input(Updates)'s batch-size is %d, Input(Ids)'s "
1020-
"batch-size is %d.",
1021-
updates_dims[0],
1022-
index_dims[0]));
10231026
out->set_dims(ref_dims);
10241027
out->share_lod(x);
10251028
out->set_dtype(x.dtype());

paddle/phi/kernels/funcs/gather.cu.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx,
9494
}
9595

9696
// index size
97-
int64_t index_size = index.dims()[0];
98-
if (index_size == 0) return;
97+
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
9998

10099
auto src_dims = src.dims();
101-
phi::DDim output_dims(src_dims);
102-
output_dims[0] = index_size;
103100

104101
// slice size
105102
int64_t slice_size = 1;
@@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input,
246243
inner_dim_size *= input_dim[i];
247244
out_dim_vec.push_back(input_dim[i]);
248245
}
249-
out_dim_vec.push_back(index_size);
246+
if (index->dims().size() != 0) {
247+
out_dim_vec.push_back(index_size);
248+
}
250249
for (int i = axis_index + 1; i < input_dim.size(); i++) {
251250
outer_dim_size *= input_dim[i];
252251
out_dim_vec.push_back(input_dim[i]);

paddle/phi/kernels/funcs/gather.h

+18-10
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx,
3838
const DenseTensor& src,
3939
const DenseTensor& index,
4040
DenseTensor* output) {
41-
// check index of shape 1-D
4241
if (index.dims().size() == 2) {
4342
PADDLE_ENFORCE_EQ(
4443
index.dims()[1],
@@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx,
4847
"in gather_op, but received value is [%d].",
4948
index.dims()[1]));
5049
} else {
51-
PADDLE_ENFORCE_EQ(index.dims().size(),
52-
1,
53-
phi::errors::InvalidArgument(
54-
"index.dims().size() should be 1 or 2 in gather_op,"
55-
"but received shape's size is [%d].",
56-
index.dims().size()));
50+
PADDLE_ENFORCE_EQ(
51+
index.dims().size() == 1 || index.dims().size() == 0,
52+
true,
53+
phi::errors::InvalidArgument(
54+
"The index should be 0D or 1D, when it is not 2D, but we get %d",
55+
index.dims().size()));
5756
}
58-
int64_t index_size = index.dims()[0];
57+
58+
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
5959

6060
auto src_dims = src.dims();
6161

@@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx,
188188
inner_dim_size *= input_dim[i];
189189
out_dim_vec.push_back(input_dim[i]);
190190
}
191-
out_dim_vec.push_back(index_size);
191+
if (index->dims().size() != 0) {
192+
out_dim_vec.push_back(index_size);
193+
}
192194
for (int i = axis_index + 1; i < input_dim.size(); i++) {
193195
outer_dim_size *= input_dim[i];
194196
out_dim_vec.push_back(input_dim[i]);
@@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,
224226

225227
if (input->numel() == 0) return;
226228
int axis_index = axis;
227-
int64_t input_index_dim_size = input_dim[axis_index];
229+
int64_t input_index_dim_size;
230+
if (input_dim.size() == out->dims().size()) {
231+
input_index_dim_size = input_dim[axis_index];
232+
} else {
233+
// 0d index
234+
input_index_dim_size = 1;
235+
}
228236

229237
int64_t inner_dim_size = 1;
230238
int64_t outer_dim_size = 1;

paddle/phi/kernels/funcs/scatter.cu.h

+16-10
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
122122
const DenseTensor& index,
123123
DenseTensor* output,
124124
bool overwrite = true) {
125-
// check index of shape 1-D
126125
if (index.dims().size() == 2) {
127126
PADDLE_ENFORCE_EQ(
128127
index.dims()[1],
@@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
132131
"But received value is [%d]",
133132
index.dims()[1]));
134133
} else {
135-
PADDLE_ENFORCE_EQ(index.dims().size(),
136-
1,
137-
phi::errors::InvalidArgument(
138-
"index.dims().size() should be 1 or 2 in scatter_op."
139-
"But received value is [%d]",
140-
index.dims().size()));
134+
PADDLE_ENFORCE_EQ(
135+
index.dims().size() == 1 || index.dims().size() == 0,
136+
true,
137+
phi::errors::InvalidArgument(
138+
"index.dims().size() should be 0, 1 or 2 in scatter_op."
139+
"But received value is [%d]",
140+
index.dims().size()));
141141
}
142-
int64_t index_size = index.dims()[0];
142+
143+
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
143144

144145
auto src_dims = src.dims();
145146
phi::DDim output_dims(src_dims);
146147
output_dims[0] = index_size;
147148

148149
// slice size
149-
int64_t slice_size = 1;
150-
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
150+
size_t slice_size = 1;
151+
if (index.dims().size() != 0) {
152+
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
153+
} else {
154+
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
155+
}
151156

152157
const T* p_src = src.data<T>();
153158
const IndexT* p_index = index.data<IndexT>();
154159
T* p_output = output->data<T>();
160+
155161
const size_t& slice_bytes = slice_size * sizeof(T);
156162

157163
// set block and grid num

0 commit comments

Comments
 (0)