Skip to content

silu_init has incorrect initialization for approx mode #37896

@tt-aho

Description

@tt-aho

The below code is for BH.

template <bool is_fp32_dest_acc_en, int ITERATIONS>
inline void calculate_silu() {
#pragma GCC unroll 8
    for (int d = 0; d < ITERATIONS; d++) {
        sfpi::vFloat x = sfpi::dst_reg[0];

        // silu(x) = x * sigmoid(x)
        sfpi::vFloat result = x * _sfpu_sigmoid_<is_fp32_dest_acc_en>(x);

        // Round to bfloat16 if not in fp32 accumulation mode
        if constexpr (!is_fp32_dest_acc_en) {
            result = sfpi::reinterpret<sfpi::vFloat>(sfpi::float_to_fp16b(result, 0));
        }

        sfpi::dst_reg[0] = result;
        sfpi::dst_reg++;
    }
}

template <bool APPROXIMATION_MODE>
inline void silu_init() {
    if constexpr (!APPROXIMATION_MODE) {
        _init_sfpu_reciprocal_<false>();
    } else {
        _init_sfpu_reciprocal_<true>();
    }
}

Looking at calculate_silu, we see that it uses _sfpu_sigmoid_. This is the non-approx call of sigmoid. This means that silu should always init based off of non-approx sigmoid, however looking at the silu_init function, we see that it toggles behaviour depending on approx mode flag. This is incorrect. We also shouldn't be manually calling the reciprocal init here, and should just call sigmoid_init<false>.

I will be posting a fix for BH, but I believe this also affects WH silu as well. For WH silu we need may need to be cautious of the initial issue #35038, as the WH init depends on fp32 flag, but BH does not.

Metadata

Metadata

Type

Projects

Status

✅ Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions