@@ -46,29 +46,33 @@ class GradientHelperImpl : public optimizer::GradientHelper
4646 El::DistData grad_dist_data,
4747 bool sharded_weights)
4848 : local_gradient_contrib_{AbsDistMatType::Instantiate (dist_data)},
49- local_contrib_dist_{dist_data},
5049 global_gradient_{AbsDistMatType::Instantiate (grad_dist_data)},
51- global_dist_{grad_dist_data},
5250 sharded_weights_{sharded_weights}
5351 {
5452 ensure_gradient_memory (height, width);
5553 El::Zeros (*local_gradient_contrib_, height, width);
56- if (grad_dist_data != dist_data ) {
54+ if (sharded_weights ) {
5755 El::Zeros (*global_gradient_, height, width);
5856 }
5957 }
6058
6159 void ensure_gradient_memory (El::Int height, El::Int width) override
6260 {
6361#if defined(LBANN_HAS_GPU)
64- local_gradient_contrib_->Matrix ().SetMemoryMode (1 );
62+ static const char * e = std::getenv (" LBANN_USE_DIRECT_FOR_CONTRIB" );
63+ if (e != nullptr && e[0 ] == ' 1' ) {
64+ local_gradient_contrib_->Matrix ().SetMemoryMode (0 );
65+ }
66+ else {
67+ local_gradient_contrib_->Matrix ().SetMemoryMode (1 );
68+ }
6569#endif // LBANN_HAS_GPU
6670
6771 if (local_gradient_contrib_->Width () == 0 ) {
6872 local_gradient_contrib_->Resize (height, width);
6973 // If distribution is the same, have global gradient matrix view the
7074 // local contributions.
71- if (local_contrib_dist_ == global_dist_ ) {
75+ if (!sharded_weights_ ) {
7276 El::View (*global_gradient_, *local_gradient_contrib_);
7377 }
7478 }
@@ -96,6 +100,13 @@ class GradientHelperImpl : public optimizer::GradientHelper
96100
97101 void start_sync (lbann_comm& comm) override
98102 {
103+ // Complete outstanding synchronization of the same data type
104+ static GradientHelperImpl<TensorDataType>* lastsync = nullptr ;
105+ if (lastsync != nullptr ) {
106+ lastsync->complete_sync (comm);
107+ lastsync = nullptr ;
108+ }
109+
99110 switch (this ->get_status ()) {
100111 case optimizer_gradient_status::sync_needed:
101112 // Sharded gradients are produced from a reduce-scatter on the local
@@ -122,6 +133,7 @@ class GradientHelperImpl : public optimizer::GradientHelper
122133 */
123134 }
124135 this ->set_status (optimizer_gradient_status::sync_started);
136+ lastsync = this ;
125137 break ;
126138 case optimizer_gradient_status::ready:
127139 case optimizer_gradient_status::cleared:
@@ -166,19 +178,19 @@ class GradientHelperImpl : public optimizer::GradientHelper
166178 void clear () override
167179 {
168180 this ->set_status (optimizer_gradient_status::cleared);
181+ local_gradient_contrib_->Empty ();
182+ global_gradient_->Empty ();
169183 }
170184
171185private:
172186 /* * Matches the distribution of gathered (unsharded) weights in backprop. */
173187 std::unique_ptr<AbsDistMatType> local_gradient_contrib_;
174- El::DistData local_contrib_dist_;
175188
176189 /* * Matches the distribution of data_type_optimizer<T>::m_gradient (i.e.,
177190 * post synchronization). Will view said matrix if only one data type
178191 * exists.
179192 */
180193 std::unique_ptr<AbsDistMatType> global_gradient_;
181- El::DistData global_dist_;
182194
183195 Al::request sync_req_;
184196 bool sharded_weights_;
@@ -218,6 +230,8 @@ optimizer::get_gradient_buffer(TensorDataType& buf_scale,
218230 // If the manager hasn't been created, let's make it.
219231 auto mat_info = this ->get_matrix_info ();
220232 if (!grad_mgr_ptr) {
233+ // If our optimizer contains a gradient of the same data type, reuse (view)
234+ // it in the gradient manager
221235 grad_mgr_ptr = std::make_unique<GradMgrType>(std::get<HEIGHT>(mat_info),
222236 std::get<WIDTH>(mat_info),
223237 std::get<DISTDATA_L>(mat_info),
@@ -319,13 +333,13 @@ void optimizer::accumulate_all_gradient_contributions(
319333 // Handle the case that only 1 update of a different type is needed.
320334 if (num_updates == 1UL &&
321335 this ->m_local_gradient_contributions .size () == 1UL ) {
322- auto const & grad_mgr =
323- *(this ->m_local_gradient_contributions .begin ()->second );
336+ auto & grad_mgr = *(this ->m_local_gradient_contributions .begin ()->second );
324337 if (grad_mgr.get_status () != optimizer_gradient_status::ready) {
325338 LBANN_ERROR (" Expected ready status. Got: " ,
326339 to_string (grad_mgr.get_status ()));
327340 }
328341 El::Copy (grad_mgr.global_gradient (), gradient);
342+ grad_mgr.clear ();
329343 }
330344 else if (this ->m_local_gradient_contributions .size () > 1UL ) {
331345 // Need a temporary matrix for the type-casted copy.
@@ -335,14 +349,15 @@ void optimizer::accumulate_all_gradient_contributions(
335349 for (auto const & grad_mgr_v : this ->m_local_gradient_contributions ) {
336350 if (grad_mgr_v.first == this_type_idx)
337351 continue ;
338- auto const & grad_mgr = *(grad_mgr_v.second );
352+ auto & grad_mgr = *(grad_mgr_v.second );
339353 if (grad_mgr.get_status () != optimizer_gradient_status::ready) {
340354 LBANN_ERROR (" Expected ready status. Got: " ,
341355 to_string (grad_mgr.get_status ()));
342356 }
343357 auto const & grad_base = grad_mgr.global_gradient ();
344358 El::Copy (grad_base, *tmp);
345359 El::Axpy (one, *tmp, gradient);
360+ grad_mgr.clear ();
346361 }
347362 }
348363}
0 commit comments