Skip to content

Commit cb5fb50

Browse files
committed
Add RMS operator constructor without gamma to avoid dummy constant creation
1 parent 4f171cd commit cb5fb50

File tree

13 files changed

+108
-70
lines changed

13 files changed

+108
-70
lines changed

src/common/transformations/include/ov_ops/rms.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ class TRANSFORMATIONS_API RMS : public ov::op::Op {
3131
const ov::element::Type output_type = ov::element::dynamic,
3232
bool elementwise_affine = true);
3333

34+
/// @brief Constructs an RMS operation without gamma.
35+
///
36+
/// @param data Input tensor with data
37+
/// @param eps Epsilon for not dividing by zero while normalizing the value
38+
/// @param output_type Output element type
39+
/// @param elementwise_affine A boolean value that when set to True, RMS has learnable affine parameters
40+
RMS(const Output<Node>& data,
41+
double epsilson,
42+
const ov::element::Type output_type = ov::element::dynamic,
43+
bool elementwise_affine = false);
44+
3445
bool visit_attributes(ov::AttributeVisitor& visitor) override;
3546

3647
void validate_and_infer_types() override;

src/common/transformations/src/ov_ops/rms.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ RMS::RMS(const Output<Node>& data,
2020
validate_and_infer_types();
2121
}
2222

23+
RMS::RMS(const Output<Node>& data,
24+
double epsilson,
25+
const ov::element::Type output_type,
26+
bool elementwise_affine)
27+
: Op({data}),
28+
m_epsilon(epsilson),
29+
m_output_type(output_type),
30+
m_elementwise_affine(elementwise_affine) {
31+
validate_and_infer_types();
32+
}
33+
2334
bool RMS::visit_attributes(ov::AttributeVisitor& visitor) {
2435
visitor.on_attribute("epsilon", m_epsilon);
2536
visitor.on_attribute("output_type", m_output_type);
@@ -34,6 +45,9 @@ void RMS::validate_and_infer_types() {
3445

3546
std::shared_ptr<Node> RMS::clone_with_new_inputs(const ov::OutputVector& new_args) const {
3647
check_new_args_count(this, new_args);
48+
if (new_args.size() == 1) {
49+
return std::make_shared<RMS>(new_args.at(0), m_epsilon, m_output_type);
50+
}
3751
return std::make_shared<RMS>(new_args.at(0), new_args.at(1), m_epsilon, m_output_type, m_elementwise_affine);
3852
}
3953

src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,6 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
133133
if (pattern_map.find(gamma_convert) != pattern_map.end()) {
134134
gamma_node = pattern_map.at(gamma_convert).get_node_shared_ptr();
135135
}
136-
} else {
137-
auto input_shape = x_output.get_partial_shape();
138-
if (input_shape.rank().is_dynamic() || input_shape[input_shape.size() - 1].is_dynamic()) {
139-
return false;
140-
}
141-
auto last_dim = input_shape[input_shape.size() - 1].get_length();
142-
auto gamma_shape = ov::Shape{static_cast<size_t>(last_dim)};
143-
auto output_type = mul_or_div_node->get_output_element_type(0);
144-
gamma_node = v0::Constant::create(output_type, gamma_shape, {1.0f});
145136
}
146137

147138
const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr();
@@ -156,7 +147,12 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x) {
156147

157148
auto output_type =
158149
has_gamma ? m.get_match_root()->get_output_element_type(0) : mul_or_div_node->get_output_element_type(0);
159-
auto rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type, has_gamma);
150+
std::shared_ptr<ov::op::internal::RMS> rms;
151+
if (has_gamma) {
152+
rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type, true);
153+
} else {
154+
rms = std::make_shared<ov::op::internal::RMS>(x_output, eps_value, output_type, false);
155+
}
160156
if (has_gamma) {
161157
rms->set_friendly_name(m.get_match_root()->get_friendly_name());
162158
ov::copy_runtime_info(m.get_matched_nodes(), rms);

src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest10) {
345345
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{1, 2, 6});
346346
auto scale = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{1, 2, 6});
347347

