Skip to content

Commit 02b7f79

Browse files
Gradient Norm Clipping (#2344)
* Implement gradient norm clipping callback * Default to loss/metric when not using gradient checking --------- Co-authored-by: Pier Fiedorowicz <[email protected]> Co-authored-by: Pier Fiedorowicz <[email protected]>
1 parent e692ec3 commit 02b7f79

File tree

12 files changed

+364
-6
lines changed

12 files changed

+364
-6
lines changed

applications/nlp/transformer/pretrain_gpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def main():
116116
beta1=0.9,
117117
beta2=0.95,
118118
eps=1e-8,
119-
clip_gradient=0.0,
119+
clip_gradient=1.0,
120120
lr_decay='cosine',
121121
lr_decay_steps=int((260 * 1e9) // tokens_per_step),
122122
end_learning_rate=chosen_config.lr / 10,

applications/nlp/transformer/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def make_batch_script(model: lbann.Model,
148148
))
149149

150150
if clip_gradient > 0:
151-
raise NotImplementedError('Gradient norm clipping not yet implemented')
152151
model.callbacks.append(
153152
lbann.CallbackClipGradientNorm(global_norm=True,
154153
value=clip_gradient))

ci_test/common_python/test_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def wrapped(*args, **kwargs):
8080
error_on_failure=True,
8181
execution_modes='train' if train else 'test'))
8282

83-
obj_func = None
83+
check_grad_obj_func = None
8484
if check_gradients:
8585
if tester.check_gradients_tensor is None:
8686
raise ValueError(
8787
'LBANN test did not set a tensor for checking gradients, '
8888
'use ``ModelTester.set_check_gradients_tensor``.')
89-
obj_func = tester.check_gradients_tensor
89+
check_grad_obj_func = tester.check_gradients_tensor
9090
callbacks.append(
9191
lbann.CallbackCheckGradients(error_on_failure=True))
9292
callbacks.extend(tester.extra_callbacks)
@@ -95,7 +95,7 @@ def wrapped(*args, **kwargs):
9595
metrics.extend(tester.extra_metrics)
9696
model = lbann.Model(epochs=1 if train else 0,
9797
layers=full_graph,
98-
objective_function=obj_func,
98+
objective_function=check_grad_obj_func if check_gradients else tester.loss,
9999
metrics=metrics,
100100
callbacks=callbacks)
101101

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import lbann
2+
import numpy as np
3+
import test_util
4+
from glob import glob
5+
import functools
6+
import os
7+
8+
9+
def check_gradients(global_norm=True, clip=1.0):
10+
11+
def decorator(f):
12+
13+
@functools.wraps(f)
14+
def wrapper(*args, **kwargs):
15+
# Clear any gradient outputs from previous runs.
16+
grad_files = glob(
17+
os.path.join(test_util._get_work_dir(__file__),
18+
'gradients*.txt'))
19+
for gf in grad_files:
20+
os.remove(gf)
21+
22+
# Run the model.
23+
f(*args, **kwargs)
24+
25+
eps = np.finfo(np.float32).eps
26+
grad_files = glob(
27+
os.path.join(test_util._get_work_dir(__file__),
28+
'gradients*.txt'))
29+
30+
# Compute the weight gradient norms, check they are less than
31+
# "clip", and update global gradient norm.
32+
norm = 0
33+
for gf in grad_files:
34+
weight_norm = np.square(np.loadtxt(gf)).sum()
35+
assert np.sqrt(weight_norm) <= clip + 8 * eps
36+
norm += weight_norm
37+
38+
# Check the global gradient norm is less than "clip" if requested.
39+
if global_norm:
40+
assert np.sqrt(norm) <= clip + 8 * eps
41+
42+
return wrapper
43+
44+
return decorator
45+
46+
47+
def setup_tester(scale, global_norm, clip):
48+
np.random.seed(20231018)
49+
x = np.random.normal(scale=scale, size=[8, 16]).astype(np.float32)
50+
ref = np.zeros_like(x)
51+
52+
tester = test_util.ModelTester()
53+
x = tester.inputs(x)
54+
ref = tester.make_reference(ref)
55+
56+
y = lbann.FullyConnected(x, num_neurons=16, has_bias=True)
57+
58+
z = lbann.FullyConnected(y, num_neurons=16, has_bias=True)
59+
60+
tester.set_loss(lbann.MeanSquaredError(z, ref), tolerance=10 * scale**2)
61+
tester.extra_callbacks = [
62+
lbann.CallbackClipGradientNorm(global_norm=global_norm, value=clip),
63+
lbann.CallbackDumpGradients(basename='gradients')
64+
]
65+
return tester
66+
67+
68+
# Case where no clipping is needed.
69+
@check_gradients(global_norm=True)
70+
@test_util.lbann_test(train=True)
71+
def test_gradient_no_clipping():
72+
return setup_tester(scale=0.1, global_norm=True, clip=1.0)
73+
74+
75+
# Case with global clipping.
76+
@check_gradients(global_norm=True)
77+
@test_util.lbann_test(train=True)
78+
def test_gradient_clipping():
79+
return setup_tester(scale=1, global_norm=True, clip=1.0)
80+
81+
82+
# Case with global clipping and another clip value.
83+
@check_gradients(global_norm=True, clip=0.3)
84+
@test_util.lbann_test(train=True)
85+
def test_gradient_clipping_diffclip():
86+
return setup_tester(scale=1, global_norm=True, clip=0.3)
87+
88+
89+
# Case with per-weight clipping only.
90+
@check_gradients(global_norm=False)
91+
@test_util.lbann_test(train=True)
92+
def test_gradient_clipping_local():
93+
return setup_tester(scale=10, global_norm=False, clip=1.0)

