Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/common/transformations/include/ov_ops/rms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
/// \param output_type Output element type
RMS(const Output<Node>& data,
const Output<Node>& gamma,
double epsilson,
double epsilon,
const ov::element::Type output_type = ov::element::dynamic);

/// @brief Constructs an RMS operation without gamma.
///
/// @param data Input tensor with data
/// @param eps Epsilon for not dividing by zero while normalizing the value
/// @param output_type Output element type
RMS(const Output<Node>& data, double epsilon, const ov::element::Type output_type = ov::element::dynamic);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
Expand All @@ -47,9 +54,18 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
m_output_type = output_type;
}

bool get_elementwise_affine() const {
return m_elementwise_affine;
}

void set_elementwise_affine(bool elementwise_affine) {
m_elementwise_affine = elementwise_affine;
}

private:
double m_epsilon{0};
ov::element::Type m_output_type;
bool m_elementwise_affine{true};
};

} // namespace internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace pass {
class RMSFusion : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("RMSFusion");
RMSFusion(bool force_tail_convert = true, bool enable_div_x = false);
RMSFusion(bool force_tail_convert = true, bool enable_div_x = false, bool enable_without_gamma = false);
};

} // namespace pass
Expand Down
19 changes: 16 additions & 3 deletions src/common/transformations/src/ov_ops/rms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,26 @@ namespace ov {
namespace op {
namespace internal {

RMS::RMS(const Output<Node>& data, const Output<Node>& gamma, double epsilson, const ov::element::Type output_type)
RMS::RMS(const Output<Node>& data, const Output<Node>& gamma, double epsilon, const ov::element::Type output_type)
: Op({data, gamma}),
m_epsilon(epsilson),
m_output_type(output_type) {
m_epsilon(epsilon),
m_output_type(output_type),
m_elementwise_affine(true) {
validate_and_infer_types();
}

RMS::RMS(const Output<Node>& data, double epsilon, const ov::element::Type output_type)
: Op({data}),
m_epsilon(epsilon),
m_output_type(output_type),
m_elementwise_affine(false) {
validate_and_infer_types();
}

bool RMS::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("epsilon", m_epsilon);
visitor.on_attribute("output_type", m_output_type);
visitor.on_attribute("elementwise_affine", m_elementwise_affine);
return true;
}

Expand All @@ -28,6 +38,9 @@ void RMS::validate_and_infer_types() {

std::shared_ptr<Node> RMS::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);
if (new_args.size() == 1) {
return std::make_shared<RMS>(new_args.at(0), m_epsilon, m_output_type);
}
return std::make_shared<RMS>(new_args.at(0), new_args.at(1), m_epsilon, m_output_type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::function<bool(ov::Output<ov::Node>)> constant_value(const float target_valu
}
} // namespace

RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_without_gamma) {
// Detect RMS decomposition pattern
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
auto x = pattern::any_input();
Expand Down Expand Up @@ -87,16 +87,28 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
mul_or_div = std::make_shared<pattern::op::Or>(OutputVector{mul1});
}

// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
// Pattern 1: RMS with gamma (learnable parameter)
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma (gamma is constant)
auto gamma = pattern::wrap_type<v0::Constant>();
auto gamma_convert = pattern::optional<v0::Convert>(gamma);
auto mul_with_gamma = pattern::wrap_type<v1::Multiply>({gamma_convert, mul_or_div});

std::shared_ptr<ov::Node> rms_mul;
if (enable_without_gamma) {
// Pattern 2: RMS without gamma, but multiplied with dynamic input
// RMS(x) * scale where scale is non-constant (e.g., gate, activation, residual)
// This allows partial fusion: only fuse up to mul_or_div
auto scale = pattern::any_input(pattern::class_other_than<v0::Constant>());
auto mul_with_scale = pattern::wrap_type<v1::Multiply>({mul_or_div, scale});
rms_mul = std::make_shared<pattern::op::Or>(OutputVector{mul_with_gamma, mul_with_scale});
} else {
rms_mul = mul_with_gamma;
}

auto mul2 = pattern::wrap_type<v1::Multiply>({gamma_convert, mul_or_div});

std::shared_ptr<ov::Node> comp = mul2;
std::shared_ptr<ov::Node> comp = rms_mul;
if (force_tail_convert) {
// compress RMS result
comp = pattern::wrap_type<v0::Convert>({mul2});
comp = pattern::wrap_type<v0::Convert>({rms_mul});
}

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
Expand All @@ -114,9 +126,15 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
return false;
}

auto gamma_node = pattern_map.at(gamma).get_node_shared_ptr();
if (pattern_map.find(gamma_convert) != pattern_map.end()) {
gamma_node = pattern_map.at(gamma_convert).get_node_shared_ptr();
auto mul_or_div_node = pattern_map.at(mul_or_div).get_node_shared_ptr();
bool elementwise_affine = pattern_map.count(mul_with_gamma);

std::shared_ptr<ov::Node> gamma_node;
if (elementwise_affine) {
gamma_node = pattern_map.at(gamma).get_node_shared_ptr();
if (pattern_map.count(gamma_convert)) {
gamma_node = pattern_map.at(gamma_convert).get_node_shared_ptr();
}
}

const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr();
Expand All @@ -129,11 +147,27 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
return false;
}

auto output_type = m.get_match_root()->get_output_element_type(0);
auto rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type);
rms->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), rms);
ov::replace_node(m.get_match_root(), rms);
auto output_type = elementwise_affine ? m.get_match_root()->get_output_element_type(0)
: mul_or_div_node->get_output_element_type(0);
std::shared_ptr<ov::op::internal::RMS> rms =
elementwise_affine ? std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type)
: std::make_shared<ov::op::internal::RMS>(x_output, eps_value, output_type);
if (elementwise_affine) {
rms->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), rms);
ov::replace_node(m.get_match_root(), rms);
} else {
rms->set_friendly_name(mul_or_div_node->get_friendly_name());
NodeVector nodes_to_fuse;
auto mul_with_scale_node = m.get_match_root();
for (const auto& matched_node : m.get_matched_nodes()) {
if (matched_node != mul_with_scale_node) {
nodes_to_fuse.push_back(matched_node);
}
}
ov::copy_runtime_info(nodes_to_fuse, rms);
ov::replace_node(mul_or_div_node, rms);
}

return true;
};
Expand Down
Loading
Loading