Skip to content

Commit 4e92814

Browse files
authored
Transformer updates and memory optimizations (#2370)
* Save memory on backpropagation by releasing gradients in GradientManager as soon as possible * Save more memory in FSDP by synchronizing previous outstanding async communication calls and freeing up local gradient contributions * FSDP: release full weight views after backprop * Minor tweaks to transformer training script
1 parent 5e92a57 commit 4e92814

File tree

6 files changed

+75
-29
lines changed

6 files changed

+75
-29
lines changed

applications/nlp/transformer/parallelism.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ def apply_fsdp_allweights(model: lbann.Model, args: argparse.Namespace):
5656

5757
# Loop over all weights
5858
for layer in model.layers:
59+
# As a heuristic, only shard the first set of weights (i.e., no
60+
# biases) and skip layer normalization
61+
if 'LayerNorm' in str(type(layer)):
62+
continue
5963
if layer.weights:
60-
# As a heuristic, only shard the first set of weights (i.e., no
61-
# biases)
6264
if len(layer.weights) > 0:
6365
layer.weights[0].sharded = True
6466

applications/nlp/transformer/trainer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ def construct_training_task(model: lbann.Model,
6565
# Data reader
6666
# ----------------------------------------------
6767
def make_data_reader(dataset_name: str, fraction: float, validate: bool,
68-
val_fraction: float):
68+
val_fraction: float, always_shuffle: bool):
6969
reader = lbann.reader_pb2.DataReader()
7070
_reader = reader.reader.add()
7171
_reader.name = 'python'
7272
_reader.role = 'train'
73-
_reader.shuffle = False if 'pretokenized' in dataset_name else True
73+
_reader.shuffle = (True if always_shuffle
74+
or 'pretokenized' not in dataset_name else False)
7475
_reader.fraction_of_data_to_use = fraction
7576
_reader.python.module = dataset_name
7677
_reader.python.module_dir = os.path.join(
@@ -124,7 +125,8 @@ def make_batch_script(model: lbann.Model,
124125
training_algo=algo)
125126
reader = make_data_reader(dataset_name, args.dataset_fraction,
126127
not args.skip_validation,
127-
args.validation_set_fraction)
128+
args.validation_set_fraction,
129+
args.always_shuffle)
128130

129131
# Optimizer with learning rate schedule
130132
if args.optimizer.lower() == 'adamw':
@@ -188,6 +190,10 @@ def make_batch_script(model: lbann.Model,
188190
epoch_interval=1,
189191
))
190192

193+
if args.validate_every > 0:
194+
model.callbacks.append(
195+
lbann.CallbackEvaluateProgress(batch_interval=args.validate_every))
196+
191197
# Print a progress bar
192198
if args.progress:
193199
model.callbacks.append(
@@ -237,6 +243,13 @@ def add_training_arguments(parser: argparse.ArgumentParser):
237243
action="store_true",
238244
default=False,
239245
help="Do not run validation (default: false)")
246+
parser.add_argument(
247+
"--always-shuffle",
248+
action="store_true",
249+
default=False,
250+
help=
251+
"Always shuffle training dataset, even if pretokenized (default: false)"
252+
)
240253
parser.add_argument(
241254
"--validation-set-fraction",
242255
type=float,
@@ -248,6 +261,10 @@ def add_training_arguments(parser: argparse.ArgumentParser):
248261
default=False,
249262
help="Save prototext experiment file instead of protobin (slower but "
250263
"debuggable) (default: false)")
264+
parser.add_argument("--validate-every",
265+
type=int,
266+
default=100,
267+
help="Run validation every N steps (default: 100)")
251268

252269

253270
# ----------------------------------------------

