Skip to content

Commit f021461

Browse files
committed
Updated implementation with updating statistics tensors
1 parent 7e16367 commit f021461

File tree

6 files changed

+367
-202
lines changed

6 files changed

+367
-202
lines changed

include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#ifndef LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
2828
#define LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
2929

30-
#if LBANN_HAS_DISTCONV
30+
#ifdef LBANN_HAS_DISTCONV
3131

3232
namespace distconv {
3333
template <typename Backend, typename DataType>
@@ -39,21 +39,17 @@ class LayerNormalization
3939
using DCTensor = tensor::Tensor<DataType, LocaleMPI, Allocator>;
4040

4141
public:
42-
LayerNormalization(Backend& backend,
43-
Datatype epsilon,
44-
size_t max_mini_batch_size)
45-
: m_backend(backend),
46-
m_epsilon(epsilon),
47-
m_max_mini_batch_size(max_mini_batch_size)
42+
LayerNormalization(Backend& backend, DataType epsilon)
43+
: m_backend(backend), m_epsilon(epsilon)
4844
{}
4945

5046
template <typename Allocator>
5147
void calculate_forward_stats(const DCTensor<Allocator>& input,
52-
DC<Allocator>& statistics);
48+
DCTensor<Allocator>& statistics);
5349

5450
template <typename Allocator>
5551
void apply_normalization(const DCTensor<Allocator>& input,
56-
const DCTensor<Allocator>& statistics,
52+
DCTensor<Allocator>& statistics,
5753
DCTensor<Allocator>& output);
5854

5955
template <typename Allocator>
@@ -74,10 +70,9 @@ class LayerNormalization
7470

7571
private:
7672
DataType m_epsilon;
77-
size_t m_max_mini_batch_size;
7873

7974
}; // class definition LayerNorm
8075
} // namespace distconv
8176

82-
#endif // LBANN_HAS_DISTONV
83-
#endif // LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
77+
#endif // LBANN_HAS_DISTCONV
78+
#endif // LBANN_LAYERS_REGULARIZERS_DISTCONV_LAYER_NORM

include/lbann/layers/regularizers/layer_norm.hpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,20 @@
3636
#include <memory>
3737

3838
#ifdef LBANN_HAS_DISTCONV
39-
#include "lbann/utils/distconv.hpp"
4039
#include "lbann/layers/data_type_distconv_adapter.hpp"
4140
#include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp"
41+
#include "lbann/utils/distconv.hpp"
4242
#endif // LBANN_HAS_DISTCONV
4343

