Skip to content

[GPU] scatter_nd_update_kernel_ref optimization #30637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ void scatterNdUpdate(
const ov::op::v15::ScatterNDUpdate::Reduction reduction_type = ov::op::v15::ScatterNDUpdate::Reduction::NONE) {
const auto update_chunk_shape = span(dataShape).drop_front(indicesShape.back());
const auto update_el_number = shape_size(update_chunk_shape);
// if (update_chunk_shape.size() == 1) {
// printf("update_chunk_shape: %ld\n", update_chunk_shape[0]);
// } else if (update_chunk_shape.size() == 2) {
// printf("update_chunk_shape: %ld, %ld\n", update_chunk_shape[0], update_chunk_shape[1]);
// } else if (update_chunk_shape.size() == 3) {
// printf("update_chunk_shape: %ld, %ld, %ld\n", update_chunk_shape[0], update_chunk_shape[1], update_chunk_shape[2]);
// } else if (update_chunk_shape.size() == 4) {
// printf("update_chunk_shape: %ld, %ld, %ld, %ld\n", update_chunk_shape[0], update_chunk_shape[1], update_chunk_shape[2], update_chunk_shape[3]);
// }

std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));

Expand All @@ -89,6 +98,7 @@ void scatterNdUpdate(
const auto reduction = scatter_nd_update::reduction_functor_for<dataType>(reduction_type);
std::vector<indicesType> indicesCopy(indices, indices + shape_size(indicesShape));
const auto num_of_updates = shape_size(span(indicesShape).drop_back(1));
std::cout << "num_of_updates: " << num_of_updates << ", update_el_number: " << update_el_number << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debug prints.

for (size_t i = 0; i != num_of_updates; ++i) {
const auto indices_coord = indicesCopy.data() + i * indicesShape.back();
const auto coord = span(indices_coord, indicesShape.back());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,64 @@

#include "include/batch_headers/fetch_data.cl"

#define GET_UPDATES_INDEX(prefix, idx_order) CAT(prefix, _GET_INDEX)(idx_order)
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
#define GET_INPUT_INDEX(idx_order) INPUT0_GET_INDEX(idx_order)
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
#define GET_UPDATES_INDEX(idx_order) INPUT2_GET_INDEX(idx_order)
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)

#if OUTPUT_DIMS == 4
#define ORDER b,f,y,x
#define TARGET_COORD_ORDER target_coord[0],target_coord[1],target_coord[2],target_coord[3]
#elif OUTPUT_DIMS == 5
#define ORDER b,f,z,y,x
#define TARGET_COORD_ORDER target_coord[0],target_coord[1],target_coord[2],target_coord[3],target_coord[4]
#elif OUTPUT_DIMS == 6
#define ORDER b,f,w,z,y,x
#define TARGET_COORD_ORDER target_coord[0],target_coord[1],target_coord[2],target_coord[3],target_coord[4],target_coord[5]
#endif

#if INPUT2_DIMS == 4
#define UPD_ORDER upd_b,upd_f,upd_y,upd_x
#define INPUT2_ORDER b,f,y,x
#elif INPUT2_DIMS == 5
#define UPD_ORDER upd_b,upd_f,upd_z,upd_y,upd_x
#define INPUT2_ORDER b,f,z,y,x
#elif INPUT2_DIMS == 6
#define UPD_ORDER upd_b,upd_f,upd_w,upd_z,upd_y,upd_x
#define INPUT2_ORDER b,f,w,z,y,x
#endif

#if INPUT1_DIMS == 4
#define IDX_ORDER idx_b,idx_f,idx_y,idx_x
#elif INPUT1_DIMS == 5
#define IDX_ORDER idx_b,idx_f,idx_z,idx_y,idx_x
#elif INPUT1_DIMS == 6
#define IDX_ORDER idx_b,idx_f,idx_w,idx_z,idx_y,idx_x
#define INDICES_MAX_DIM 6


#if INDICES_RANK == 1
#define IND_ORDER i,0,0,0
#elif INDICES_RANK == 2
#if INPUT1_DIMS == 4
#define IND_ORDER b,i,0,0
#elif INPUT1_DIMS == 5
#define IND_ORDER b,i,0,0,0
#elif INPUT1_DIMS == 6
#define IND_ORDER b,i,0,0,0,0
#endif
#elif INDICES_RANK == 3
#define IND_ORDER b,f,0,i
#elif INDICES_RANK == 4
#if INPUT1_DIMS == 4
#if INPUT2_DIMS == 4
#define IND_ORDER b,f,y,i
#elif INPUT2_DIMS == 5
#define IND_ORDER b,f,z,i
#elif INPUT2_DIMS == 6
#define IND_ORDER b,f,w,i
#endif
#elif INPUT1_DIMS == 5
#define IND_ORDER b,f,y,i,0
#endif
// #elif INDICES_RANK == 5
// target_coord[i] = indices[INPUT1_GET_INDEX(b, f, z, y, i)];
// #elif INDICES_RANK == 6
// target_coord[i] = indices[INPUT1_GET_INDEX(b, f, w, z, y, i)];
#endif

#define INDICES_MAX_DIM 6


KERNEL(scatter_nd_update_ref)(OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* data,
Expand All @@ -50,15 +80,15 @@ KERNEL(scatter_nd_update_ref)(OPTIONAL_SHAPE_INFO_ARG
const uint dim1 = get_global_id(1);
const uint dim2 = get_global_id(2);

#ifndef IS_SECOND_ITER // First kernel
#ifdef IS_FIRST_ITER
const uint x = dim0 % OUTPUT_SIZE_X;
const uint y = dim0 / OUTPUT_SIZE_X;
const uint z = dim1 % OUTPUT_SIZE_Z;
const uint w = dim1 / OUTPUT_SIZE_Z;
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;

const uint input_idx = GET_UPDATES_INDEX(INPUT0, ORDER);
const uint input_idx = GET_INPUT_INDEX(ORDER);
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
INPUT0_TYPE val = data[input_idx];
#if HAS_FUSED_OPS
Expand All @@ -68,149 +98,68 @@ KERNEL(scatter_nd_update_ref)(OPTIONAL_SHAPE_INFO_ARG
output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif

#else // Second kernel

const uint dataND[] = {INPUT0_BLOCK_ND};
const uint updatesND[] = {INPUT2_BLOCK_ND};
const uint indicesND[] = {INPUT1_BLOCK_ND};
const uint size_to_update = dataND[INDICES_LAST_DIM];
#else // IS_SECOND_ITER

#if INPUT1_DIMS == 4
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#elif INPUT1_DIMS == 5
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#elif INPUT1_DIMS == 6
const uint indices_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#endif
#if INPUT2_DIMS == 4
const uint x = dim0;
const uint y = dim1;
const uint f = dim2 % INPUT2_FEATURE_NUM;
const uint b = dim2 / INPUT2_FEATURE_NUM;
#elif INPUT2_DIMS == 5
const uint x = dim0;
const uint y = dim1 % INPUT2_SIZE_Y;
const uint z = dim1 / INPUT2_SIZE_Y;
const uint f = dim2 % INPUT2_FEATURE_NUM;
const uint b = dim2 / INPUT2_FEATURE_NUM;
#elif INPUT2_DIMS == 6
const uint x = dim0 % INPUT2_SIZE_X;
const uint y = dim0 / INPUT2_SIZE_X;
const uint z = dim1 % INPUT2_SIZE_Z;
const uint w = dim1 / INPUT2_SIZE_Z;
const uint f = dim2 % INPUT2_FEATURE_NUM;
const uint b = dim2 / INPUT2_FEATURE_NUM;
#endif

#if INPUT0_DIMS == 4
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#elif INPUT0_DIMS == 5
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#elif INPUT0_DIMS == 6
const uint data_dim[INPUT0_DIMS] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#endif
INPUT1_TYPE target_coord[INDICES_MAX_DIM];
INPUT1_TYPE g_coord[INDICES_MAX_DIM] = { INPUT2_ORDER };

// Get indices index
uint idx[INDICES_MAX_DIM] = {0};
uint rmd_idx = dim2;
for (int i = 0; i < INDICES_RANK - 1; ++i) {
idx[i] = rmd_idx / indicesND[1 + i];
rmd_idx %= indicesND[1 + i];
#if INPUT1_LENGTH == 1 && INDICES_RANK == 1
for (uint i = 0; i < OUTPUT_DIMS; ++i) {
target_coord[i] = g_coord[i];
}

uint out[INDICES_MAX_DIM] = {0};
for (int i = 0; i < indices_dim[INDICES_RANK - 1]; ++i) {
idx[INDICES_RANK - 1] = i;
const uint idx_b = idx[0];
const uint idx_f = idx[1];
#if INPUT1_DIMS == 4
const uint idx_y = idx[2];
const uint idx_x = idx[3];
#elif INPUT1_DIMS == 5
const uint idx_z = idx[2];
const uint idx_y = idx[3];
const uint idx_x = idx[4];
#elif INPUT1_DIMS == 6
const uint idx_w = idx[2];
const uint idx_z = idx[3];
const uint idx_y = idx[4];
const uint idx_x = idx[5];
#endif
uint index = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
out[i] = indices[index];

// Check if tensor size is valid
// ex) when data format = bfyx and data shape = { 3, 3, 4, 1 }, indices shape is { 2, 1 } with rank = 2, indices values are { 1.0, 4.0 },
// the second indices value is invalid as data shape has 'b' of size 3, and therefore 4 cannot be a correct index of data
// If indices value is invalid, saturate value to max valid value (ex. 4.0 -> 2.0)
if(out[i] >= data_dim[i])
out[i] = data_dim[i] - 1;
#else
for (uint i = 0; i < INDICES_LAST_DIM; ++i) {
target_coord[i] = indices[GET_INDICES_INDEX(IND_ORDER)];
}

for (int i = 0; i < size_to_update; ++i) {
// Define updates index
uint upd[INDICES_MAX_DIM] = {0};
for (int j = 0; j < INDICES_RANK - 1; ++j) {
upd[j] = idx[j];
}
uint data_rmd = i, updates_rmd = i;
for (int j = indices_dim[INDICES_RANK - 1]; j < INPUT0_DIMS; ++j) {
out[j] = data_rmd / dataND[j + 1];
data_rmd %= dataND[j + 1];
}
for (int k = INDICES_RANK - 1; k < INPUT2_DIMS; ++k) {
upd[k] = updates_rmd / updatesND[k + 1];
updates_rmd %= updatesND[k + 1];
}
// Get update index
const uint upd_b = upd[0];
const uint upd_f = upd[1];
#if INPUT2_DIMS == 4
const uint upd_y = upd[2];
const uint upd_x = upd[3];
#elif INPUT2_DIMS == 5
const uint upd_z = upd[2];
const uint upd_y = upd[3];
const uint upd_x = upd[4];
#elif INPUT2_DIMS == 6
const uint upd_w = upd[2];
const uint upd_z = upd[3];
const uint upd_y = upd[4];
const uint upd_x = upd[5];
#endif
uint upd_idx = GET_UPDATES_INDEX(INPUT2, UPD_ORDER);

// Get output index
const uint b = out[0];
const uint f = out[1];
#if INPUT0_DIMS == 4
const uint y = out[2];
const uint x = out[3];
#elif INPUT0_DIMS == 5
const uint z = out[2];
const uint y = out[3];
const uint x = out[4];
#elif INPUT0_DIMS == 6
const uint w = out[2];
const uint z = out[3];
const uint y = out[4];
const uint x = out[5];
#endif
uint out_idx = GET_OUTPUT_INDEX(ORDER);
INPUT2_TYPE val = updates[upd_idx];

#if HAS_FUSED_OPS
FUSED_OPS_SECOND_KERNEL;
output[out_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
#else
output[out_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif
for (uint i = INDICES_LAST_DIM; i < OUTPUT_DIMS; ++i) {
target_coord[i] = g_coord[INDICES_RANK - 1 - INDICES_LAST_DIM + i];
}
#endif

}
const uint output_idx = GET_OUTPUT_INDEX(TARGET_COORD_ORDER);
const uint updates_idx = GET_UPDATES_INDEX(INPUT2_ORDER);

#ifdef GET_UPDATES_INDEX
#undef GET_UPDATES_INDEX
#endif
INPUT2_TYPE val = updates[updates_idx];

#ifdef GET_OUTPUT_INDEX
#undef GET_OUTPUT_INDEX
#endif
// printf("g_coord[%2d,%2d,%2d,%2d,%2d] target_coord[%2d,%2d,%2d,%2d,%2d] output_id(%2d) updates_idx(%2d) val(%.2f) INDICES_LAST_DIM(%d) OUTPUT_DIMS(%d)\n",
// g_coord[0], g_coord[1], g_coord[2], g_coord[3], g_coord[4],
// target_coord[0], target_coord[1], target_coord[2], target_coord[3], target_coord[4],
// output_idx, updates_idx, val, INDICES_LAST_DIM, OUTPUT_DIMS);

#if HAS_FUSED_OPS
FUSED_OPS_SECOND_KERNEL;
output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SECOND_KERNEL);
#else
output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif
#endif // IS_SECOND_ITER
}

#ifdef ORDER
#undef ORDER
#endif

#ifdef UPD_ORDER
#undef UPD_ORDER
#endif

#ifdef IDX_ORDER
#undef IDX_ORDER
#endif

#ifdef INDICES_MAX_DIM
#undef INDICES_MAX_DIM
#endif
Loading
Loading