Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,31 @@ KERNEL(softmax)(
max_value = max(max_value, in);
data[cls*TMP_CLASS_PITCH] = in;
}
// Handle IEEE-754 case when max_value is INF
if (isinf((float)max_value)) {
for (cls = 0; cls < class_num; ++cls) {
ACCUMULATOR_TYPE v = data[cls * TMP_CLASS_PITCH];
if (v == max_value)
data[cls * TMP_CLASS_PITCH] = (ACCUMULATOR_TYPE)NAN;
else
data[cls * TMP_CLASS_PITCH] = (ACCUMULATOR_TYPE)0.0f;
}

// Write results and exit
for (cls = 0; cls < class_num; ++cls) {
#if INPUT0_SIMPLE == 1
const uint output_idx = out_depth_offset + cls*OUTPUT_CLASS_PITCH;
#else
#if INPUT0_DIMS == 5
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, z + *z_offset, y + *y_offset, x + *x_offset);
#else
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, y + *y_offset, x + *x_offset);
#endif
#endif
output[output_idx] = data[cls * TMP_CLASS_PITCH];
}
return;
}
Comment on lines +109 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to apply fused ops for the issued case too.
Please check my suggestion below.

Suggested change
// Handle IEEE-754 case when max_value is INF
if (isinf((float)max_value)) {
for (cls = 0; cls < class_num; ++cls) {
ACCUMULATOR_TYPE v = data[cls * TMP_CLASS_PITCH];
if (v == max_value)
data[cls * TMP_CLASS_PITCH] = (ACCUMULATOR_TYPE)NAN;
else
data[cls * TMP_CLASS_PITCH] = (ACCUMULATOR_TYPE)0.0f;
}
// Write results and exit
for (cls = 0; cls < class_num; ++cls) {
#if INPUT0_SIMPLE == 1
const uint output_idx = out_depth_offset + cls*OUTPUT_CLASS_PITCH;
#else
#if INPUT0_DIMS == 5
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, z + *z_offset, y + *y_offset, x + *x_offset);
#else
const uint output_idx = OUTPUT_GET_INDEX(b + *b_offset, f + *f_offset, y + *y_offset, x + *x_offset);
#endif
#endif
output[output_idx] = data[cls * TMP_CLASS_PITCH];
}
return;
}
for (cls = 0; cls < class_num; ++cls) {
// Handle IEEE-754 case when max_value is INF
if (isinf((float)max_value)) {
if (data[cls*TMP_CLASS_PITCH] == max_value)
data[cls*TMP_CLASS_PITCH] = TO_ACCUMULATOR_TYPE(NAN);
else
data[cls*TMP_CLASS_PITCH] = TO_ACCUMULATOR_TYPE(0.0f);
} else {
ACCUMULATOR_TYPE t = native_exp(data[cls*TMP_CLASS_PITCH] - max_value);
denominator += t;
data[cls*TMP_CLASS_PITCH] = t;
}
}
....
for (cls = 0; cls < class_num; ++cls) {
ACCUMULATOR_TYPE res = data[cls*TMP_CLASS_PITCH];
if (!isinf((float)max_value)) {
res = res / denominator;
}


// TODO: currently we calculate on float32 because it's lot of "add" operation and it stuck on the value "8192.0f"
ACCUMULATOR_TYPE denominator = 0.0;
Expand Down
Loading