11from __future__ import annotations
22
33import contextlib
4+ import functools
45import itertools
56from abc import ABC , abstractmethod
67from collections import deque
7- from typing import Any , Callable , Generator , Sequence
8+ from typing import Any , Callable , Generator , Iterable , Sequence
89
910import numpy .typing as npt
1011
@@ -150,6 +151,7 @@ def __sub__(self, other: int | float | Expr) -> Expr:
150151 def _backward (self , grad : npt .NDArray ) -> None :
151152 del grad
152153
154+ @functools .lru_cache (maxsize = 1 ) # Cache for when loss.params is called multiple times
153155 def _get_nodes (self ) -> deque [Expr ]:
154156 retval : deque [Expr ] = deque ()
155157 if not self ._requires_grad :
@@ -175,6 +177,10 @@ def visit(node: Expr) -> None:
175177 visit (self )
176178 return retval
177179
180+ @property
181+ def params (self ) -> list [Expr ]:
182+ return [node for node in self ._get_nodes () if len (node ._children ) == 0 ]
183+
178184 def backward (
179185 self ,
180186 init : np .ndarray | float = 1.0 ,
@@ -214,7 +220,7 @@ def backward(
214220 if not retain_grad :
215221 node ._grad = None
216222 if opt is not None :
217- opt .step ()
223+ opt .update_state ()
218224
219225 def exp (self ) -> Expr :
220226 """Return the element-wise exponential."""
@@ -356,6 +362,9 @@ def var(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -
356362 retval = _Squeeze (retval , dim = dim )
357363 return retval
358364
365+ def zero_grad (self ) -> None :
366+ self ._grad = None
367+
359368 @property
360369 def dtype (self ) -> npt .DTypeLike :
361370 """Data type."""
@@ -390,13 +399,23 @@ def shape(self) -> tuple[int, ...]:
390399 return self ._value .shape
391400
392401
402+ def zero_grads (params : Iterable [Expr ]) -> None :
403+ for param in params :
404+ param .zero_grad ()
405+
406+
393407class Opt (ABC ):
408+ def step (self , params : Iterable [Expr ]) -> None :
409+ for param in params :
410+ self .update_param (param )
411+ self .update_state ()
412+
394413 @abstractmethod
395414 def update_param (self , param : Expr ) -> None :
396415 pass
397416
398417 @abstractmethod
399- def step (self ) -> None :
418+ def update_state (self ) -> None :
400419 pass
401420
402421
0 commit comments