Skip to content

Commit 826c935

Browse files
Merge pull request #6 from nlesc-dirac/linesearch_upgrade
keep sign of curvature
2 parents d54bc0e + c2d7391 commit 826c935

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

kan_pde.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from kan import KAN
66
from lbfgsb import LBFGSB
7+
from lbfgsnew import LBFGSNew
78
import torch
89
import matplotlib.pyplot as plt
910
from torch import autograd
@@ -22,7 +23,7 @@
2223
np_b = 21 # number of boundary points (along each dimension)
2324
ranges = [-1, 1]
2425

25-
model = KAN(width=[2,2,1], grid=5, k=3, grid_eps=1.0, noise_scale_base=0.25, device=mydevice)
26+
model = KAN(width=[2,2,1], grid=5, k=3, grid_eps=1.0, device=mydevice)
2627

2728
# get all parameters (all may not be trainable)
2829
n_params = sum([np.prod(p.size()) for p in model.parameters()])
@@ -70,6 +71,7 @@ def _func_sum(x):
7071
def train():
7172
# try running with batch_mode=True and batch_mode=False (both should work)
7273
optimizer = LBFGSB(model.parameters(), lower_bound=x_l, upper_bound=x_u, history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, batch_mode=True, cost_use_gradient=True)
74+
#optimizer = LBFGSNew(model.parameters(), history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, batch_mode=True, cost_use_gradient=True)
7375

7476
pbar = tqdm(range(steps), desc='description')
7577

lbfgsb.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -556,24 +556,25 @@ def step(self, closure):
556556
y=g-g_old
557557
x=torch.cat(self._copy_params_out(),0)
558558
s=x-x_old
559-
curv=abs(torch.dot(s,y))
559+
curv=(torch.dot(s,y))
560560
n_iter +=1
561561
state['n_iter'] +=1
562562

563563

564564
batch_changed=batch_mode and (n_iter==1 and state['n_iter']>1)
565565
if batch_changed:
566566
tmp_grad_1=g_old.clone(memory_format=torch.contiguous_format)
567-
tmp_grad_1.add_(self.running_avg,alpha=-1.0)
567+
tmp_grad_1.add_(self.running_avg,alpha=-1.0) # grad-oldmean
568568
self.running_avg.add_(tmp_grad_1,alpha=1.0/state['n_iter'])
569569
tmp_grad_2=g_old.clone(memory_format=torch.contiguous_format)
570-
tmp_grad_2.add_(self.running_avg,alpha=-1.0)
571-
self.running_avg_sq.addcmul_(tmp_grad_2,tmp_grad_1,value=1)
570+
tmp_grad_2.add_(self.running_avg,alpha=-1.0) # grad-newmean
571+
self.running_avg_sq.addcmul_(tmp_grad_2,tmp_grad_1,value=1) # # +(grad-newmean)(grad-oldmean)
572572
self.alphabar=1.0/(1.0+self.running_avg_sq.sum()/((state['n_iter']-1)*g_old.norm().item()))
573573

574574

575575
if (curv<self._eps):
576-
print('Warning: negative curvature detected, skipping update')
576+
if be_verbose:
577+
print('Warning: negative curvature detected, skipping update')
577578
n_iter+=1
578579
continue
579580
# in batch mode, do not update Y and S if the batch has changed
@@ -601,9 +602,6 @@ def step(self, closure):
601602
self._M=torch.linalg.pinv(MM)
602603

603604

604-
605-
606-
607605
if be_verbose and (n_iter==max_iter):
608606
print('Reached maximum number of iterations, stopping')
609607

0 commit comments

Comments
 (0)