Skip to content

Commit ecac28c

Browse files
committed
Updating layer norm impl
1 parent 01a361a commit ecac28c

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

src/layers/regularizers/layer_norm.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
////////////////////////////////////////////////////////////////////////////////
2-
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
2+
// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
33
// Produced at the Lawrence Livermore National Laboratory.
44
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
55
// the CONTRIBUTORS file. <[email protected]>
@@ -28,8 +28,8 @@
2828
#include "layer_norm_kernels.cuh"
2929
#include "lbann/comm_impl.hpp"
3030
#include "lbann/layers/regularizers/layer_norm.hpp"
31-
#include "lbann/optimizers/optimizer.hpp"
3231
#include "lbann/layers/regularizers/layer_norm_impl.hpp"
32+
#include "lbann/optimizers/optimizer.hpp"
3333
#include "lbann/utils/gpu/helpers.hpp"
3434

3535
#ifdef LBANN_HAS_DISTCONV
@@ -654,12 +654,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
654654
template <typename TensorDataType, data_layout Layout, El::Device Device>
655655
void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
656656
{
657-
#ifdef LBANN_HAS_DISTCONV
657+
#ifdef LBANN_HAS_DISTCONV
658658
if (this->distconv_enabled()) {
659659
this->get_distconv_adapter().fp_compute();
660660
return;
661661
}
662-
#endif // LBANN_HAS_DISTCONV
662+
#endif // LBANN_HAS_DISTCONV
663663

664664
int weight_idx = 0;
665665
const TensorDataType* scale_weights = nullptr;
@@ -671,7 +671,6 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
671671
bias_weights =
672672
this->weights_values(weight_idx).LockedMatrix().LockedBuffer();
673673

674-
675674
fp_impl(*this->get_comm(),
676675
this->m_epsilon,
677676
this->get_prev_activations(),
@@ -684,13 +683,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
684683
template <typename TensorDataType, data_layout Layout, El::Device Device>
685684
void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
686685
{
687-
#ifdef LBANN_HAS_DISTCONV
686+
#ifdef LBANN_HAS_DISTCONV
688687
if (this->distconv_enabled()) {
689688
this->get_distconv_adapter().bp_compute();
690689
return;
691690
}
692-
#endif // LBANN_HAS_DISTCONV
693-
691+
#endif // LBANN_HAS_DISTCONV
692+
694693
// Obtain optional buffers
695694
const TensorDataType* scale_weights = nullptr;
696695
TensorDataType* scale_grad = nullptr;

0 commit comments

Comments
 (0)