3333#include " lbann/layers/data_type_distconv_adapter.hpp"
3434#endif
3535
36- namespace lbann {
37-
36+ namespace lbann {
3837
3938// =========================================================
4039// Implementation
@@ -135,7 +134,6 @@ void layer_norm_layer<TensorDataType, Layout, Device>::setup_data(
135134 m_statistics_gradient.reset (AbsDistMatrixType::Instantiate (dist));
136135}
137136
138-
139137#ifdef LBANN_HAS_DISTCONV
140138
141139// =============================================================
@@ -174,57 +172,53 @@ layer_norm_layer<TensorDataType, Layout, Device>::get_distconv_adapter()
174172 layer_norm_distconv_adapter<TensorDataType, Layout, Device>&>(
175173 static_cast <const layer_norm_layer<TensorDataType, Layout, Device>&>(*this )
176174 .get_distconv_adapter ());
175+ }
177176
178177// =============================================================
179178// LayerNorm DistConv Adapter implementation
180179// =============================================================
181180
182- template <typename TensorDataType, data_layout Layout, El::Device Device>
183- void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::
184- setup_distributions (tensor_overlap_constraints & constraints)
185- {
186- data_type_distconv_adapter<TensorDataType>::setup_distributions (
187- constraints);
188- // no overlap needed
189- for (auto & d : this ->m_prev_activations_dists ) {
190- d.clear_overlap ();
191- constraints.mark_updated (d);
192- constraints.mark_invariant (d);
193- }
194- for (auto & d : this ->m_activations_dists ) {
195- d.clear_overlap ();
196- constraints.mark_updated (d);
197- constraints.mark_invariant (d);
198- }
199- for (auto & d : this ->m_prev_error_signals_dists ) {
200- d.clear_overlap ();
201- constraints.mark_updated (d);
202- constraints.mark_invariant (d);
203- }
204- for (auto & d : this ->m_error_signals_dists ) {
205- d.clear_overlap ();
206- constraints.mark_updated (d);
207- constraints.mark_invariant (d);
208- }
181+ template <typename TensorDataType, data_layout Layout, El::Device Device>
182+ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::
183+ setup_distributions (tensor_overlap_constraints& constraints)
184+ {
185+ data_type_distconv_adapter<TensorDataType>::setup_distributions (constraints);
186+ // no overlap needed
187+ for (auto & d : this ->m_prev_activations_dists ) {
188+ d.clear_overlap ();
189+ constraints.mark_updated (d);
190+ constraints.mark_invariant (d);
209191 }
210-
211- template <typename TensorDataType, data_layout Layout, El::Device Device>
212- void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::setup_layer (
213- size_t workspace_capacity)
214- {
215- data_type_distconv_adapter<TensorDataType>::setup_layer (workspace_capacity);
216- auto & layer = dynamic_cast <
217- channelwise_fully_connected_layer<TensorDataType, Layout, Device>&>(
218- this ->layer ());
219- const auto max_mini_batch_size =
220- layer.get_model ()->m_max_mini_batch_size_distconv ;
221-
222- m_layer_norm_operator =
223- make_unique<dc::LayerNormalization<TensorDataType>>(dc::get_backend (),
224- layer.m_epsilon ,
225- max_mini_batch_size);
192+ for (auto & d : this ->m_activations_dists ) {
193+ d.clear_overlap ();
194+ constraints.mark_updated (d);
195+ constraints.mark_invariant (d);
196+ }
197+ for (auto & d : this ->m_prev_error_signals_dists ) {
198+ d.clear_overlap ();
199+ constraints.mark_updated (d);
200+ constraints.mark_invariant (d);
226201 }
202+ for (auto & d : this ->m_error_signals_dists ) {
203+ d.clear_overlap ();
204+ constraints.mark_updated (d);
205+ constraints.mark_invariant (d);
206+ }
207+ }
208+
209+ template <typename TensorDataType, data_layout Layout, El::Device Device>
210+ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
211+ size_t workspace_capacity)
212+ {
213+ data_type_distconv_adapter<TensorDataType>::setup_layer (workspace_capacity);
214+ auto & layer = dynamic_cast <layer_norm_layer<TensorDataType, Layout, Device>&>(
215+ this ->layer ());
216+
217+ m_layer_norm_operator =
218+ make_unique<dc::LayerNormalization<TensorDataType>>(dc::get_backend (),
219+ layer.m_epsilon );
220+ }
227221
228- #endif LBANN_HAS_DISTCONV
222+ #endif // LBANN_HAS_DISTCONV
229223} // namespace lbann
230224#endif // LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED
0 commit comments