3535#include < memory>
3636
3737#ifdef LBANN_HAS_DISTCONV
38+ #include " lbann/utils/distconv.hpp"
3839#include " lbann/layers/data_type_distconv_adapter.hpp"
39- #include " lbann/layers/regularizeres /distconv/distconv_layer_norm.hpp"
40+ #include " lbann/layers/regularizers /distconv/distconv_layer_norm.hpp"
4041#endif // LBANN_HAS_DISTCONV
4142
4243namespace lbann {
4344
4445#ifdef LBANN_HAS_DISTCONV
4546namespace dc {
47+ using Shape = ::distconv::tensor::Shape;
48+ using Backend= ::distconv::BackendDNNLib;
4649template <typename TensorDataType>
47- using LayerNorm = ::distconv::LayerNorm <Backend, TensorDataType>;
50+ using LayerNormalization = ::distconv::LayerNormalization <Backend, TensorDataType>;
4851} // namespace dc
4952
5053template <typename TensorDataType, data_layout Layout, El::Device Device>
@@ -63,11 +66,15 @@ class layer_norm_distconv_adapter
6366
6467 void setup_distributions (tensor_overlap_constraints& constraints) override ;
6568 void setup_layer (size_t workspace_capacity) override ;
69+ void setup_fp_tensors () override ;
70+ void setup_bp_tensors () override ;
6671
6772 void fp_compute ();
6873 void bp_compute ();
69-
70- std::unique_ptr<dc::LayerNorm<TensorDataType>> m_layer_norm_operator;
74+
75+ TensorDevType m_statistics;
76+ TensorDevType m_statistics_grad;
77+ std::unique_ptr<dc::LayerNormalization<TensorDataType>> m_layer_norm_operator;
7178}; // class definition channelwise_fully_connected_distconv_adapter
7279
7380#endif // LBANN_HAS_DISTCONV
@@ -140,7 +147,7 @@ class layer_norm_layer : public data_type_layer<TensorDataType>
140147 get_distconv_adapter () override ;
141148 const layer_norm_distconv_adapter<TensorDataType, Layout, Device>&
142149 get_distconv_adapter () const override ;
143- #endif
150+ #endif // LBANN_HAS_DISTCONV
144151
145152private:
146153 using AbsDistMatType = El::AbstractDistMatrix<TensorDataType>;
@@ -160,216 +167,14 @@ class layer_norm_layer : public data_type_layer<TensorDataType>
160167 std::unique_ptr<AbsDistMatType> m_statistics_gradient;
161168};
162169
163- // =========================================================
164- // Implementation
165- // =========================================================
166-
167- template <typename T, data_layout L, El::Device D>
168- void layer_norm_layer<T, L, D>::write_specific_proto(
169- lbann_data::Layer& proto) const
170- {
171- proto.set_datatype (proto::ProtoDataType<T>);
172- auto * msg = proto.mutable_layer_norm ();
173- msg->mutable_epsilon ()->set_value (m_epsilon);
174- }
175-
176- template <typename TensorDataType, data_layout Layout, El::Device Device>
177- layer_norm_layer<TensorDataType, Layout, Device>::layer_norm_layer(
178- TensorDataType epsilon)
179- : data_type_layer<TensorDataType>(nullptr ), m_epsilon(epsilon)
180- {}
181-
182- template <typename TensorDataType, data_layout Layout, El::Device Device>
183- layer_norm_layer<TensorDataType, Layout, Device>::layer_norm_layer(
184- const layer_norm_layer<TensorDataType, Layout, Device>& other)
185- : data_type_layer<TensorDataType>(other),
186- m_epsilon (other.m_epsilon),
187- m_statistics(other.m_statistics ? other.m_statistics->Copy () : nullptr),
188- m_statistics_gradient(other.m_statistics_gradient
189- ? other.m_statistics_gradient->Copy ()
190- : nullptr)
191- {}
192-
193- template <typename TensorDataType, data_layout Layout, El::Device Device>
194- layer_norm_layer<TensorDataType, Layout, Device>&
195- layer_norm_layer<TensorDataType, Layout, Device>::operator =(
196- const layer_norm_layer<TensorDataType, Layout, Device>& other)
197- {
198- data_type_layer<TensorDataType>::operator =(other);
199- m_epsilon = other.m_epsilon ;
200- m_statistics.reset (other.m_statistics ? other.m_statistics ->Copy () : nullptr );
201- m_statistics_gradient.reset (other.m_statistics_gradient
202- ? other.m_statistics_gradient ->Copy ()
203- : nullptr );
204- return *this ;
205- }
206-
207- template <typename TensorDataType, data_layout Layout, El::Device Device>
208- layer_norm_layer<TensorDataType, Layout, Device>*
209- layer_norm_layer<TensorDataType, Layout, Device>::copy() const
210- {
211- return new layer_norm_layer (*this );
212- }
213-
214- template <typename TensorDataType, data_layout Layout, El::Device Device>
215- std::string layer_norm_layer<TensorDataType, Layout, Device>::get_type() const
216- {
217- return " layer norm" ;
218- }
219-
220- template <typename TensorDataType, data_layout Layout, El::Device Device>
221- data_layout
222- layer_norm_layer<TensorDataType, Layout, Device>::get_data_layout() const
223- {
224- return Layout;
225- }
226-
227- template <typename TensorDataType, data_layout Layout, El::Device Device>
228- El::Device
229- layer_norm_layer<TensorDataType, Layout, Device>::get_device_allocation() const
230- {
231- return Device;
232- }
233-
234- template <typename TensorDataType, data_layout Layout, El::Device Device>
235- description
236- layer_norm_layer<TensorDataType, Layout, Device>::get_description() const
237- {
238- auto desc = data_type_layer<TensorDataType>::get_description ();
239- desc.add (" Epsilon" , m_epsilon);
240- return desc;
241- }
242-
243- template <typename TensorDataType, data_layout Layout, El::Device Device>
244- void layer_norm_layer<TensorDataType, Layout, Device>::setup_dims(
245- DataReaderMetaData& dr_metadata)
246- {
247- data_type_layer<TensorDataType>::setup_dims (dr_metadata);
248- this ->set_output_dims (this ->get_input_dims ());
249- }
250-
251- template <typename TensorDataType, data_layout Layout, El::Device Device>
252- void layer_norm_layer<TensorDataType, Layout, Device>::setup_data(
253- size_t max_mini_batch_size)
254- {
255- data_type_layer<TensorDataType>::setup_data (max_mini_batch_size);
256- auto dist = this ->get_prev_activations ().DistData ();
257- dist.colDist = El::STAR;
258- m_statistics.reset (AbsDistMatrixType::Instantiate (dist));
259- m_statistics_gradient.reset (AbsDistMatrixType::Instantiate (dist));
260- }
261-
262- #ifdef LBANN_HAS_DISTCONV
263-
264- // =============================================================
265- // DistConv-enabled Scatter member functions
266- // =============================================================
267-
268- template <typename TensorDataType, data_layout Layout, El::Device Device>
269- bool
270- layer_norm_layer<TensorDataType, Layout, Device>
271- ::is_distconv_supported () const {
272- return Device==El::Device::GPU && Layout == data_layout::DATA_PARALLEL;
273- }
274-
275- template <typename TensorDataType, data_layout Layout, El::Device Device>
276- void
277- layer_norm_layer<TensorDataType,Layout,Device>
278- ::setup_distconv_adapter (const DataReaderMetaData& dr_metadata){
279- this ->get_distconv_adapter_ptr () = std::make_unique<layer_norm_distconv_adapter<
280- TensorDataType, Layout, Device>>(*this );
281- }
282-
283- template <typename TensorDataType, data_layout Layout, El::Device Device>
284- const layer_norm_distconv_adapter <TensorDataType, Layout, Device>&
285- layer_norm_layer<TensorDataType, Layout, Device>
286- ::get_distconv_adapter () const {
287- return dynamic_cast <const layer_norm_distconv_adapter<
288- TensorDataType, Layout, Device>&>(data_type_layer<TensorDataType>::get_distconv_adapter ());
289- }
290-
291- template <typename TensorDataType, data_layout Layout, El::Device Device>
292- layer_norm_distconv_adapter <TensorDataType, Layout, Device>&
293- layer_norm_layer<TensorDataType, Layout, Device>
294- ::get_distconv_adapter (){
295- return const_cast <layer_norm_distconv_adapter<TensorDataType, Layout, Device>&>(
296- static_cast <const layer_norm_layer<TensorDataType, Layout, Device>&>(*this ).get_distconv_adapter ());
297-
298-
299- // =============================================================
300- // Scatter DistConv Adapter implementation
301- // =============================================================
302-
303- template <typename TensorDataType, data_layout Layout, El::Device Device>
304- void
305- layer_norm_distconv_adapter<TensorDataType, Layout, Device>
306- ::setup_distributions (tensor_overlap_constraints &constraints){
307- data_type_distconv_adapter<TensorDataType>::setup_distributions (constraints);
308- // no overlap needed
309- for (auto &d: this ->m_prev_activations_dists ) {
310- d.clear_overlap ();
311- constraints.mark_updated (d);
312- constraints.mark_invariant (d);
313- }
314- for (auto &d: this ->m_activations_dists ) {
315- d.clear_overlap ();
316- constraints.mark_updated (d);
317- constraints.mark_invariant (d);
318- }
319- for (auto &d: this ->m_prev_error_signals_dists ) {
320- d.clear_overlap ();
321- constraints.mark_updated (d);
322- constraints.mark_invariant (d);
323- }
324- for (auto &d: this ->m_error_signals_dists ) {
325- d.clear_overlap ();
326- constraints.mark_updated (d);
327- constraints.mark_invariant (d);
328- }
329- }
330-
331- template <typename TensorDataType, data_layout Layout, El::Device Device>
332- void
333- layer_norm_distconv_adapter<TensorDataType, Layout, Device>
334- ::setup_layer (size_t workspace_capacity){
335- data_type_distconv_adapter<TensorDataType>::setup_layer (workspace_capacity);
336- auto &layer = dynamic_cast <channelwise_fully_connected_layer
337- <TensorDataType, Layout, Device>&>(this ->layer ());
338- m_layer_norm_operator = make_unique<dc::Scatter<TensorDataType>>(dc::get_backend (),
339- layer.m_epsilon );
340- }
341170
342-
343- template <typename TensorDataType, data_layout Layout, El::Device Device>
344- void
345- layer_norm_distconv_adapter<TensorDataType, Layout, Device>
346- ::fp_compute (){
347- // Compute the forward pass
348- m_layer_norm_operator->forward (this ->get_prev_activations (0 ),
349- this -m_epsilon);
350- }
351-
352-
353- template <typename TensorDataType, data_layout Layout, El::Device Device>
354- void
355- layer_norm_distconv_adapter<TensorDataType, Layout, Device>
356- ::bp_compute (){
357- // Compute the backward pass
358- m_layer_norm_operator->backward (this ->get_prev_error_signals (0 )); // Indices gradient. Will be 0'ed out
359- }
360-
361- #define PROTO_DEVICE (T, Device ) \
362- template class layer_norm_distconv_adapter < \
363- T,data_layout::DATA_PARALLEL, Device>
364- #include " lbann/macros/instantiate_device.hpp"
365- #undef PROTO_DEVICE
366171#endif // LBANN_HAS_DISTCONV
367172
368- LBANN_DEFINE_LAYER_BUILDER (layer_norm);
173+ LBANN_DEFINE_LAYER_BUILDER (layer_norm);
369174
370- // =========================================================
371- // Explicit template instantiation
372- // =========================================================
175+ // =========================================================
176+ // Explicit template instantiation
177+ // =========================================================
373178
374179#ifndef LBANN_LAYER_NORM_LAYER_INSTANTIATE
375180#define PROTO_DEVICE (T, Device ) \
0 commit comments