Skip to content

Commit 0531ca3

Browse files
committed
tests: lnorm: update reference implementations to use RMS norm
1 parent dc667a3 commit 0531ca3

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/benchdnn/lnorm/ref_lnorm.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2024 Intel Corporation
2+
* Copyright 2019-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
3434

3535
const bool use_sc = prb->use_sc();
3636
const bool use_sh = prb->use_sh();
37+
const bool use_rms_norm = prb->use_rms_norm();
3738

3839
const bool has_src_scale = !prb->attr.scales.get(DNNL_ARG_SRC).is_def();
3940
const bool has_dst_scale = !prb->attr.scales.get(DNNL_ARG_DST).is_def();
@@ -48,7 +49,7 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
4849
prb->ndims, dnnl_layer_normalization);
4950

5051
benchdnn_parallel_nd(prb->n, [&](int64_t n) {
51-
float smean = mean.get_elem(n);
52+
float smean = use_rms_norm ? 0.f : mean.get_elem(n);
5253
float svar = var.get_elem(n);
5354
float sqrt_var = sqrtf(svar + prb->eps);
5455

@@ -79,14 +80,15 @@ void compute_ref_bwd(const prb_t *prb, const args_t &args) {
7980

8081
const bool use_sc = prb->use_sc();
8182
const bool use_sh = prb->use_sh();
83+
const bool use_rms_norm = prb->use_rms_norm();
8284

8385
if ((use_sc || use_sh) && (prb->dir & FLAG_WEI)) {
8486
benchdnn_parallel_nd(prb->c, [&](int64_t c) {
8587
float d_gamma = 0;
8688
float d_beta = 0;
8789

8890
for (int64_t n = 0; n < prb->n; ++n) {
89-
float smean = mean.get_elem(n);
91+
float smean = use_rms_norm ? 0.f : mean.get_elem(n);
9092
float svar = var.get_elem(n);
9193
float rcp_denom = 1.f / sqrtf(svar + prb->eps);
9294
auto off = n * prb->c + c;
@@ -101,7 +103,7 @@ void compute_ref_bwd(const prb_t *prb, const args_t &args) {
101103
}
102104

103105
benchdnn_parallel_nd(prb->n, [&](int64_t n) {
104-
float smean = mean.get_elem(n);
106+
float smean = use_rms_norm ? 0.0f : mean.get_elem(n);
105107
float svar = var.get_elem(n);
106108
float rcp_denom = 1.f / sqrtf(svar + prb->eps);
107109
float dd_gamma = 0, dd_gamma_x = 0;

0 commit comments

Comments
 (0)