Skip to content

Commit d060d7a

Browse files
[GPU] Extend ClampFP16Output pass to support clippling input for RMS (openvinotoolkit#29744)
### Details: - Extend `ClampFP16Output` pass to add clamp primitive between `Add` and `RMS` operation which is targeting language model in VLM models which may have an fp16 overflow on Add output tensor which could lead to Inf Value and affecting the result of RMS ![image](https://github.com/user-attachments/assets/70624d20-b9dc-405e-a3ff-993365ec3f0c) ### Tickets: - 164349
1 parent 7970784 commit d060d7a

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

src/plugins/intel_gpu/src/plugin/transformations/clamp_fp16_output.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "clamp_fp16_output.hpp"
66

7-
#include "openvino/core/rt_info.hpp"
7+
#include "ov_ops/rms.hpp"
88
#include "openvino/op/clamp.hpp"
99
#include "openvino/op/constant.hpp"
1010
#include "openvino/op/matmul.hpp"
@@ -14,6 +14,7 @@
1414
#include "openvino/op/multiply.hpp"
1515
#include "openvino/op/subtract.hpp"
1616
#include "openvino/op/divide.hpp"
17+
#include "openvino/core/rt_info.hpp"
1718
#include "openvino/pass/pattern/op/pattern.hpp"
1819
#include "openvino/pass/pattern/op/wrap_type.hpp"
1920
#include "openvino/pass/pattern/op/or.hpp"
@@ -24,6 +25,11 @@
2425
namespace ov::intel_gpu {
2526

2627
ClampFP16Output::ClampFP16Output() {
28+
add_matcher<ClampFP16OutputSoftmaxMatcher>();
29+
add_matcher<ClampFP16OutputRMSMatcher>();
30+
}
31+
32+
ClampFP16OutputSoftmaxMatcher::ClampFP16OutputSoftmaxMatcher() {
2733
using namespace ov::op;
2834
using namespace ov::pass::pattern;
2935
using namespace ov::pass::pattern::op;
@@ -58,7 +64,39 @@ ClampFP16Output::ClampFP16Output() {
5864
return true;
5965
};
6066

61-
auto m = std::make_shared<ov::pass::pattern::Matcher>(softmax_m, "ClampFP16Output");
67+
auto m = std::make_shared<ov::pass::pattern::Matcher>(softmax_m, "ClampFP16OutputSoftmaxMatcher");
68+
this->register_matcher(m, callback);
69+
}
70+
71+
ClampFP16OutputRMSMatcher::ClampFP16OutputRMSMatcher() {
72+
using namespace ov::pass::pattern;
73+
74+
auto add_m = wrap_type<ov::op::v1::Add>({any_input(), any_input()}, type_matches(element::f16));
75+
auto rms_post_m = wrap_type<ov::op::internal::RMS>({any_input(), wrap_type<ov::op::v0::Constant>()}, type_matches(element::f16));
76+
auto add_1_m = wrap_type<ov::op::v1::Add>({add_m, rms_post_m}, type_matches(element::f16));
77+
auto rms_m = wrap_type<ov::op::internal::RMS>({add_1_m, wrap_type<ov::op::v0::Constant>()}, type_matches(element::f16));
78+
79+
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
80+
const auto& pattern_map = m.get_pattern_value_map();
81+
auto rms = ov::as_type_ptr<ov::op::internal::RMS>(pattern_map.at(rms_m).get_node_shared_ptr());
82+
if (!rms || transformation_callback(rms)) {
83+
return false;
84+
}
85+
86+
auto add_1 = pattern_map.at(add_1_m).get_node_shared_ptr();
87+
88+
auto min = static_cast<double>(std::numeric_limits<ov::float16>::lowest());
89+
auto max = static_cast<double>(std::numeric_limits<ov::float16>::max());
90+
auto clamp = std::make_shared<ov::op::v0::Clamp>(rms->get_input_source_output(0), min, max);
91+
clamp->set_friendly_name(add_1->get_friendly_name() + "/ClampFP16Output");
92+
ov::copy_runtime_info(add_1, clamp);
93+
94+
rms->input(0).replace_source_output(clamp);
95+
96+
return true;
97+
};
98+
99+
auto m = std::make_shared<ov::pass::pattern::Matcher>(rms_m, "ClampFP16OutputRMSMatcher");
62100
this->register_matcher(m, callback);
63101
}
64102

src/plugins/intel_gpu/src/plugin/transformations/clamp_fp16_output.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,28 @@
99

1010
namespace ov::intel_gpu {
1111

12+
class ClampFP16Output: public ov::pass::GraphRewrite {
13+
public:
14+
OPENVINO_GRAPH_REWRITE_RTTI("ClampFP16Output");
15+
ClampFP16Output();
16+
};
17+
1218
/**
1319
* @brief This transformation adds Clamp primitive between MatMul and Softmax operation
1420
* which is targeting some transformer based models (mainly Stable Diffusion) which may have an fp16 overflow
1521
* on MatMul output tensor which could lead to Inf/Nan values on the model output.
1622
* We assume that Clamp operation handling costs almost nothing from the performance perspective as it's supposed to be fused to MatMul later
1723
*/
18-
class ClampFP16Output: public ov::pass::MatcherPass {
24+
class ClampFP16OutputSoftmaxMatcher: public ov::pass::MatcherPass {
1925
public:
20-
OPENVINO_MATCHER_PASS_RTTI("ov::intel_gpu::ClampFP16Output");
26+
OPENVINO_MATCHER_PASS_RTTI("ClampFP16OutputSoftmaxMatcher");
27+
ClampFP16OutputSoftmaxMatcher();
28+
};
2129

22-
ClampFP16Output();
30+
class ClampFP16OutputRMSMatcher: public ov::pass::MatcherPass {
31+
public:
32+
OPENVINO_MATCHER_PASS_RTTI("ClampFP16OutputRMSMatcher");
33+
ClampFP16OutputRMSMatcher();
2334
};
2435

2536
} // namespace ov::intel_gpu

src/plugins/intel_gpu/tests/unit/transformations/clamp_fp16_output_test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "openvino/core/coordinate_diff.hpp"
1414
#include "openvino/core/type/element_type.hpp"
1515
#include <openvino/op/constant.hpp>
16+
#include "ov_ops/rms.hpp"
1617
#include "openvino/op/clamp.hpp"
1718
#include "openvino/op/reshape.hpp"
1819
#include "openvino/op/add.hpp"
@@ -157,3 +158,37 @@ TEST_F(TransformationTestsF, ClampFp16OutputTest6) {
157158
}
158159
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
159160
}
161+
162+
TEST_F(TransformationTestsF, ClampFp16OutputRMS) {
163+
{
164+
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 2560 });
165+
auto input2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 2560 });
166+
auto add = std::make_shared<ov::op::v1::Add>(input1, input2);
167+
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 2560 });
168+
auto gamma1 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 2560 }, {1});
169+
auto rms_post = std::make_shared<ov::op::internal::RMS>(data, gamma1, 1e-5f, ov::element::f16);
170+
auto add1 = std::make_shared<ov::op::v1::Add>(add, rms_post);
171+
auto gamma2 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 2560 }, {1});
172+
auto rms = std::make_shared<ov::op::internal::RMS>(add1, gamma2, 1e-5f, ov::element::f16);
173+
174+
model = std::make_shared<ov::Model>(ov::NodeVector{ rms }, ov::ParameterVector{ input1, input2, data });
175+
manager.register_pass<ClampFP16Output>();
176+
}
177+
{
178+
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 2560 });
179+
auto input2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 2560 });
180+
auto add = std::make_shared<ov::op::v1::Add>(input1, input2);
181+
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 2560 });
182+
auto gamma1 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 2560 }, {1});
183+
auto rms_post = std::make_shared<ov::op::internal::RMS>(data, gamma1, 1e-5f, ov::element::f16);
184+
auto add1 = std::make_shared<ov::op::v1::Add>(add, rms_post);
185+
auto min = static_cast<double>(std::numeric_limits<ov::float16>::lowest());
186+
auto max = static_cast<double>(std::numeric_limits<ov::float16>::max());
187+
auto clamp = std::make_shared<ov::op::v0::Clamp>(add1, min, max);
188+
auto gamma2 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 2560 }, {1});
189+
auto rms = std::make_shared<ov::op::internal::RMS>(clamp, gamma2, 1e-5f, ov::element::f16);
190+
191+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rms }, ov::ParameterVector{ input1, input2, data });
192+
}
193+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
194+
}

0 commit comments

Comments
 (0)