Skip to content

Commit ba0e44c

Browse files
committed
add tqdm progress bar to l-bfgs
1 parent c40f58f commit ba0e44c

File tree

2 files changed

+142
-138
lines changed

2 files changed

+142
-138
lines changed

tensordiffeq/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def fit(obj, tf_iter=0, newton_iter=0, batch_sz=None, newton_eager=True):
3434
# these cant be tf.functions on initialization since the distributed strategy requires its own
3535
# graph using grad and adaptgrad, so they cant be compiled as tf.functions until we know dist/non-dist
3636
obj.grad = tf.function(obj.grad)
37-
print_screen(obj, obj.domain)
37+
print_screen(obj)
3838
print("starting Adam training")
3939
# tf.profiler.experimental.start('../cache/tblogdir1')
4040
train_op_fn = train_op_inner(obj)

tensordiffeq/optimizers.py

Lines changed: 141 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tensorflow as tf
55
import tensorflow_probability as tfp
66
from matplotlib import pyplot
7+
from tqdm.auto import tqdm, trange
78
import time
89

910

@@ -140,144 +141,147 @@ def eager_lbfgs(opfunc, x, state, maxIter=100, learningRate=1, do_verbose=True):
140141
# optimize for a max of maxIter iterations
141142
nIter = 0
142143
times = []
143-
while nIter < maxIter:
144-
start_time = time.time()
145-
if state.nIter == 1:
146-
tmp1 = tf.abs(g)
147-
t = min(1, 1 / tf.reduce_sum(tmp1))
148-
else:
149-
t = learningRate
150-
# keep track of nb of iterations
151-
nIter = nIter + 1
152-
state.nIter = state.nIter + 1
153-
154-
############################################################
155-
## compute gradient descent direction
156-
############################################################
157-
if state.nIter == 1:
158-
d = -g
159-
old_dirs = []
160-
old_stps = []
161-
Hdiag = 1
162-
else:
163-
# do lbfgs update (update memory)
164-
y = g - g_old
165-
s = d * t
166-
ys = dot(y, s)
167-
168-
if ys > 1e-10:
169-
# updating memory
170-
if len(old_dirs) == nCorrection:
171-
# shift history by one (limited-memory)
172-
del old_dirs[0]
173-
del old_stps[0]
174-
175-
# store new direction/step
176-
old_dirs.append(s)
177-
old_stps.append(y)
178-
179-
# update scale of initial Hessian approximation
180-
Hdiag = ys / dot(y, y)
181-
182-
# compute the approximate (L-BFGS) inverse Hessian
183-
# multiplied by the gradient
184-
k = len(old_dirs)
185-
186-
# need to be accessed element-by-element, so don't re-type tensor:
187-
ro = [0] * nCorrection
188-
for i in range(k):
189-
ro[i] = 1 / dot(old_stps[i], old_dirs[i])
190-
191-
# iteration in L-BFGS loop collapsed to use just one buffer
192-
# need to be accessed element-by-element, so don't re-type tensor:
193-
al = [0] * nCorrection
194-
195-
q = -g
196-
for i in range(k - 1, -1, -1):
197-
al[i] = dot(old_dirs[i], q) * ro[i]
198-
q = q - al[i] * old_stps[i]
199-
200-
# multiply by initial Hessian
201-
r = q * Hdiag
202-
for i in range(k):
203-
be_i = dot(old_stps[i], r) * ro[i]
204-
r += (al[i] - be_i) * old_dirs[i]
205-
206-
d = r
207-
# final direction is in r/d (same object)
208-
209-
g_old = g
210-
f_old = f
211-
212-
############################################################
213-
## compute step length
214-
############################################################
215-
# directional derivative
216-
gtd = dot(g, d)
217-
218-
# check that progress can be made along that direction
219-
if gtd > -tolX:
220-
verbose("Can not make progress along direction.")
221-
break
222-
223-
# reset initial guess for step size
224-
if state.nIter == 1:
144+
with trange(maxIter) as t_:
145+
for epoch in t_:
146+
start_time = time.time()
147+
if state.nIter == 1:
148+
tmp1 = tf.abs(g)
149+
t = min(1, 1 / tf.reduce_sum(tmp1))
150+
else:
151+
t = learningRate
152+
# keep track of nb of iterations
153+
nIter = nIter + 1
154+
state.nIter = state.nIter + 1
155+
156+
############################################################
157+
## compute gradient descent direction
158+
############################################################
159+
if state.nIter == 1:
160+
d = -g
161+
old_dirs = []
162+
old_stps = []
163+
Hdiag = 1
164+
else:
165+
# do lbfgs update (update memory)
166+
y = g - g_old
167+
s = d * t
168+
ys = dot(y, s)
169+
170+
if ys > 1e-10:
171+
# updating memory
172+
if len(old_dirs) == nCorrection:
173+
# shift history by one (limited-memory)
174+
del old_dirs[0]
175+
del old_stps[0]
176+
177+
# store new direction/step
178+
old_dirs.append(s)
179+
old_stps.append(y)
180+
181+
# update scale of initial Hessian approximation
182+
Hdiag = ys / dot(y, y)
183+
184+
# compute the approximate (L-BFGS) inverse Hessian
185+
# multiplied by the gradient
186+
k = len(old_dirs)
187+
188+
# need to be accessed element-by-element, so don't re-type tensor:
189+
ro = [0] * nCorrection
190+
for i in range(k):
191+
ro[i] = 1 / dot(old_stps[i], old_dirs[i])
192+
193+
# iteration in L-BFGS loop collapsed to use just one buffer
194+
# need to be accessed element-by-element, so don't re-type tensor:
195+
al = [0] * nCorrection
196+
197+
q = -g
198+
for i in range(k - 1, -1, -1):
199+
al[i] = dot(old_dirs[i], q) * ro[i]
200+
q = q - al[i] * old_stps[i]
201+
202+
# multiply by initial Hessian
203+
r = q * Hdiag
204+
for i in range(k):
205+
be_i = dot(old_stps[i], r) * ro[i]
206+
r += (al[i] - be_i) * old_dirs[i]
207+
208+
d = r
209+
# final direction is in r/d (same object)
210+
211+
g_old = g
212+
f_old = f
213+
214+
############################################################
215+
## compute step length
216+
############################################################
217+
# directional derivative
218+
gtd = dot(g, d)
219+
220+
# check that progress can be made along that direction
221+
if gtd > -tolX:
222+
verbose("Can not make progress along direction.")
223+
break
224+
225+
# reset initial guess for step size
226+
if state.nIter == 1:
227+
tmp1 = tf.abs(g)
228+
t = min(1, 1 / tf.reduce_sum(tmp1))
229+
else:
230+
t = learningRate
231+
232+
x += t * d
233+
234+
if nIter != maxIter:
235+
# re-evaluate function only if not in last iteration
236+
# the reason we do this: in a stochastic setting,
237+
# no use to re-evaluate that function here
238+
f, g = opfunc(x)
239+
240+
lsFuncEval = 1
241+
f_hist.append(f)
242+
243+
# update func eval
244+
currentFuncEval = currentFuncEval + lsFuncEval
245+
state.funcEval = state.funcEval + lsFuncEval
246+
247+
############################################################
248+
## check conditions
249+
############################################################
250+
if nIter == maxIter:
251+
break
252+
253+
if currentFuncEval >= maxEval:
254+
# max nb of function evals
255+
print('max nb of function evals')
256+
break
257+
225258
tmp1 = tf.abs(g)
226-
t = min(1, 1 / tf.reduce_sum(tmp1))
227-
else:
228-
t = learningRate
229-
230-
x += t * d
231-
232-
if nIter != maxIter:
233-
# re-evaluate function only if not in last iteration
234-
# the reason we do this: in a stochastic setting,
235-
# no use to re-evaluate that function here
236-
f, g = opfunc(x)
237-
238-
lsFuncEval = 1
239-
f_hist.append(f)
240-
241-
# update func eval
242-
currentFuncEval = currentFuncEval + lsFuncEval
243-
state.funcEval = state.funcEval + lsFuncEval
244-
245-
############################################################
246-
## check conditions
247-
############################################################
248-
if nIter == maxIter:
249-
break
250-
251-
if currentFuncEval >= maxEval:
252-
# max nb of function evals
253-
print('max nb of function evals')
254-
break
255-
256-
tmp1 = tf.abs(g)
257-
if tf.reduce_sum(tmp1) <= tolFun:
258-
# check optimality
259-
print('optimality condition below tolFun')
260-
break
261-
262-
tmp1 = tf.abs(d * t)
263-
if tf.reduce_sum(tmp1) <= tolX:
264-
# step size below tolX
265-
print('step size below tolX')
266-
break
267-
268-
if tf.abs(f, f_old) < tolX:
269-
# function value changing less than tolX
270-
print('function value changing less than tolX' + str(tf.abs(f - f_old)))
271-
break
272-
273-
if do_verbose:
274-
if nIter % 100 == 0:
275-
elapsed = time.time() - state.start_time
276-
print("Step: %3d, loss: %9.8f, time: " % (nIter, f.numpy()), elapsed)
277-
state.start_time = time.time()
278-
279-
if nIter == maxIter - 1:
280-
final_loss = f.numpy()
259+
if tf.reduce_sum(tmp1) <= tolFun:
260+
# check optimality
261+
print('optimality condition below tolFun')
262+
break
263+
264+
tmp1 = tf.abs(d * t)
265+
if tf.reduce_sum(tmp1) <= tolX:
266+
# step size below tolX
267+
print('step size below tolX')
268+
break
269+
270+
if tf.abs(f, f_old) < tolX:
271+
# function value changing less than tolX
272+
print('function value changing less than tolX' + str(tf.abs(f - f_old)))
273+
break
274+
275+
t_.set_description('L-BFGS epoch %i' % epoch)
276+
if do_verbose:
277+
if nIter % 10 == 0:
278+
t_.set_postfix(loss=f.numpy())
279+
elapsed = time.time() - state.start_time
280+
#print("Step: %3d, loss: %9.8f, time: " % (nIter, f.numpy()), elapsed)
281+
state.start_time = time.time()
282+
283+
if nIter == maxIter - 1:
284+
final_loss = f.numpy()
281285

282286
# save state
283287
state.old_dirs = old_dirs

0 commit comments

Comments
 (0)