4444
namespace lbann {
4545

4646
#ifdef LBANN_HAS_DISTCONV
4747
namespace dc {
48-
using Shape = ::distconv::tensor::Shape;
49-
using Backend= ::distconv::BackendDNNLib;
48+
using Shape = ::distconv::tensor::Shape;
49+
using Backend = ::distconv::BackendDNNLib;
5050
template <typename TensorDataType>
51-
using LayerNormalization = ::distconv::LayerNormalization<Backend, TensorDataType>;
51+
using LayerNormalization =
52+
::distconv::LayerNormalization<Backend, TensorDataType>;
5253
} // namespace dc
5354

5455
template <typename TensorDataType, data_layout Layout, El::Device Device>
@@ -67,12 +68,10 @@ class layer_norm_distconv_adapter
6768

6869
void setup_distributions(tensor_overlap_constraints& constraints) override;
6970
void setup_layer(size_t workspace_capacity) override;
70-
void setup_fp_tensors() override;
71-
void setup_bp_tensors() override;
7271

7372
void fp_compute();
7473
void bp_compute();
75-
74+
7675
TensorDevType m_statistics;
7776
TensorDevType m_statistics_grad;
7877
std::unique_ptr<dc::LayerNormalization<TensorDataType>> m_layer_norm_operator;
@@ -419,13 +418,9 @@ ::get_distconv_adapter(){
419418
// Scatter DistConv Adapter implementation
420419
// =============================================================
421420

422-
#endif // LBANN_HAS_DISTCONV
423-
424-
LBANN_DEFINE_LAYER_BUILDER(layer_norm);
425-
426-
// =========================================================
427-
// Explicit template instantiation
428-
// =========================================================
421+
// =========================================================
422+
// Explicit template instantiation
423+
// =========================================================
429424

430425
#ifndef LBANN_LAYER_NORM_LAYER_INSTANTIATE
431426
#define PROTO_DEVICE(T, Device) \

include/lbann/layers/regularizers/layer_norm_impl.hpp

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
#include "lbann/layers/data_type_distconv_adapter.hpp"
3434
#endif
3535

36-
namespace lbann{
37-
36+
namespace lbann {
3837

3938
// =========================================================
4039
// Implementation
@@ -135,7 +134,6 @@ void layer_norm_layer<TensorDataType, Layout, Device>::setup_data(
135134
m_statistics_gradient.reset(AbsDistMatrixType::Instantiate(dist));
136135
}
137136

138-
139137
#ifdef LBANN_HAS_DISTCONV
140138

141139
// =============================================================
@@ -174,57 +172,53 @@ layer_norm_layer<TensorDataType, Layout, Device>::get_distconv_adapter()
174172
layer_norm_distconv_adapter<TensorDataType, Layout, Device>&>(
175173
static_cast<const layer_norm_layer<TensorDataType, Layout, Device>&>(*this)
176174
.get_distconv_adapter());
175+
}
177176

178177
// =============================================================
179178
// LayerNorm DistConv Adapter implementation
180179
// =============================================================
181180

182-
template <typename TensorDataType, data_layout Layout, El::Device Device>
183-
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::
184-
setup_distributions(tensor_overlap_constraints & constraints)
185-
{
186-
data_type_distconv_adapter<TensorDataType>::setup_distributions(
187-
constraints);
188-
// no overlap needed
189-
for (auto& d : this->m_prev_activations_dists) {
190-
d.clear_overlap();
191-
constraints.mark_updated(d);
192-
constraints.mark_invariant(d);
193-
}
194-
for (auto& d : this->m_activations_dists) {
195-
d.clear_overlap();
196-
constraints.mark_updated(d);
197-
constraints.mark_invariant(d);
198-
}
199-
for (auto& d : this->m_prev_error_signals_dists) {
200-
d.clear_overlap();
201-
constraints.mark_updated(d);
202-
constraints.mark_invariant(d);
203-
}
204-
for (auto& d : this->m_error_signals_dists) {
205-
d.clear_overlap();
206-
constraints.mark_updated(d);
207-
constraints.mark_invariant(d);
208-
}
181+
template <typename TensorDataType, data_layout Layout, El::Device Device>
182+
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::
183+
setup_distributions(tensor_overlap_constraints& constraints)
184+
{
185+
data_type_distconv_adapter<TensorDataType>::setup_distributions(constraints);
186+
// no overlap needed
187+
for (auto& d : this->m_prev_activations_dists) {
188+
d.clear_overlap();
189+
constraints.mark_updated(d);
190+
constraints.mark_invariant(d);
209191
}
210-
211-
template <typename TensorDataType, data_layout Layout, El::Device Device>
212-
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
213-
size_t workspace_capacity)
214-
{
215-
data_type_distconv_adapter<TensorDataType>::setup_layer(workspace_capacity);
216-
auto& layer = dynamic_cast<
217-
channelwise_fully_connected_layer<TensorDataType, Layout, Device>&>(
218-
this->layer());
219-
const auto max_mini_batch_size =
220-
layer.get_model()->m_max_mini_batch_size_distconv;
221-
222-
m_layer_norm_operator =
223-
make_unique<dc::LayerNormalization<TensorDataType>>(dc::get_backend(),
224-
layer.m_epsilon,
225-
max_mini_batch_size);
192+
for (auto& d : this->m_activations_dists) {
193+
d.clear_overlap();
194+
constraints.mark_updated(d);
195+
constraints.mark_invariant(d);
196+
}
197+
for (auto& d : this->m_prev_error_signals_dists) {
198+
d.clear_overlap();
199+
constraints.mark_updated(d);
200+
constraints.mark_invariant(d);
226201
}
202+
for (auto& d : this->m_error_signals_dists) {
203+
d.clear_overlap();
204+
constraints.mark_updated(d);
205+
constraints.mark_invariant(d);
206+
}
207+
}
208+
209+
template <typename TensorDataType, data_layout Layout, El::Device Device>
210+
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
211+
size_t workspace_capacity)
212+
{
213+
data_type_distconv_adapter<TensorDataType>::setup_layer(workspace_capacity);
214+
auto& layer = dynamic_cast<layer_norm_layer<TensorDataType, Layout, Device>&>(
215+
this->layer());
216+
217+
m_layer_norm_operator =
218+
make_unique<dc::LayerNormalization<TensorDataType>>(dc::get_backend(),
219+
layer.m_epsilon);
220+
}
227221

228-
#endif LBANN_HAS_DISTCONV
222+
#endif // LBANN_HAS_DISTCONV
229223
} // namespace lbann
230224
#endif // LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED

0 commit comments

Comments
 (0)