include/lbann/callbacks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set_full_path(THIS_DIR_HEADERS
4646
dump_weights.hpp
4747
early_stopping.hpp
4848
gpu_memory_usage.hpp
49+
gradient_clipping.hpp
4950
hang.hpp
5051
learning_rate.hpp
5152
ltfb.hpp
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3+
// Produced at the Lawrence Livermore National Laboratory.
4+
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5+
// the CONTRIBUTORS file. <[email protected]>
6+
//
7+
// LLNL-CODE-697807.
8+
// All rights reserved.
9+
//
10+
// This file is part of LBANN: Livermore Big Artificial Neural Network
11+
// Toolkit. For details, see http://software.llnl.gov/LBANN or
12+
// https://github.com/LLNL/LBANN.
13+
//
14+
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15+
// may not use this file except in compliance with the License. You may
16+
// obtain a copy of the License at:
17+
//
18+
// http://www.apache.org/licenses/LICENSE-2.0
19+
//
20+
// Unless required by applicable law or agreed to in writing, software
21+
// distributed under the License is distributed on an "AS IS" BASIS,
22+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23+
// implied. See the License for the specific language governing
24+
// permissions and limitations under the license.
25+
//
26+
// gradient_clipping .hpp .cpp - Callbacks to clip gradient values in training
27+
////////////////////////////////////////////////////////////////////////////////
28+
29+
#ifndef LBANN_CALLBACKS_CALLBACK_GRADIENT_CLIPPING_HPP_INCLUDED
30+
#define LBANN_CALLBACKS_CALLBACK_GRADIENT_CLIPPING_HPP_INCLUDED
31+
32+
#include <unordered_set>
33+
#include <utility>
34+
35+
#include "lbann/callbacks/callback.hpp"
36+
37+
namespace lbann {
38+
namespace callback {
39+
40+
/**
41+
* @brief Clip gradients whose norm is larger than a user-defined value by
42+
* dividing them.
43+
*/
44+
class clip_gradient_norm : public callback_base
45+
{
46+
public:
47+
using callback_base::on_backward_prop_end;
48+
49+
/**
50+
* @param weights Parameters whose gradient to clip, or empty for all
51+
* @param global_norm Whether to clip according to the norm of all parameters
52+
* or each one separately
53+
* @param value Value to clip to
54+
*/
55+
clip_gradient_norm(std::vector<std::string> weights,
56+
bool global_norm = false,
57+
float value = 1.0f)
58+
: callback_base(1),
59+
m_weight_names(std::move(weights)),
60+
m_global_norm(global_norm),
61+
m_value(value)
62+
{}
63+
clip_gradient_norm(const clip_gradient_norm&) = default;
64+
clip_gradient_norm& operator=(const clip_gradient_norm&) = default;
65+
void setup(model* m) override;
66+
clip_gradient_norm* copy() const override
67+
{
68+
return new clip_gradient_norm(*this);
69+
}
70+
void on_backward_prop_end(model* m) override;
71+
std::string name() const override { return "clip gradient norm"; }
72+
73+
/** @name Serialization */
74+
///@{
75+
76+
/** @brief Store state to archive for checkpoint and restart */
77+
template <class Archive>
78+
void serialize(Archive& ar);
79+
80+
///@}
81+
82+
private:
83+
/** Add callback specific data to prototext */
84+
void write_specific_proto(lbann_data::Callback& proto) const final;
85+
86+
friend class cereal::access;
87+
clip_gradient_norm();
88+
89+
/** @brief Parameter names whose gradients to clip. */
90+
std::vector<std::string> m_weight_names;
91+
92+
/** @brief Whether to clip according to the norm of all parameters. */
93+
bool m_global_norm;
94+
95+
/** @brief Value to clip to. */
96+
float m_value;
97+
98+
/** Weights to update. */
99+
std::unordered_set<weights*> m_weights;
100+
};
101+
102+
// Builder function
103+
std::unique_ptr<callback_base> build_clip_gradient_norm_callback_from_pbuf(
104+
const google::protobuf::Message&,
105+
std::shared_ptr<lbann_summary> const&);
106+
107+
} // namespace callback
108+
} // namespace lbann
109+
110+
#endif // LBANN_CALLBACKS_CALLBACK_GRADIENT_CLIPPING_HPP_INCLUDED

include/lbann/lbann.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@
176176
#include "lbann/callbacks/dump_weights.hpp"
177177
#include "lbann/callbacks/early_stopping.hpp"
178178
#include "lbann/callbacks/gpu_memory_usage.hpp"
179+
#include "lbann/callbacks/gradient_clipping.hpp"
179180
#include "lbann/callbacks/hang.hpp"
180181
#include "lbann/callbacks/learning_rate.hpp"
181182
#include "lbann/callbacks/load_model.hpp"

src/base.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ CEREAL_FORCE_DYNAMIC_INIT(callback_dump_outputs);
651651
CEREAL_FORCE_DYNAMIC_INIT(callback_dump_weights);
652652
CEREAL_FORCE_DYNAMIC_INIT(callback_early_stopping);
653653
CEREAL_FORCE_DYNAMIC_INIT(callback_gpu_memory_usage);
654+
CEREAL_FORCE_DYNAMIC_INIT(callback_clip_gradient_norm);
654655
CEREAL_FORCE_DYNAMIC_INIT(callback_hang);
655656
CEREAL_FORCE_DYNAMIC_INIT(callback_load_model);
656657
CEREAL_FORCE_DYNAMIC_INIT(callback_mixup);

src/callbacks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set_full_path(THIS_DIR_SOURCES
4646
dump_weights.cpp
4747
early_stopping.cpp
4848
gpu_memory_usage.cpp
49+
gradient_clipping.cpp
4950
hang.cpp
5051
learning_rate.cpp
5152
load_model.cpp

0 commit comments

Comments
 (0)