11// //////////////////////////////////////////////////////////////////////////////
2- // Copyright (c) 2014-2023 , Lawrence Livermore National Security, LLC.
2+ // Copyright (c) 2014-2024 , Lawrence Livermore National Security, LLC.
33// Produced at the Lawrence Livermore National Laboratory.
44// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
55// the CONTRIBUTORS file. <[email protected] >2828#include " layer_norm_kernels.cuh"
2929#include " lbann/comm_impl.hpp"
3030#include " lbann/layers/regularizers/layer_norm.hpp"
31- #include " lbann/optimizers/optimizer.hpp"
3231#include " lbann/layers/regularizers/layer_norm_impl.hpp"
32+ #include " lbann/optimizers/optimizer.hpp"
3333#include " lbann/utils/gpu/helpers.hpp"
3434
3535#ifdef LBANN_HAS_DISTCONV
@@ -556,12 +556,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
556556template <typename TensorDataType, data_layout Layout, El::Device Device>
557557void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
558558{
559- #ifdef LBANN_HAS_DISTCONV
559+ #ifdef LBANN_HAS_DISTCONV
560560 if (this ->distconv_enabled ()) {
561561 this ->get_distconv_adapter ().fp_compute ();
562562 return ;
563563 }
564- #endif // LBANN_HAS_DISTCONV
564+ #endif // LBANN_HAS_DISTCONV
565565
566566 int weight_idx = 0 ;
567567 const TensorDataType* scale_weights = nullptr ;
@@ -575,6 +575,7 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
575575 El::Int norm_size, global_norm_size, num_norm, norm_stride;
576576 this ->get_normdims (norm_size, global_norm_size, num_norm, norm_stride);
577577
578+ <<<<<<< HEAD
578579<<<<<<< HEAD
579580#ifdef LBANN_HAS_DISTCONV
580581 if (this ->distconv_enabled ()) {
@@ -585,6 +586,8 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
585586=======
586587
587588>>> >>> > f02146109 (Updated implementation with updating statistics tensors)
589+ =======
590+ >>>>>>> ecac28c9f (Updating layer norm impl)
588591 fp_impl(*this ->get_comm (),
589592 this->m_epsilon,
590593 norm_size,
@@ -601,13 +604,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
601604template <typename TensorDataType, data_layout Layout, El::Device Device>
602605void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
603606{
604- #ifdef LBANN_HAS_DISTCONV
607+ #ifdef LBANN_HAS_DISTCONV
605608 if (this ->distconv_enabled ()) {
606609 this ->get_distconv_adapter ().bp_compute ();
607610 return ;
608611 }
609- #endif // LBANN_HAS_DISTCONV
610-
612+ #endif // LBANN_HAS_DISTCONV
613+
611614 // Obtain optional buffers
612615 const TensorDataType* scale_weights = nullptr ;
613616 TensorDataType* scale_grad = nullptr ;
0 commit comments