Skip to content

Commit 021d2a7

Browse files
committed
Updating layer norm impl
1 parent 9e44581 commit 021d2a7

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/layers/regularizers/layer_norm.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
////////////////////////////////////////////////////////////////////////////////
2-
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
2+
// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
33
// Produced at the Lawrence Livermore National Laboratory.
44
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
55
// the CONTRIBUTORS file. <[email protected]>
@@ -28,8 +28,8 @@
2828
#include "layer_norm_kernels.cuh"
2929
#include "lbann/comm_impl.hpp"
3030
#include "lbann/layers/regularizers/layer_norm.hpp"
31-
#include "lbann/optimizers/optimizer.hpp"
3231
#include "lbann/layers/regularizers/layer_norm_impl.hpp"
32+
#include "lbann/optimizers/optimizer.hpp"
3333
#include "lbann/utils/gpu/helpers.hpp"
3434

3535
#ifdef LBANN_HAS_DISTCONV
@@ -556,12 +556,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
556556
template <typename TensorDataType, data_layout Layout, El::Device Device>
557557
void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
558558
{
559-
#ifdef LBANN_HAS_DISTCONV
559+
#ifdef LBANN_HAS_DISTCONV
560560
if (this->distconv_enabled()) {
561561
this->get_distconv_adapter().fp_compute();
562562
return;
563563
}
564-
#endif // LBANN_HAS_DISTCONV
564+
#endif // LBANN_HAS_DISTCONV
565565

566566
int weight_idx = 0;
567567
const TensorDataType* scale_weights = nullptr;
@@ -575,6 +575,7 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
575575
El::Int norm_size, global_norm_size, num_norm, norm_stride;
576576
this->get_normdims(norm_size, global_norm_size, num_norm, norm_stride);
577577

578+
<<<<<<< HEAD
578579
<<<<<<< HEAD
579580
#ifdef LBANN_HAS_DISTCONV
580581
if (this->distconv_enabled()) {
@@ -585,6 +586,8 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
585586
=======
586587
587588
>>>>>>> f02146109 (Updated implementation with updating statistics tensors)
589+
=======
590+
>>>>>>> ecac28c9f (Updating layer norm impl)
588591
fp_impl(*this->get_comm(),
589592
this->m_epsilon,
590593
norm_size,
@@ -601,13 +604,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
601604
template <typename TensorDataType, data_layout Layout, El::Device Device>
602605
void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
603606
{
604-
#ifdef LBANN_HAS_DISTCONV
607+
#ifdef LBANN_HAS_DISTCONV
605608
if (this->distconv_enabled()) {
606609
this->get_distconv_adapter().bp_compute();
607610
return;
608611
}
609-
#endif // LBANN_HAS_DISTCONV
610-
612+
#endif // LBANN_HAS_DISTCONV
613+
611614
// Obtain optional buffers
612615
const TensorDataType* scale_weights = nullptr;
613616
TensorDataType* scale_grad = nullptr;

0 commit comments

Comments
 (0)