|
11 | 11 | #include <cassert> |
12 | 12 | #include <cmath> |
13 | 13 | #include <cstddef> |
14 | | -#include <limits> |
15 | 14 | #include <map> |
16 | 15 | #include <memory> |
17 | 16 | #include <oneapi/dnnl/dnnl.hpp> |
|
46 | 45 | #include "openvino/core/node.hpp" |
47 | 46 | #include "openvino/core/shape.hpp" |
48 | 47 | #include "openvino/core/type.hpp" |
49 | | -#include "openvino/core/type/bfloat16.hpp" |
50 | 48 | #include "openvino/core/type/element_type.hpp" |
51 | 49 | #include "openvino/op/abs.hpp" |
52 | 50 | #include "openvino/op/add.hpp" |
@@ -546,20 +544,25 @@ bool Eltwise::isWithBroadcast() { |
546 | 544 | } |
547 | 545 |
|
548 | 546 | void Eltwise::init() { |
549 | | - // Bf16 saturation handling for gamma parameter when input precision is bf16 to make sure it stays within the valid |
550 | | - // range for bfloat16. |
| 547 | + // Bf16 saturation handling for PowerStatic parameters |
| 548 | + // to make sure they stay within the valid range for bfloat16. |
551 | 549 | if (m_attrs.data.algo == Algorithm::EltwisePowerStatic && getOriginalInputPrecisionAtPort(0) == ov::element::bf16) { |
552 | | - const float lowest = static_cast<float>(std::numeric_limits<ov::bfloat16>::lowest()); |
553 | | - const float max = static_cast<float>(std::numeric_limits<ov::bfloat16>::max()); |
554 | | - auto& gamma = m_attrs.data.gamma; |
555 | | - |
556 | | - if (gamma < lowest) { |
557 | | - gamma = lowest; |
558 | | - } |
| 550 | + // Use the actual float values corresponding to bfloat16 limits |
| 551 | + // 0xFF7F = -65504.0F (lowest), 0x7F7F = 65504.0F (max) |
| 552 | + static constexpr float bf16_lowest = -65504.0F; |
| 553 | + static constexpr float bf16_max = 65504.0F; |
| 554 | + |
| 555 | + // Helper lambda to clamp parameter values within bf16 range |
| 556 | + auto clampBf16Parameter = [&](auto& param) { |
| 557 | + if (std::isfinite(param)) { |
| 558 | + param = std::clamp(static_cast<float>(param), bf16_lowest, bf16_max); |
| 559 | + } |
| 560 | + }; |
559 | 561 |
|
560 | | - if (gamma > max) { |
561 | | - gamma = max; |
562 | | - } |
| 562 | + // Clamp all PowerStatic parameters |
| 563 | + clampBf16Parameter(m_attrs.data.alpha); |
| 564 | + clampBf16Parameter(m_attrs.data.beta); |
| 565 | + clampBf16Parameter(m_attrs.data.gamma); |
563 | 566 | } |
564 | 567 | } |
565 | 568 |
|
|
0 commit comments