-
Notifications
You must be signed in to change notification settings - Fork 894
Add NNCG to optimizers submodule #1661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
d0e22bb
Add NNCG to optimizers submodule
pratikrathore8 3afed9f
Update nys_newton_cg.py
pratikrathore8 081d5f6
Moved NNCG to pytorch folder
pratikrathore8 03a77a1
Minor formatting changes in NNCG
pratikrathore8 88d2f7e
Update nys_newton_cg.py
pratikrathore8 fff6a91
Fix Codacy issues
pratikrathore8 19490ea
Fix more Codacy issues
pratikrathore8 ec59a99
Added NNCG to config.py and optimizers.py
pratikrathore8 8995aad
Clean up NNCG integration in optimizers.py
pratikrathore8 1b13a08
Fixed import order in optimizers.py
pratikrathore8 2d63ba3
Made demo with NNCG and Burgers equation
pratikrathore8 c356c90
Merge branch 'lululxvi:master' into master
pratikrathore8 8d0210a
refactor nncg integration
pratikrathore8 5c9beba
clean up commented code
pratikrathore8 ce1dcc8
format with black
pratikrathore8 d97ca16
remove unnecessary error checks
pratikrathore8 81c4452
fix some codacy issues in nncg
pratikrathore8 60d3ff3
further improvements to nncg integration
pratikrathore8 6e739e1
add train_step_nncg
pratikrathore8 f4e322e
improve documentation in nncg config
pratikrathore8 6338f59
added doc for nncg demo
pratikrathore8 838eaa8
added demo file to pinn_forward.rst
pratikrathore8 e0bb44d
Merge branch 'master' into master
pratikrathore8 61f08f9
Merge branch 'lululxvi:master' into master
pratikrathore8 7d58317
change ordering in pinn_forward.rst
pratikrathore8 5467524
merge burgers_nncg demo into burgers demo
pratikrathore8 f2d5c31
more cleanup in demo
pratikrathore8 74ace2c
comment out nncg code block
pratikrathore8 c1a6365
minor cleanup in burgers example
pratikrathore8 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,312 @@ | ||
| import torch | ||
| from torch.optim import Optimizer | ||
| from torch.func import vmap | ||
| from functools import reduce | ||
|
|
||
|
|
||
| def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): | ||
| """Line search to find a step size that satisfies the Armijo condition.""" | ||
| f0 = f(x, 0, dx) | ||
| f1 = f(x, t, dx) | ||
| while f1 > f0 + alpha * t * gx.dot(dx): | ||
| t *= beta | ||
| f1 = f(x, t, dx) | ||
| return t | ||
|
|
||
|
|
||
| def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x): | ||
| """Applies the inverse of the Nystrom approximation of the Hessian to a vector.""" | ||
| z = U.T @ x | ||
| z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z) | ||
| return z | ||
|
|
||
|
|
||
| def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): | ||
| """Solves a positive-definite linear system using NyströmPCG. | ||
|
|
||
| `Frangella et al. Randomized Nyström Preconditioning. | ||
| SIAM Journal on Matrix Analysis and Applications, 2023. | ||
| <https://epubs.siam.org/doi/10.1137/21M1466244>`""" | ||
pratikrathore8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| lambd_r = S[r - 1] | ||
| S_mu_inv = (S + mu) ** (-1) | ||
|
|
||
| resid = b - (hess(x) + mu * x) | ||
| with torch.no_grad(): | ||
| z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) | ||
| p = z.clone() | ||
|
|
||
| i = 0 | ||
|
|
||
| while torch.norm(resid) > tol and i < max_iters: | ||
| v = hess(p) + mu * p | ||
| with torch.no_grad(): | ||
| alpha = torch.dot(resid, z) / torch.dot(p, v) | ||
| x += alpha * p | ||
|
|
||
| rTz = torch.dot(resid, z) | ||
| resid -= alpha * v | ||
| z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) | ||
| beta = torch.dot(resid, z) / rTz | ||
|
|
||
| p = z + beta * p | ||
|
|
||
| i += 1 | ||
|
|
||
| if torch.norm(resid) > tol: | ||
| print( | ||
| f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}" | ||
| ) | ||
|
|
||
| return x | ||
|
|
||
|
|
||
| class NysNewtonCG(Optimizer): | ||
| """Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning. | ||
|
|
||
| `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. | ||
| Preprint, 2024. <https://arxiv.org/abs/2402.01868>` | ||
|
|
||
| .. warning:: | ||
| This optimizer doesn't support per-parameter options and parameter | ||
| groups (there can be only one). | ||
|
|
||
| NOTE: This optimizer is currently a beta version. | ||
|
|
||
| Our implementation is inspired by the PyTorch implementation of `L-BFGS | ||
| <https://pytorch.org/docs/stable/_modules/torch/optim/lbfgs.html#LBFGS>`. | ||
|
|
||
| The parameters rank and mu will probably need to be tuned for your specific problem. | ||
| If the optimizer is running very slowly, you can try one of the following: | ||
| - Increase the rank (this should increase the accuracy of the Nyström approximation in PCG) | ||
| - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) | ||
| - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) | ||
|
|
||
| Args: | ||
| params (iterable): iterable of parameters to optimize or dicts defining | ||
| parameter groups | ||
| lr (float, optional): learning rate (default: 1.0) | ||
| rank (int, optional): rank of the Nyström approximation (default: 10) | ||
| mu (float, optional): damping parameter (default: 1e-4) | ||
| chunk_size (int, optional): number of Hessian-vector products to be computed in parallel (default: 1) | ||
| cg_tol (float, optional): tolerance for PCG (default: 1e-16) | ||
| cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) | ||
| line_search_fn (str, optional): either 'armijo' or None (default: None) | ||
| verbose (bool, optional): verbosity (default: False) | ||
|
|
||
pratikrathore8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| params, | ||
| lr=1.0, | ||
| rank=10, | ||
| mu=1e-4, | ||
| chunk_size=1, | ||
| cg_tol=1e-16, | ||
| cg_max_iters=1000, | ||
| line_search_fn=None, | ||
| verbose=False, | ||
| ): | ||
| defaults = dict( | ||
| lr=lr, | ||
| rank=rank, | ||
| chunk_size=chunk_size, | ||
| mu=mu, | ||
| cg_tol=cg_tol, | ||
| cg_max_iters=cg_max_iters, | ||
| line_search_fn=line_search_fn, | ||
| ) | ||
| self.rank = rank | ||
| self.mu = mu | ||
| self.chunk_size = chunk_size | ||
| self.cg_tol = cg_tol | ||
| self.cg_max_iters = cg_max_iters | ||
| self.line_search_fn = line_search_fn | ||
| self.verbose = verbose | ||
| self.U = None | ||
| self.S = None | ||
| self.n_iters = 0 | ||
| super(NysNewtonCG, self).__init__(params, defaults) | ||
|
|
||
| if len(self.param_groups) > 1: | ||
| raise ValueError( | ||
| "NysNewtonCG doesn't currently support per-parameter options (parameter groups)" | ||
| ) | ||
|
|
||
| if self.line_search_fn is not None and self.line_search_fn != "armijo": | ||
| raise ValueError("NysNewtonCG only supports Armijo line search") | ||
|
|
||
| self._params = self.param_groups[0]["params"] | ||
| self._params_list = list(self._params) | ||
| self._numel_cache = None | ||
|
|
||
| def step(self, closure=None): | ||
| """Perform a single optimization step. | ||
|
|
||
| Args: | ||
| closure (callable, optional): A closure that reevaluates the model and returns (i) the loss and (ii) gradient w.r.t. the parameters. | ||
| The closure can compute the gradient w.r.t. the parameters by calling torch.autograd.grad on the loss with create_graph=True. | ||
pratikrathore8 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| if self.n_iters == 0: | ||
| # Store the previous direction for warm starting PCG | ||
| self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) | ||
|
|
||
| # NOTE: The closure must return both the loss and the gradient | ||
| loss = None | ||
| if closure is not None: | ||
| with torch.enable_grad(): | ||
| loss, grad_tuple = closure() | ||
|
|
||
| g = torch.cat([grad.view(-1) for grad in grad_tuple if grad is not None]) | ||
|
|
||
| # One step update | ||
| for group_idx, group in enumerate(self.param_groups): | ||
|
|
||
| def hvp_temp(x): | ||
| return self._hvp(g, self._params_list, x) | ||
|
|
||
| # Calculate the Newton direction | ||
| d = _nystrom_pcg( | ||
| hvp_temp, | ||
| g, | ||
| self.old_dir, | ||
| self.mu, | ||
| self.U, | ||
| self.S, | ||
| self.rank, | ||
| self.cg_tol, | ||
| self.cg_max_iters, | ||
| ) | ||
|
|
||
| # Store the previous direction for warm starting PCG | ||
| self.old_dir = d | ||
|
|
||
| # Check if d is a descent direction | ||
| if torch.dot(d, g) <= 0: | ||
| print("Warning: d is not a descent direction") | ||
|
|
||
| if self.line_search_fn == "armijo": | ||
| x_init = self._clone_param() | ||
|
|
||
| def obj_func(x, t, dx): | ||
| self._add_grad(t, dx) | ||
| loss = float(closure()[0]) | ||
| self._set_param(x) | ||
| return loss | ||
|
|
||
| # Use -d for convention | ||
| t = _armijo(obj_func, x_init, g, -d, group["lr"]) | ||
| else: | ||
| t = group["lr"] | ||
|
|
||
| self.state[group_idx]["t"] = t | ||
|
|
||
| # update parameters | ||
| ls = 0 | ||
| for p in group["params"]: | ||
| np = torch.numel(p) | ||
| dp = d[ls : ls + np].view(p.shape) | ||
| ls += np | ||
| p.data.add_(-dp, alpha=t) | ||
|
|
||
| self.n_iters += 1 | ||
|
|
||
| return loss, g | ||
|
|
||
| def update_preconditioner(self, grad_tuple): | ||
| """Update the Nystrom approximation of the Hessian. | ||
|
|
||
| Args: | ||
| grad_tuple (tuple): tuple of Tensors containing the gradients of the loss w.r.t. the parameters. | ||
| This tuple can be obtained by calling torch.autograd.grad on the loss with create_graph=True. | ||
| """ | ||
|
|
||
| # Flatten and concatenate the gradients | ||
| gradsH = torch.cat( | ||
| [gradient.view(-1) for gradient in grad_tuple if gradient is not None] | ||
| ) | ||
|
|
||
| # Generate test matrix (NOTE: This is transposed test matrix) | ||
| p = gradsH.shape[0] | ||
| Phi = torch.randn((self.rank, p), device=gradsH.device) / (p**0.5) | ||
| Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t() | ||
|
|
||
| Y = self._hvp_vmap(gradsH, self._params_list)(Phi) | ||
|
|
||
| # Calculate shift | ||
| shift = torch.finfo(Y.dtype).eps | ||
| Y_shifted = Y + shift * Phi | ||
|
|
||
| # Calculate Phi^T * H * Phi (w/ shift) for Cholesky | ||
| choleskytarget = torch.mm(Y_shifted, Phi.t()) | ||
|
|
||
| # Perform Cholesky, if fails, do eigendecomposition | ||
| # The new shift is the abs of smallest eigenvalue (negative) plus the original shift | ||
| try: | ||
| C = torch.linalg.cholesky(choleskytarget) | ||
| except: | ||
| # eigendecomposition, eigenvalues and eigenvector matrix | ||
| eigs, eigvectors = torch.linalg.eigh(choleskytarget) | ||
| shift = shift + torch.abs(torch.min(eigs)) | ||
| # add shift to eigenvalues | ||
| eigs = eigs + shift | ||
| # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T | ||
| C = torch.linalg.cholesky( | ||
| torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)) | ||
| ) | ||
|
|
||
| try: | ||
| B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True) | ||
| # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 | ||
| except: | ||
| B = torch.linalg.solve_triangular( | ||
| C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True | ||
| ).to(C.device) | ||
|
|
||
| # B = V * S * U^T b/c we have been using transposed sketch | ||
| _, S, UT = torch.linalg.svd(B, full_matrices=False) | ||
| self.U = UT.t() | ||
| self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0)) | ||
|
|
||
| self.rho = self.S[-1] | ||
|
|
||
| if self.verbose: | ||
| print(f"Approximate eigenvalues = {self.S}") | ||
|
|
||
| def _hvp_vmap(self, grad_params, params): | ||
| return vmap( | ||
| lambda v: self._hvp(grad_params, params, v), | ||
| in_dims=0, | ||
| chunk_size=self.chunk_size, | ||
| ) | ||
|
|
||
| def _hvp(self, grad_params, params, v): | ||
| Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, retain_graph=True) | ||
| Hv = tuple(Hvi.detach() for Hvi in Hv) | ||
| return torch.cat([Hvi.reshape(-1) for Hvi in Hv]) | ||
|
|
||
| def _numel(self): | ||
| if self._numel_cache is None: | ||
| self._numel_cache = reduce( | ||
| lambda total, p: total + p.numel(), self._params, 0 | ||
| ) | ||
| return self._numel_cache | ||
|
|
||
| def _add_grad(self, step_size, update): | ||
| offset = 0 | ||
| for p in self._params: | ||
| numel = p.numel() | ||
| # Avoid in-place operation by creating a new tensor | ||
| p.data = p.data.add( | ||
| update[offset : offset + numel].view_as(p), alpha=step_size | ||
| ) | ||
| offset += numel | ||
| assert offset == self._numel() | ||
|
|
||
| def _clone_param(self): | ||
| return [p.clone(memory_format=torch.contiguous_format) for p in self._params] | ||
|
|
||
| def _set_param(self, params_data): | ||
| for p, pdata in zip(self._params, params_data): | ||
| # Replace the .data attribute of the tensor | ||
| p.data = pdata.data | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.