348-
auto rms_const =
349-
ov::opset10::Constant::create(ov::element::f32, ov::Shape{6}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
350-
auto rms = std::make_shared<ov::op::internal::RMS>(input, rms_const, 1e-5f, ov::element::f32, false);
348+
auto rms = std::make_shared<ov::op::internal::RMS>(input, 1e-5f, ov::element::f32, false);
351349
auto mul = std::make_shared<ov::opset10::Multiply>(rms, scale);
352350

353351
model_ref = std::make_shared<ov::Model>(ov::OutputVector{mul}, ov::ParameterVector{input, scale});
@@ -379,9 +377,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest11) {
379377
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
380378
auto scale = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
381379

382-
auto rms_const =
383-
ov::opset10::Constant::create(ov::element::f32, ov::Shape{6}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
384-
auto rms = std::make_shared<ov::op::internal::RMS>(input, rms_const, 1e-6f, ov::element::f32, false);
380+
auto rms = std::make_shared<ov::op::internal::RMS>(input, 1e-6f, ov::element::f32, false);
385381
auto mul = std::make_shared<ov::opset10::Multiply>(rms, scale);
386382

387383
model_ref = std::make_shared<ov::Model>(ov::OutputVector{mul}, ov::ParameterVector{input, scale});

src/plugins/intel_gpu/include/intel_gpu/primitives/rms.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ struct rms : public primitive_base<rms> {
2929
epsilon(epsilon),
3030
elementwise_affine(elementwise_affine) {}
3131

32+
/// @brief Constructs rms primitive without gamma
33+
/// @param id This primitive id
34+
/// @param input Input primitive id
35+
/// @param epsilon Epsilon for not dividing by zero while normalizing
36+
/// @param elementwise_affine A boolean value that when set to True, RMS has learnable affine parameters
37+
rms(const primitive_id& id,
38+
const input_info& input,
39+
const float epsilon,
40+
const bool elementwise_affine = false)
41+
: primitive_base(id, {input}),
42+
epsilon(epsilon),
43+
elementwise_affine(elementwise_affine) {}
44+
3245
/// @brief Epsilon for not dividing by zero while normalizing
3346
float epsilon;
3447
/// @brief A boolean value that when set to True, RMS has learnable affine parameters (gamma)

src/plugins/intel_gpu/src/graph/impls/ocl/rms.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ struct rms_impl : typed_primitive_impl_ocl<rms> {
3636
const auto& primitive = impl_param.typed_desc<rms>();
3737
auto params = get_default_params<kernel_selector::rms_params>(impl_param, is_shape_agnostic);
3838

39-
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(1)));
39+
if (primitive->elementwise_affine) {
40+
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(1)));
41+
}
4042
params.epsilon = primitive->epsilon;
4143
params.ov_input_rank = static_cast<int32_t>(impl_param.get_input_layout().get_partial_shape().size());
4244
params.elementwise_affine = primitive->elementwise_affine;

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rms_gpu_bfyx_opt.cl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
2727
KERNEL(rms_gpu_bfyx_opt)(
2828
OPTIONAL_SHAPE_INFO_ARG
2929
const __global INPUT0_TYPE* input,
30+
#if ELEMENTWISE_AFFINE
3031
const __global INPUT1_TYPE* gamma,
32+
#endif
3133
__global OUTPUT_TYPE* output
3234
#if HAS_FUSED_OPS_DECLS
3335
, FUSED_OPS_DECLS

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rms_gpu_ref.cl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
KERNEL(rms_gpu_ref)(
88
OPTIONAL_SHAPE_INFO_ARG
99
const __global INPUT0_TYPE* input,
10+
#if ELEMENTWISE_AFFINE
1011
const __global INPUT1_TYPE* gamma,
12+
#endif
1113
__global OUTPUT_TYPE* output
1214
#if HAS_FUSED_OPS_DECLS
1315
, FUSED_OPS_DECLS

src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_base.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ KernelsData RMSKernelBase::GetCommonKernelsData(const Params& params) const {
6767
GetUpdateDispatchDataFunc(kd);
6868

6969
auto& kernel = kd.kernels[0];
70+
auto inputs_count = orgParams.elementwise_affine ? 2 : 1;
7071
FillCLKernelData(kernel,
7172
dispatchData,
7273
params.engineInfo,
@@ -76,7 +77,7 @@ KernelsData RMSKernelBase::GetCommonKernelsData(const Params& params) const {
7677
EXE_MODE_DEFAULT,
7778
false,
7879
false,
79-
2,
80+
inputs_count,
8081
GetFusedPrimitiveInputsCount(params),
8182
1,
8283
orgParams.is_shape_agnostic);

src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,14 @@ bool RMSKernelBfyxOpt::Validate(const Params& p) const {
197197
DO_NOT_USE_THIS_KERNEL(p.layerID);
198198

199199
const rms_params& params = static_cast<const rms_params&>(p);
200-
const auto& gamma = params.inputs[1];
201-
202-
if (!gamma.is_dynamic()) {
203-
size_t data_size = gamma.LogicalSize();
204-
if (data_size < subgroup_size) {
205-
DO_NOT_USE_THIS_KERNEL(p.layerID);
200+
if (params.elementwise_affine) {
201+
const auto& gamma = params.inputs[1];
202+
203+
if (!gamma.is_dynamic()) {
204+
size_t data_size = gamma.LogicalSize();
205+
if (data_size < subgroup_size) {
206+
DO_NOT_USE_THIS_KERNEL(p.layerID);
207+
}
206208
}
207209
}
208210
return true;

0 commit comments

Comments
 (0)