Skip to content

Commit 710d0fe

Browse files
committed
Separated out layer norm implementation from class definition
1 parent 59723d2 commit 710d0fe

File tree

3 files changed

+303
-211
lines changed

3 files changed

+303
-211
lines changed

include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,60 @@
2424
// permissions and limitations under the license.
2525
////////////////////////////////////////////////////////////////////////////////
2626

27+
#ifndef LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
28+
#define LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
29+
30+
#if LBANN_HAS_DISTCONV
31+
32+
namespace distconv {
33+
template <typename Backend, typename DataType>
34+
class LayerNorm
35+
{
36+
using LocaleMPI = tensor::LocaleMPI;
37+
38+
template <typename Allocator>
39+
using DCTensor = tensor::Tensor<DataType, LocaleMPI, Allocator>;
40+
41+
public:
42+
LayerNormalization(Backend& backend,
43+
Datatype epsilon,
44+
size_t max_mini_batch_size)
45+
: m_backend(backend),
46+
m_epsilon(epsilon),
47+
m_max_mini_batch_size(max_mini_batch_size)
48+
{}
49+
50+
template <typename Allocator>
51+
void calculate_forward_stats(const DCTensor<Allocator>& input,
52+
DC<Allocator>& statistics);
53+
54+
template <typename Allocator>
55+
void apply_normalization(const DCTensor<Allocator>& input,
56+
const DCTensor<Allocator>& statistics,
57+
DCTensor<Allocator>& output);
58+
59+
template <typename Allocator>
60+
void calculate_backward_stats(const DCTensor<Allocator>& input,
61+
const DCTensor<Allocator>& output_grad,
62+
const DCTensor<Allocator>& statistics,
63+
DCTensor<Allocator>& statistics_grad);
64+
65+
template <typename Allocator>
66+
void apply_grad(const DCTensor<Allocator>& input,
67+
const DCTensor<Allocator>& output_grad,
68+
const DCTensor<Allocator>& statistics,
69+
const DCTensor<Allocator>& statistics_grad,
70+
DCTensor<Allocator>& input_grad);
71+
72+
protected:
73+
Backend& m_backend;
74+
75+
private:
76+
DataType m_epsilon;
77+
size_t m_max_mini_batch_size;
78+
79+
}; // class definition LayerNorm
80+
} // namespace distconv
81+
82+
#endif // LBANN_HAS_DISTONV
83+
#endif // LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM

include/lbann/layers/regularizers/layer_norm.hpp

Lines changed: 16 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,19 @@
3535
#include <memory>
3636

3737
#ifdef LBANN_HAS_DISTCONV
38+
#include "lbann/utils/distconv.hpp"
3839
#include "lbann/layers/data_type_distconv_adapter.hpp"
39-
#include "lbann/layers/regularizeres/distconv/distconv_layer_norm.hpp"
40+
#include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp"
4041
#endif // LBANN_HAS_DISTCONV
4142

4243
namespace lbann {
4344

4445
#ifdef LBANN_HAS_DISTCONV
4546
namespace dc {
47+
using Shape = ::distconv::tensor::Shape;
48+
using Backend= ::distconv::BackendDNNLib;
4649
template <typename TensorDataType>
47-
using LayerNorm = ::distconv::LayerNorm<Backend, TensorDataType>;
50+
using LayerNormalization = ::distconv::LayerNormalization<Backend, TensorDataType>;
4851
} // namespace dc
4952

5053
template <typename TensorDataType, data_layout Layout, El::Device Device>
@@ -63,11 +66,15 @@ class layer_norm_distconv_adapter
6366

6467
void setup_distributions(tensor_overlap_constraints& constraints) override;
6568
void setup_layer(size_t workspace_capacity) override;
69+
void setup_fp_tensors() override;
70+
void setup_bp_tensors() override;
6671

6772
void fp_compute();
6873
void bp_compute();
69-
70-
std::unique_ptr<dc::LayerNorm<TensorDataType>> m_layer_norm_operator;
74+
75+
TensorDevType m_statistics;
76+
TensorDevType m_statistics_grad;
77+
std::unique_ptr<dc::LayerNormalization<TensorDataType>> m_layer_norm_operator;
7178
}; // class definition channelwise_fully_connected_distconv_adapter
7279

7380
#endif // LBANN_HAS_DISTCONV
@@ -140,7 +147,7 @@ class layer_norm_layer : public data_type_layer<TensorDataType>
140147
get_distconv_adapter() override;
141148
const layer_norm_distconv_adapter<TensorDataType, Layout, Device>&
142149
get_distconv_adapter() const override;
143-
#endif
150+
#endif // LBANN_HAS_DISTCONV
144151

145152
private:
146153
using AbsDistMatType = El::AbstractDistMatrix<TensorDataType>;
@@ -160,216 +167,14 @@ class layer_norm_layer : public data_type_layer<TensorDataType>
160167
std::unique_ptr<AbsDistMatType> m_statistics_gradient;
161168
};
162169