include/lbann/optimizers/optimizer_impl.hpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,29 +46,33 @@ class GradientHelperImpl : public optimizer::GradientHelper
4646
El::DistData grad_dist_data,
4747
bool sharded_weights)
4848
: local_gradient_contrib_{AbsDistMatType::Instantiate(dist_data)},
49-
local_contrib_dist_{dist_data},
5049
global_gradient_{AbsDistMatType::Instantiate(grad_dist_data)},
51-
global_dist_{grad_dist_data},
5250
sharded_weights_{sharded_weights}
5351
{
5452
ensure_gradient_memory(height, width);
5553
El::Zeros(*local_gradient_contrib_, height, width);
56-
if (grad_dist_data != dist_data) {
54+
if (sharded_weights) {
5755
El::Zeros(*global_gradient_, height, width);
5856
}
5957
}
6058

6159
void ensure_gradient_memory(El::Int height, El::Int width) override
6260
{
6361
#if defined(LBANN_HAS_GPU)
64-
local_gradient_contrib_->Matrix().SetMemoryMode(1);
62+
static const char* e = std::getenv("LBANN_USE_DIRECT_FOR_CONTRIB");
63+
if (e != nullptr && e[0] == '1') {
64+
local_gradient_contrib_->Matrix().SetMemoryMode(0);
65+
}
66+
else {
67+
local_gradient_contrib_->Matrix().SetMemoryMode(1);
68+
}
6569
#endif // LBANN_HAS_GPU
6670

6771
if (local_gradient_contrib_->Width() == 0) {
6872
local_gradient_contrib_->Resize(height, width);
6973
// If distribution is the same, have global gradient matrix view the
7074
// local contributions.
71-
if (local_contrib_dist_ == global_dist_) {
75+
if (!sharded_weights_) {
7276
El::View(*global_gradient_, *local_gradient_contrib_);
7377
}
7478
}
@@ -96,6 +100,13 @@ class GradientHelperImpl : public optimizer::GradientHelper
96100

97101
void start_sync(lbann_comm& comm) override
98102
{
103+
// Complete outstanding synchronization of the same data type
104+
static GradientHelperImpl<TensorDataType>* lastsync = nullptr;
105+
if (lastsync != nullptr) {
106+
lastsync->complete_sync(comm);
107+
lastsync = nullptr;
108+
}
109+
99110
switch (this->get_status()) {
100111
case optimizer_gradient_status::sync_needed:
101112
// Sharded gradients are produced from a reduce-scatter on the local
@@ -122,6 +133,7 @@ class GradientHelperImpl : public optimizer::GradientHelper
122133
*/
123134
}
124135
this->set_status(optimizer_gradient_status::sync_started);
136+
lastsync = this;
125137
break;
126138
case optimizer_gradient_status::ready:
127139
case optimizer_gradient_status::cleared:
@@ -166,19 +178,19 @@ class GradientHelperImpl : public optimizer::GradientHelper
166178
void clear() override
167179
{
168180
this->set_status(optimizer_gradient_status::cleared);
181+
local_gradient_contrib_->Empty();
182+
global_gradient_->Empty();
169183
}
170184

171185
private:
172186
/** Matches the distribution of gathered (unsharded) weights in backprop. */
173187
std::unique_ptr<AbsDistMatType> local_gradient_contrib_;
174-
El::DistData local_contrib_dist_;
175188

176189
/** Matches the distribution of data_type_optimizer<T>::m_gradient (i.e.,
177190
* post synchronization). Will view said matrix if only one data type
178191
* exists.
179192
*/
180193
std::unique_ptr<AbsDistMatType> global_gradient_;
181-
El::DistData global_dist_;
182194

183195
Al::request sync_req_;
184196
bool sharded_weights_;
@@ -218,6 +230,8 @@ optimizer::get_gradient_buffer(TensorDataType& buf_scale,
218230
// If the manager hasn't been created, let's make it.
219231
auto mat_info = this->get_matrix_info();
220232
if (!grad_mgr_ptr) {
233+
// If our optimizer contains a gradient of the same data type, reuse (view)
234+
// it in the gradient manager
221235
grad_mgr_ptr = std::make_unique<GradMgrType>(std::get<HEIGHT>(mat_info),
222236
std::get<WIDTH>(mat_info),
223237
std::get<DISTDATA_L>(mat_info),
@@ -319,13 +333,13 @@ void optimizer::accumulate_all_gradient_contributions(
319333
// Handle the case that only 1 update of a different type is needed.
320334
if (num_updates == 1UL &&
321335
this->m_local_gradient_contributions.size() == 1UL) {
322-
auto const& grad_mgr =
323-
*(this->m_local_gradient_contributions.begin()->second);
336+
auto& grad_mgr = *(this->m_local_gradient_contributions.begin()->second);
324337
if (grad_mgr.get_status() != optimizer_gradient_status::ready) {
325338
LBANN_ERROR("Expected ready status. Got: ",
326339
to_string(grad_mgr.get_status()));
327340
}
328341
El::Copy(grad_mgr.global_gradient(), gradient);
342+
grad_mgr.clear();
329343
}
330344
else if (this->m_local_gradient_contributions.size() > 1UL) {
331345
// Need a temporary matrix for the type-casted copy.
@@ -335,14 +349,15 @@ void optimizer::accumulate_all_gradient_contributions(
335349
for (auto const& grad_mgr_v : this->m_local_gradient_contributions) {
336350
if (grad_mgr_v.first == this_type_idx)
337351
continue;
338-
auto const& grad_mgr = *(grad_mgr_v.second);
352+
auto& grad_mgr = *(grad_mgr_v.second);
339353
if (grad_mgr.get_status() != optimizer_gradient_status::ready) {
340354
LBANN_ERROR("Expected ready status. Got: ",
341355
to_string(grad_mgr.get_status()));
342356
}
343357
auto const& grad_base = grad_mgr.global_gradient();
344358
El::Copy(grad_base, *tmp);
345359
El::Axpy(one, *tmp, gradient);
360+
grad_mgr.clear();
346361
}
347362
}
348363
}

python/lbann/models/transformer.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,23 @@ def __init__(self, normalized_shape, name=None, builtin=True):
2929
self.name = (name if name else f'layernorm{LayerNorm.global_count}')
3030
self.builtin = builtin
3131

32-
if not self.builtin:
33-
# Initialize weights
34-
self.weight = lbann.Weights(
35-
initializer=lbann.ConstantInitializer(value=1),
36-
name=f'{self.name}_weight',
37-
)
38-
self.bias = lbann.Weights(
39-
initializer=lbann.ConstantInitializer(value=0),
40-
name=f'{self.name}_bias',
41-
)
32+
# Initialize weights
33+
self.weight = lbann.Weights(
34+
initializer=lbann.ConstantInitializer(value=1),
35+
name=f'{self.name}_weight',
36+
)
37+
self.bias = lbann.Weights(
38+
initializer=lbann.ConstantInitializer(value=0),
39+
name=f'{self.name}_bias',
40+
)
4241

4342
def forward(self, x):
4443
if self.builtin:
45-
return lbann.LayerNorm(x, scale=True, bias=True, name=self.name)
44+
return lbann.LayerNorm(x,
45+
scale=True,
46+
bias=True,
47+
name=self.name,
48+
weights=[self.weight, self.bias])
4649

4750
# Normalization
4851
x = lbann.InstanceNorm(x)

src/layers/data_type_layer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,11 @@ void data_type_layer<InputTensorDataType,
329329
}
330330
#endif // defined(LBANN_HAS_GPU) && defined(LBANN_DEBUG)
331331

332+
// Release the now-unnecessary full weight views
333+
for (size_t i = 0; i < this->num_weights(); ++i) {
334+
this->get_weights(i).release_full_weights();
335+
}
336+
332337
// Release activation memory as necessary
333338
model* m = this->get_model();
334339
if (m != nullptr) {

src/weights/data_type_weights.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,17 @@ void data_type_weights<TensorDataType>::do_setup_()
239239
}
240240

241241
// Construct matrix for weight values
242-
// If sharded, use STAR_VC distribution (column distributed)
242+
// If sharded, use STAR_VC distribution (column distributed) or VC_STAR (row
243+
// distributed) if width=1.
243244
auto matrix_dist = this->get_matrix_distribution();
245+
bool must_use_vc_star = (this->get_matrix_width() == 1);
244246
m_values.reset(AbsDistMatrixType::Instantiate(
245247
*matrix_dist.grid,
246248
matrix_dist.root,
247-
this->is_sharded() ? El::STAR : matrix_dist.colDist,
248-
this->is_sharded() ? El::VC : matrix_dist.rowDist,
249+
this->is_sharded() ? (must_use_vc_star ? El::VC : El::STAR)
250+
: matrix_dist.colDist,
251+
this->is_sharded() ? (must_use_vc_star ? El::STAR : El::VC)
252+
: matrix_dist.rowDist,
249253
(matrix_dist.blockHeight == 1 && matrix_dist.blockWidth == 1 ? El::ELEMENT
250254
: El::BLOCK),
251255
matrix_dist.device));

0 commit comments

Comments
 (0)