11// //////////////////////////////////////////////////////////////////////////////
2- // Copyright (c) 2014-2023 , Lawrence Livermore National Security, LLC.
2+ // Copyright (c) 2014-2024 , Lawrence Livermore National Security, LLC.
33// Produced at the Lawrence Livermore National Laboratory.
44// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
55// the CONTRIBUTORS file. <[email protected] >2828#include " layer_norm_kernels.cuh"
2929#include " lbann/comm_impl.hpp"
3030#include " lbann/layers/regularizers/layer_norm.hpp"
31- #include " lbann/optimizers/optimizer.hpp"
3231#include " lbann/layers/regularizers/layer_norm_impl.hpp"
32+ #include " lbann/optimizers/optimizer.hpp"
3333#include " lbann/utils/gpu/helpers.hpp"
3434
3535#ifdef LBANN_HAS_DISTCONV
@@ -654,12 +654,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
654654template <typename TensorDataType, data_layout Layout, El::Device Device>
655655void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
656656{
657- #ifdef LBANN_HAS_DISTCONV
657+ #ifdef LBANN_HAS_DISTCONV
658658 if (this ->distconv_enabled ()) {
659659 this ->get_distconv_adapter ().fp_compute ();
660660 return ;
661661 }
662- #endif // LBANN_HAS_DISTCONV
662+ #endif // LBANN_HAS_DISTCONV
663663
664664 int weight_idx = 0 ;
665665 const TensorDataType* scale_weights = nullptr ;
@@ -671,7 +671,6 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
671671 bias_weights =
672672 this ->weights_values (weight_idx).LockedMatrix ().LockedBuffer ();
673673
674-
675674 fp_impl (*this ->get_comm (),
676675 this ->m_epsilon ,
677676 this ->get_prev_activations (),
@@ -684,13 +683,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
684683template <typename TensorDataType, data_layout Layout, El::Device Device>
685684void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
686685{
687- #ifdef LBANN_HAS_DISTCONV
686+ #ifdef LBANN_HAS_DISTCONV
688687 if (this ->distconv_enabled ()) {
689688 this ->get_distconv_adapter ().bp_compute ();
690689 return ;
691690 }
692- #endif // LBANN_HAS_DISTCONV
693-
691+ #endif // LBANN_HAS_DISTCONV
692+
694693 // Obtain optional buffers
695694 const TensorDataType* scale_weights = nullptr ;
696695 TensorDataType* scale_grad = nullptr ;
0 commit comments