163-
// =========================================================
164-
// Implementation
165-
// =========================================================
166-
167-
template <typename T, data_layout L, El::Device D>
168-
void layer_norm_layer<T, L, D>::write_specific_proto(
169-
lbann_data::Layer& proto) const
170-
{
171-
proto.set_datatype(proto::ProtoDataType<T>);
172-
auto* msg = proto.mutable_layer_norm();
173-
msg->mutable_epsilon()->set_value(m_epsilon);
174-
}
175-
176-
template <typename TensorDataType, data_layout Layout, El::Device Device>
177-
layer_norm_layer<TensorDataType, Layout, Device>::layer_norm_layer(
178-
TensorDataType epsilon)
179-
: data_type_layer<TensorDataType>(nullptr), m_epsilon(epsilon)
180-
{}
181-
182-
template <typename TensorDataType, data_layout Layout, El::Device Device>
183-
layer_norm_layer<TensorDataType, Layout, Device>::layer_norm_layer(
184-
const layer_norm_layer<TensorDataType, Layout, Device>& other)
185-
: data_type_layer<TensorDataType>(other),
186-
m_epsilon(other.m_epsilon),
187-
m_statistics(other.m_statistics ? other.m_statistics->Copy() : nullptr),
188-
m_statistics_gradient(other.m_statistics_gradient
189-
? other.m_statistics_gradient->Copy()
190-
: nullptr)
191-
{}
192-
193-
template <typename TensorDataType, data_layout Layout, El::Device Device>
194-
layer_norm_layer<TensorDataType, Layout, Device>&
195-
layer_norm_layer<TensorDataType, Layout, Device>::operator=(
196-
const layer_norm_layer<TensorDataType, Layout, Device>& other)
197-
{
198-
data_type_layer<TensorDataType>::operator=(other);
199-
m_epsilon = other.m_epsilon;
200-
m_statistics.reset(other.m_statistics ? other.m_statistics->Copy() : nullptr);
201-
m_statistics_gradient.reset(other.m_statistics_gradient
202-
? other.m_statistics_gradient->Copy()
203-
: nullptr);
204-
return *this;
205-
}
206-
207-
template <typename TensorDataType, data_layout Layout, El::Device Device>
208-
layer_norm_layer<TensorDataType, Layout, Device>*
209-
layer_norm_layer<TensorDataType, Layout, Device>::copy() const
210-
{
211-
return new layer_norm_layer(*this);
212-
}
213-
214-
template <typename TensorDataType, data_layout Layout, El::Device Device>
215-
std::string layer_norm_layer<TensorDataType, Layout, Device>::get_type() const
216-
{
217-
return "layer norm";
218-
}
219-
220-
template <typename TensorDataType, data_layout Layout, El::Device Device>
221-
data_layout
222-
layer_norm_layer<TensorDataType, Layout, Device>::get_data_layout() const
223-
{
224-
return Layout;
225-
}
226-
227-
template <typename TensorDataType, data_layout Layout, El::Device Device>
228-
El::Device
229-
layer_norm_layer<TensorDataType, Layout, Device>::get_device_allocation() const
230-
{
231-
return Device;
232-
}
233-
234-
template <typename TensorDataType, data_layout Layout, El::Device Device>
235-
description
236-
layer_norm_layer<TensorDataType, Layout, Device>::get_description() const
237-
{
238-
auto desc = data_type_layer<TensorDataType>::get_description();
239-
desc.add("Epsilon", m_epsilon);
240-
return desc;
241-
}
242-
243-
template <typename TensorDataType, data_layout Layout, El::Device Device>
244-
void layer_norm_layer<TensorDataType, Layout, Device>::setup_dims(
245-
DataReaderMetaData& dr_metadata)
246-
{
247-
data_type_layer<TensorDataType>::setup_dims(dr_metadata);
248-
this->set_output_dims(this->get_input_dims());
249-
}
250-
251-
template <typename TensorDataType, data_layout Layout, El::Device Device>
252-
void layer_norm_layer<TensorDataType, Layout, Device>::setup_data(
253-
size_t max_mini_batch_size)
254-
{
255-
data_type_layer<TensorDataType>::setup_data(max_mini_batch_size);
256-
auto dist = this->get_prev_activations().DistData();
257-
dist.colDist = El::STAR;
258-
m_statistics.reset(AbsDistMatrixType::Instantiate(dist));
259-
m_statistics_gradient.reset(AbsDistMatrixType::Instantiate(dist));
260-
}
261-
262-
#ifdef LBANN_HAS_DISTCONV
263-
264-
// =============================================================
265-
// DistConv-enabled Scatter member functions
266-
// =============================================================
267-
268-
template <typename TensorDataType, data_layout Layout, El::Device Device>
269-
bool
270-
layer_norm_layer<TensorDataType, Layout, Device>
271-
::is_distconv_supported() const {
272-
return Device==El::Device::GPU && Layout == data_layout::DATA_PARALLEL;
273-
}
274-
275-
template <typename TensorDataType, data_layout Layout, El::Device Device>
276-
void
277-
layer_norm_layer<TensorDataType,Layout,Device>
278-
::setup_distconv_adapter(const DataReaderMetaData& dr_metadata){
279-
this->get_distconv_adapter_ptr() = std::make_unique<layer_norm_distconv_adapter<
280-
TensorDataType, Layout, Device>>(*this);
281-
}
282-
283-
template <typename TensorDataType, data_layout Layout, El::Device Device>
284-
const layer_norm_distconv_adapter <TensorDataType, Layout, Device>&
285-
layer_norm_layer<TensorDataType, Layout, Device>
286-
::get_distconv_adapter() const{
287-
return dynamic_cast<const layer_norm_distconv_adapter<
288-
TensorDataType, Layout, Device>&>(data_type_layer<TensorDataType>::get_distconv_adapter());
289-
}
290-
291-
template <typename TensorDataType, data_layout Layout, El::Device Device>
292-
layer_norm_distconv_adapter <TensorDataType, Layout, Device>&
293-
layer_norm_layer<TensorDataType, Layout, Device>
294-
::get_distconv_adapter(){
295-
return const_cast<layer_norm_distconv_adapter<TensorDataType, Layout, Device>&>(
296-
static_cast<const layer_norm_layer<TensorDataType, Layout, Device>&>(*this).get_distconv_adapter());
297-
298-
299-
// =============================================================
300-
// Scatter DistConv Adapter implementation
301-
// =============================================================
302-
303-
template <typename TensorDataType, data_layout Layout, El::Device Device>
304-
void
305-
layer_norm_distconv_adapter<TensorDataType, Layout, Device>
306-
::setup_distributions(tensor_overlap_constraints &constraints){
307-
data_type_distconv_adapter<TensorDataType>::setup_distributions(constraints);
308-
// no overlap needed
309-
for (auto &d: this->m_prev_activations_dists) {
310-
d.clear_overlap();
311-
constraints.mark_updated(d);
312-
constraints.mark_invariant(d);
313-
}
314-
for (auto &d: this->m_activations_dists) {
315-
d.clear_overlap();
316-
constraints.mark_updated(d);
317-
constraints.mark_invariant(d);
318-
}
319-
for (auto &d: this->m_prev_error_signals_dists) {
320-
d.clear_overlap();
321-
constraints.mark_updated(d);
322-
constraints.mark_invariant(d);
323-
}
324-
for (auto &d: this->m_error_signals_dists) {
325-
d.clear_overlap();
326-
constraints.mark_updated(d);
327-
constraints.mark_invariant(d);
328-
}
329-
}
330-
331-
template <typename TensorDataType, data_layout Layout, El::Device Device>
332-
void
333-
layer_norm_distconv_adapter<TensorDataType, Layout, Device>
334-
::setup_layer(size_t workspace_capacity){
335-
data_type_distconv_adapter<TensorDataType>::setup_layer(workspace_capacity);
336-
auto &layer = dynamic_cast<channelwise_fully_connected_layer
337-
<TensorDataType, Layout, Device>&>(this->layer());
338-
m_layer_norm_operator = make_unique<dc::Scatter<TensorDataType>>(dc::get_backend(),
339-
layer.m_epsilon);
340-
}
341170

342-
343-
template <typename TensorDataType, data_layout Layout, El::Device Device>
344-
void
345-
layer_norm_distconv_adapter<TensorDataType, Layout, Device>
346-
::fp_compute(){
347-
// Compute the forward pass
348-
m_layer_norm_operator->forward(this->get_prev_activations(0),
349-
this-m_epsilon);
350-
}
351-
352-
353-
template <typename TensorDataType, data_layout Layout, El::Device Device>
354-
void
355-
layer_norm_distconv_adapter<TensorDataType, Layout, Device>
356-
::bp_compute(){
357-
// Compute the backward pass
358-
m_layer_norm_operator->backward(this->get_prev_error_signals(0)); // Indices gradient. Will be 0'ed out
359-
}
360-
361-
#define PROTO_DEVICE(T, Device) \
362-
template class layer_norm_distconv_adapter< \
363-
T,data_layout::DATA_PARALLEL, Device>
364-
#include "lbann/macros/instantiate_device.hpp"
365-
#undef PROTO_DEVICE
366171
#endif // LBANN_HAS_DISTCONV
367172

368-
LBANN_DEFINE_LAYER_BUILDER(layer_norm);
173+
LBANN_DEFINE_LAYER_BUILDER(layer_norm);
369174

370-
// =========================================================
371-
// Explicit template instantiation
372-
// =========================================================
175+
// =========================================================
176+
// Explicit template instantiation
177+
// =========================================================
373178

374179
#ifndef LBANN_LAYER_NORM_LAYER_INSTANTIATE
375180
#define PROTO_DEVICE(T, Device) \

0 commit comments

Comments
 (0)