@@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x,
1268
1268
index_dims[1 ]));
1269
1269
} else {
1270
1270
PADDLE_ENFORCE_EQ (
1271
- index_dims.size (),
1272
- 1 ,
1271
+ index_dims.size () == 1 || index_dims. size () == 0 ,
1272
+ true ,
1273
1273
phi::errors::InvalidArgument (
1274
- " The index should be 1D, when it is not 2D, but we get %d" ,
1274
+ " The index should be 0D or 1D, when it is not 2D, but we get %d" ,
1275
1275
index_dims.size ()));
1276
1276
}
1277
1277
1278
1278
auto input_dim = x.dims ();
1279
1279
auto axis_v = axis.to <int >();
1280
- if (axis.FromTensor () || axis_v == 0 ) {
1281
- // if axis.FromTensor(), we can not obtain correct shape of output
1282
- int batch_size = index_dims[0 ];
1283
- phi::DDim output_dims (input_dim);
1284
- output_dims[0 ] = batch_size;
1285
- out->set_dims (output_dims);
1286
- out->set_dtype (x.dtype ());
1287
- out->share_lod (x);
1288
- } else {
1289
- int index_size = index_dims[0 ];
1290
- std::vector<int > out_dim_vec;
1291
- for (int i = 0 ; i < axis_v; i++) {
1292
- out_dim_vec.push_back (input_dim[i]);
1280
+ if (index_dims.size () == 0 ) {
1281
+ // 0D index will decrease the dimension
1282
+ if (input_dim.size () == 1 ) {
1283
+ // the index is a 0D tensor and the x is a 1D tensor
1284
+ out->set_dims (phi::DDim (phi::Dim<0 >()));
1285
+ } else {
1286
+ if (axis.FromTensor () || axis_v == 0 ) {
1287
+ // decrease the output dimension
1288
+ std::vector<int > out_dim_vec;
1289
+ for (int i = 1 ; i < input_dim.size (); ++i) {
1290
+ out_dim_vec.emplace_back (input_dim[i]);
1291
+ }
1292
+ auto output_dims = phi::make_ddim (out_dim_vec);
1293
+ out->set_dims (output_dims);
1294
+ out->set_dtype (x.dtype ());
1295
+ out->share_lod (x);
1296
+ } else {
1297
+ std::vector<int > out_dim_vec;
1298
+ for (int i = 0 ; i < axis_v; i++) {
1299
+ out_dim_vec.push_back (input_dim[i]);
1300
+ }
1301
+ for (int i = axis_v + 1 ; i < input_dim.size (); i++) {
1302
+ out_dim_vec.push_back (input_dim[i]);
1303
+ }
1304
+ auto output_dims = phi::make_ddim (out_dim_vec);
1305
+ out->set_dims (output_dims);
1306
+ out->set_dtype (x.dtype ());
1307
+ out->share_lod (x);
1308
+ }
1293
1309
}
1294
- out_dim_vec.push_back (index_size);
1295
- for (int i = axis_v + 1 ; i < input_dim.size (); i++) {
1296
- out_dim_vec.push_back (input_dim[i]);
1310
+ } else {
1311
+ if (axis.FromTensor () || axis_v == 0 ) {
1312
+ // if axis.FromTensor(), we can not obtain correct shape of output
1313
+ int batch_size = index_dims[0 ];
1314
+ phi::DDim output_dims (input_dim);
1315
+ output_dims[0 ] = batch_size;
1316
+ out->set_dims (output_dims);
1317
+ out->set_dtype (x.dtype ());
1318
+ out->share_lod (x);
1319
+ } else {
1320
+ int index_size = index_dims[0 ];
1321
+ std::vector<int > out_dim_vec;
1322
+ for (int i = 0 ; i < axis_v; i++) {
1323
+ out_dim_vec.push_back (input_dim[i]);
1324
+ }
1325
+ out_dim_vec.push_back (index_size);
1326
+ for (int i = axis_v + 1 ; i < input_dim.size (); i++) {
1327
+ out_dim_vec.push_back (input_dim[i]);
1328
+ }
1329
+ auto output_dims = phi::make_ddim (out_dim_vec);
1330
+ out->set_dims (output_dims);
1331
+ out->set_dtype (x.dtype ());
1332
+ out->share_lod (x);
1297
1333
}
1298
- auto output_dims = phi::make_ddim (out_dim_vec);
1299
- out->set_dims (output_dims);
1300
- out->set_dtype (x.dtype ());
1301
- out->share_lod (x);
1302
1334
}
1303
1335
}
1304
1336
0 commit comments