1
1
/* ******************************************************************************
2
- * Copyright 2019-2024 Intel Corporation
2
+ * Copyright 2019-2025 Intel Corporation
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
34
34
35
35
const bool use_sc = prb->use_sc ();
36
36
const bool use_sh = prb->use_sh ();
37
+ const bool use_rms_norm = prb->use_rms_norm ();
37
38
38
39
const bool has_src_scale = !prb->attr .scales .get (DNNL_ARG_SRC).is_def ();
39
40
const bool has_dst_scale = !prb->attr .scales .get (DNNL_ARG_DST).is_def ();
@@ -48,7 +49,7 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
48
49
prb->ndims , dnnl_layer_normalization);
49
50
50
51
benchdnn_parallel_nd (prb->n , [&](int64_t n) {
51
- float smean = mean.get_elem (n);
52
+ float smean = use_rms_norm ? 0 . f : mean.get_elem (n);
52
53
float svar = var.get_elem (n);
53
54
float sqrt_var = sqrtf (svar + prb->eps );
54
55
@@ -79,14 +80,15 @@ void compute_ref_bwd(const prb_t *prb, const args_t &args) {
79
80
80
81
const bool use_sc = prb->use_sc ();
81
82
const bool use_sh = prb->use_sh ();
83
+ const bool use_rms_norm = prb->use_rms_norm ();
82
84
83
85
if ((use_sc || use_sh) && (prb->dir & FLAG_WEI)) {
84
86
benchdnn_parallel_nd (prb->c , [&](int64_t c) {
85
87
float d_gamma = 0 ;
86
88
float d_beta = 0 ;
87
89
88
90
for (int64_t n = 0 ; n < prb->n ; ++n) {
89
- float smean = mean.get_elem (n);
91
+ float smean = use_rms_norm ? 0 . f : mean.get_elem (n);
90
92
float svar = var.get_elem (n);
91
93
float rcp_denom = 1 .f / sqrtf (svar + prb->eps );
92
94
auto off = n * prb->c + c;
@@ -101,7 +103,7 @@ void compute_ref_bwd(const prb_t *prb, const args_t &args) {
101
103
}
102
104
103
105
benchdnn_parallel_nd (prb->n , [&](int64_t n) {
104
- float smean = mean.get_elem (n);
106
+ float smean = use_rms_norm ? 0 . 0f : mean.get_elem (n);
105
107
float svar = var.get_elem (n);
106
108
float rcp_denom = 1 .f / sqrtf (svar + prb->eps );
107
109
float dd_gamma = 0 , dd_gamma_x = 0 ;
0 commit comments