Skip to content

Commit dcd9bd5

Browse files
committed
Add gradient clipping
1 parent d52ba30 commit dcd9bd5

6 files changed

Lines changed: 160 additions & 18 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ sys.path.insert(0, os.path.expanduser("~/micrograd-pp/python"))
5252
* ☒ Stochastic Gradient Descent (SGD)
5353
* **Training**
5454
* ☐ Exponential moving average (EMA) of model weights
55-
* Gradient clipping
55+
* Gradient clipping
5656
* ☐ Learning rate schedules

src/micrograd_pp/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._expr import Constant, Expr, Parameter, is_grad_enabled, maximum, no_grad, relu, zero_grads
22
from ._func import cat, cross_entropy_loss, softmax
3+
from ._clip import clip_grad_norm_, clip_grad_value_
34
from ._nn import (
45
BatchNorm1d,
56
Dropout,
@@ -34,6 +35,8 @@
3435
"Sequential",
3536
"SGD",
3637
"cat",
38+
"clip_grad_norm_",
39+
"clip_grad_value_",
3740
"cross_entropy_loss",
3841
"datasets",
3942
"eval",

src/micrograd_pp/_clip.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import math
2+
from collections.abc import Iterable
3+
4+
import numpy.typing as npt
5+
6+
from ._expr import Expr
7+
from ._numpy import numpy as np
8+
9+
10+
def _get_grads(params: Iterable[Expr]) -> Iterable[npt.NDArray]:
11+
return (param.grad for param in params if param.requires_grad)
12+
13+
14+
def clip_grad_value_(params: Iterable[Expr], clip_value: float) -> None:
15+
"""Clip gradient values in-place.
16+
17+
Parameters
18+
----------
19+
params
20+
Parameters whose gradients should be clipped
21+
clip_value
22+
Maximum absolute gradient value
23+
"""
24+
if clip_value < 0.0:
25+
msg = "clip_value must be non-negative"
26+
raise ValueError(msg)
27+
for grad in _get_grads(params):
28+
np.clip(grad, -clip_value, clip_value, out=grad)
29+
30+
31+
def clip_grad_norm_(
32+
params: Iterable[Expr],
33+
max_norm: float,
34+
norm_type: float = 2.0,
35+
error_if_nonfinite: bool = False,
36+
eps: float = 1e-6,
37+
) -> float:
38+
"""Clip gradient norm in-place.
39+
40+
Parameters
41+
----------
42+
params
43+
Parameters whose gradients should be clipped
44+
max_norm
45+
Maximum allowed norm
46+
norm_type
47+
Type of p-norm to use. Supports ``math.inf`` for infinity norm.
48+
error_if_nonfinite
49+
If True, raises if the total norm is NaN or infinite
50+
eps
51+
Numerical stability term added to denominator
52+
"""
53+
if max_norm < 0.0:
54+
msg = "max_norm must be non-negative"
55+
raise ValueError(msg)
56+
if eps <= 0.0:
57+
msg = "eps must be positive"
58+
raise ValueError(msg)
59+
if norm_type <= 0.0:
60+
msg = "norm_type must be positive"
61+
raise ValueError(msg)
62+
63+
grads = list(_get_grads(params))
64+
if len(grads) == 0:
65+
return 0.0
66+
67+
if math.isinf(norm_type):
68+
total_norm = max(float(np.abs(grad).max()) for grad in grads)
69+
else:
70+
total_norm = 0.0
71+
for grad in grads:
72+
total_norm += float((np.abs(grad) ** norm_type).sum())
73+
total_norm = total_norm ** (1.0 / norm_type)
74+
75+
if error_if_nonfinite and not np.isfinite(total_norm):
76+
msg = f"The total norm of gradients is non-finite: {total_norm}"
77+
raise RuntimeError(msg)
78+
79+
clip_coef = max_norm / (total_norm + eps)
80+
if clip_coef < 1.0:
81+
for grad in grads:
82+
grad *= clip_coef
83+
84+
return total_norm

tests/test_clip.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pytest
2+
3+
import micrograd_pp as mpp
4+
5+
np = mpp.numpy
6+
7+
8+
@pytest.fixture(autouse=True)
9+
def run_before_and_after_tests():
10+
np.random.seed(0)
11+
yield
12+
13+
14+
def _set_grad(param: mpp.Expr, grad: np.ndarray) -> None:
15+
param.zero_grad()
16+
param.update_grad(lambda: grad)
17+
18+
19+
def test_clip_grad_value_clamps_each_element() -> None:
20+
param = mpp.Parameter(np.array([0.0, 0.0, 0.0]))
21+
_set_grad(param, np.array([-2.0, 0.25, 3.0]))
22+
23+
mpp.clip_grad_value_([param], clip_value=0.5)
24+
25+
np.testing.assert_allclose(param.grad, np.array([-0.5, 0.25, 0.5]))
26+
27+
28+
def test_clip_grad_norm_scales_all_grads_by_common_factor() -> None:
29+
p1 = mpp.Parameter(np.zeros((2,)))
30+
p2 = mpp.Parameter(np.zeros((1,)))
31+
_set_grad(p1, np.array([3.0, 4.0]))
32+
_set_grad(p2, np.array([12.0]))
33+
34+
total_norm = mpp.clip_grad_norm_([p1, p2], max_norm=6.5, norm_type=2.0)
35+
scale = 6.5 / (13.0 + 1e-6)
36+
37+
np.testing.assert_allclose(total_norm, 13.0)
38+
np.testing.assert_allclose(p1.grad, np.array([3.0, 4.0]) * scale, atol=1e-12, rtol=0.0)
39+
np.testing.assert_allclose(p2.grad, np.array([12.0]) * scale, atol=1e-12, rtol=0.0)
40+
41+
42+
def test_clip_grad_norm_noop_when_within_threshold() -> None:
43+
p1 = mpp.Parameter(np.zeros((2,)))
44+
p2 = mpp.Parameter(np.zeros((1,)))
45+
_set_grad(p1, np.array([3.0, 4.0]))
46+
_set_grad(p2, np.array([12.0]))
47+
48+
total_norm = mpp.clip_grad_norm_([p1, p2], max_norm=13.1, norm_type=2.0)
49+
50+
np.testing.assert_allclose(total_norm, 13.0)
51+
np.testing.assert_allclose(p1.grad, np.array([3.0, 4.0]), atol=1e-12, rtol=0.0)
52+
np.testing.assert_allclose(p2.grad, np.array([12.0]), atol=1e-12, rtol=0.0)
53+
54+
55+
def test_clip_grad_norm_errors_on_nonfinite_if_requested() -> None:
56+
p = mpp.Parameter(np.zeros((1,)))
57+
_set_grad(p, np.array([np.inf]))
58+
59+
with pytest.raises(RuntimeError):
60+
mpp.clip_grad_norm_([p], max_norm=1.0, error_if_nonfinite=True)

tests/test_mnist.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ def test_mnist(batch_sz: int = 64, n_epochs: int = 3):
5959
x = mpp.Constant(train_images[batch_index])
6060
y = train_labels[batch_index]
6161
loss = cross_entropy_loss(model(x), y)
62-
loss.backward(opt=opt)
62+
params = loss.params
63+
mpp.zero_grads(params)
64+
loss.backward()
65+
mpp.clip_grad_norm_(params, max_norm=5.0)
66+
opt.step(params)
6367
test_x = mpp.Constant(test_images)
6468
with mpp.eval(), mpp.no_grad():
6569
test_fx = model(test_x)

tests/test_opt.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,13 @@ def run_before_and_after_tests():
1212

1313

1414
@pytest.mark.parametrize(
15-
("opt_factory", "num_steps", "atol", "pass_opt_to_backward"),
16-
[
17-
(*cfg, pass_opt_to_backward)
18-
for cfg in (
19-
(lambda: mpp.SGD(lr=0.1), 150, 1e-8),
20-
(lambda: mpp.AdamW(lr=0.2, weight_decay=0.0), 600, 1e-8),
21-
)
22-
for pass_opt_to_backward in (False, True)
23-
],
15+
("opt_factory", "num_steps", "atol"),
16+
(
17+
(lambda: mpp.SGD(lr=0.1), 150, 1e-8),
18+
(lambda: mpp.AdamW(lr=0.2, weight_decay=0.0), 600, 1e-8),
19+
),
2420
)
25-
def test_mse(opt_factory, num_steps: int, atol: float, pass_opt_to_backward: bool):
21+
def test_mse(opt_factory, num_steps: int, atol: float):
2622
n = 10
2723
coef = np.random.randn(3, 1)
2824
coef_hat = np.random.randn(3, 1)
@@ -38,11 +34,6 @@ def test_mse(opt_factory, num_steps: int, atol: float, pass_opt_to_backward: boo
3834
for _ in range(num_steps):
3935
y_pred_ = x_ @ coef_hat_
4036
mse = ((y_pred_ - y_) ** 2).sum() / n
41-
if pass_opt_to_backward:
42-
mse.backward(opt=opt) # Automatically handles zeroing gradients and updating the optimizer state
43-
else:
44-
mpp.zero_grads(mse.params)
45-
mse.backward()
46-
opt.step(mse.params)
37+
mse.backward(opt=opt) # Automatically handles zeroing gradients and updating the optimizer state
4738

4839
np.testing.assert_allclose(coef, coef_hat, rtol=0.0, atol=atol)

0 commit comments

Comments
 (0)