Skip to content

Commit d52ba30

Browse files
committed
Add zero_grads and change the meaning of step
1 parent 94ba6a5 commit d52ba30

4 files changed

Lines changed: 41 additions & 13 deletions

File tree

src/micrograd_pp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._expr import Constant, Expr, Parameter, is_grad_enabled, maximum, no_grad, relu
1+
from ._expr import Constant, Expr, Parameter, is_grad_enabled, maximum, no_grad, relu, zero_grads
22
from ._func import cat, cross_entropy_loss, softmax
33
from ._nn import (
44
BatchNorm1d,
@@ -44,4 +44,5 @@
4444
"no_grad",
4545
"relu",
4646
"softmax",
47+
"zero_grads",
4748
)

src/micrograd_pp/_expr.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import contextlib
4+
import functools
45
import itertools
56
from abc import ABC, abstractmethod
67
from collections import deque
7-
from typing import Any, Callable, Generator, Sequence
8+
from typing import Any, Callable, Generator, Iterable, Sequence
89

910
import numpy.typing as npt
1011

@@ -150,6 +151,7 @@ def __sub__(self, other: int | float | Expr) -> Expr:
150151
def _backward(self, grad: npt.NDArray) -> None:
151152
del grad
152153

154+
@functools.lru_cache(maxsize=1) # Cache for when loss.params is called multiple times
153155
def _get_nodes(self) -> deque[Expr]:
154156
retval: deque[Expr] = deque()
155157
if not self._requires_grad:
@@ -175,6 +177,10 @@ def visit(node: Expr) -> None:
175177
visit(self)
176178
return retval
177179

180+
@property
181+
def params(self) -> list[Expr]:
182+
return [node for node in self._get_nodes() if len(node._children) == 0]
183+
178184
def backward(
179185
self,
180186
init: np.ndarray | float = 1.0,
@@ -214,7 +220,7 @@ def backward(
214220
if not retain_grad:
215221
node._grad = None
216222
if opt is not None:
217-
opt.step()
223+
opt.update_state()
218224

219225
def exp(self) -> Expr:
220226
"""Return the element-wise exponential."""
@@ -356,6 +362,9 @@ def var(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -
356362
retval = _Squeeze(retval, dim=dim)
357363
return retval
358364

365+
def zero_grad(self) -> None:
366+
self._grad = None
367+
359368
@property
360369
def dtype(self) -> npt.DTypeLike:
361370
"""Data type."""
@@ -390,13 +399,23 @@ def shape(self) -> tuple[int, ...]:
390399
return self._value.shape
391400

392401

402+
def zero_grads(params: Iterable[Expr]) -> None:
403+
for param in params:
404+
param.zero_grad()
405+
406+
393407
class Opt(ABC):
408+
def step(self, params: Iterable[Expr]) -> None:
409+
for param in params:
410+
self.update_param(param)
411+
self.update_state()
412+
394413
@abstractmethod
395414
def update_param(self, param: Expr) -> None:
396415
pass
397416

398417
@abstractmethod
399-
def step(self) -> None:
418+
def update_state(self) -> None:
400419
pass
401420

402421

src/micrograd_pp/_opt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, lr: float) -> None:
1717
def update_param(self, param: Expr) -> None:
1818
param.update_value(-self._lr * param.grad)
1919

20-
def step(self) -> None:
20+
def update_state(self) -> None:
2121
pass
2222

2323

@@ -79,7 +79,7 @@ def __init__(
7979
self._moments: dict[Expr, tuple[np.ndarray, np.ndarray]] = {}
8080

8181
self._t = 0
82-
self.step()
82+
self.update_state()
8383

8484
def update_param(self, param: Expr) -> None:
8585
if param not in self._moments:
@@ -95,7 +95,7 @@ def update_param(self, param: Expr) -> None:
9595
update = -self._lr * (corrected_moment_1 / denom + self._weight_decay * param.value)
9696
param.update_value(update)
9797

98-
def step(self):
98+
def update_state(self):
9999
self._t += 1
100100
self._bias_correction_1 = 1.0 - self._beta_1**self._t
101101
self._bias_correction_2 = 1.0 - self._beta_2**self._t

tests/test_opt.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@ def run_before_and_after_tests():
1212

1313

1414
@pytest.mark.parametrize(
15-
("opt_factory", "num_steps", "atol"),
15+
("opt_factory", "num_steps", "atol", "pass_opt_to_backward"),
1616
[
17-
(lambda: mpp.SGD(lr=0.1), 150, 1e-8),
18-
(lambda: mpp.AdamW(lr=0.2, weight_decay=0.0), 600, 1e-8),
17+
(*cfg, pass_opt_to_backward)
18+
for cfg in (
19+
(lambda: mpp.SGD(lr=0.1), 150, 1e-8),
20+
(lambda: mpp.AdamW(lr=0.2, weight_decay=0.0), 600, 1e-8),
21+
)
22+
for pass_opt_to_backward in (False, True)
1923
],
20-
ids=("sgd", "adamw"),
2124
)
22-
def test_mse(opt_factory, num_steps: int, atol: float):
25+
def test_mse(opt_factory, num_steps: int, atol: float, pass_opt_to_backward: bool):
2326
n = 10
2427
coef = np.random.randn(3, 1)
2528
coef_hat = np.random.randn(3, 1)
@@ -35,6 +38,11 @@ def test_mse(opt_factory, num_steps: int, atol: float):
3538
for _ in range(num_steps):
3639
y_pred_ = x_ @ coef_hat_
3740
mse = ((y_pred_ - y_) ** 2).sum() / n
38-
mse.backward(opt=opt)
41+
if pass_opt_to_backward:
42+
mse.backward(opt=opt) # Automatically handles zeroing gradients and updating the optimizer state
43+
else:
44+
mpp.zero_grads(mse.params)
45+
mse.backward()
46+
opt.step(mse.params)
3947

4048
np.testing.assert_allclose(coef, coef_hat, rtol=0.0, atol=atol)

0 commit comments

Comments
 (0)