Skip to content

Commit c9733df

Browse files
committed
Add enable_without_gamma flag to restrict pattern by plugin
1 parent 8de0bf5 commit c9733df

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

src/common/transformations/include/transformations/common_optimizations/rms_fusion.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace pass {
3030
class RMSFusion : public ov::pass::MatcherPass {
3131
public:
3232
OPENVINO_MATCHER_PASS_RTTI("RMSFusion");
33-
RMSFusion(bool force_tail_convert = true, bool enable_div_x = false);
33+
RMSFusion(bool force_tail_convert = true, bool enable_div_x = false, bool enable_without_gamma = false);
3434
};
3535

3636
} // namespace pass

src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ std::function<bool(ov::Output<ov::Node>)> constant_value(const float target_valu
4242
}
4343
} // namespace
4444

45-
RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
45+
RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_without_gamma) {
4646
// Detect RMS decomposition pattern
4747
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
4848
auto x = pattern::any_input();
@@ -93,13 +93,17 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
9393
auto gamma_convert = pattern::optional<v0::Convert>(gamma);
9494
auto mul_with_gamma = pattern::wrap_type<v1::Multiply>({gamma_convert, mul_or_div});
9595

96-
// Pattern 2: RMS without gamma, but multiplied with dynamic input
97-
// RMS(x) * scale where scale is non-constant (e.g., gate, activation, residual)
98-
// This allows partial fusion: only fuse up to mul_or_div
99-
auto scale = pattern::any_input(pattern::class_other_than<v0::Constant>());
100-
auto mul_with_scale = pattern::wrap_type<v1::Multiply>({mul_or_div, scale});
101-
102-
auto rms_mul = std::make_shared<pattern::op::Or>(OutputVector{mul_with_gamma, mul_with_scale});
96+
std::shared_ptr<ov::Node> rms_mul;
97+
if (enable_without_gamma) {
98+
// Pattern 2: RMS without gamma, but multiplied with dynamic input
99+
// RMS(x) * scale where scale is non-constant (e.g., gate, activation, residual)
100+
// This allows partial fusion: only fuse up to mul_or_div
101+
auto scale = pattern::any_input(pattern::class_other_than<v0::Constant>());
102+
auto mul_with_scale = pattern::wrap_type<v1::Multiply>({mul_or_div, scale});
103+
rms_mul = std::make_shared<pattern::op::Or>(OutputVector{mul_with_gamma, mul_with_scale});
104+
} else {
105+
rms_mul = mul_with_gamma;
106+
}
103107

104108
std::shared_ptr<ov::Node> comp = rms_mul;
105109
if (force_tail_convert) {

src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest10) {
341341
auto mul2 = std::make_shared<ov::op::v1::Multiply>(mul1, scale);
342342

343343
model = std::make_shared<ov::Model>(ov::OutputVector{mul2}, ov::ParameterVector{input, scale});
344-
manager.register_pass<RMSFusion>(false);
344+
manager.register_pass<RMSFusion>(false, false, true);
345345
}
346346
{
347347
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 2, 6});
@@ -373,7 +373,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest11) {
373373
auto mul2 = std::make_shared<ov::op::v1::Multiply>(mul1, scale);
374374

375375
model = std::make_shared<ov::Model>(ov::OutputVector{mul2}, ov::ParameterVector{input, scale});
376-
manager.register_pass<RMSFusion>(false);
376+
manager.register_pass<RMSFusion>(false, false, true);
377377
}
378378
{
379379
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});

src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
566566
const int32_t vec_size = 8;
567567
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
568568
});
569-
manager.register_pass<ov::pass::RMSFusion>(false, true);
569+
manager.register_pass<ov::pass::RMSFusion>(false, true, true);
570570
manager.register_pass<DisableFP16CompForGemma3RMSPattern>();
571571
manager.register_pass<DisableFP16ComForGPTOSSROPEPattern>();
572572
manager.register_pass<DisableFP16ComSinGenPatternForHiFiGAN>();

0 commit comments

Comments
 (0)