|
4 | 4 | import tensorflow as tf |
5 | 5 | import tensorflow_probability as tfp |
6 | 6 | from matplotlib import pyplot |
| 7 | +from tqdm.auto import tqdm, trange |
7 | 8 | import time |
8 | 9 |
|
9 | 10 |
|
@@ -140,144 +141,147 @@ def eager_lbfgs(opfunc, x, state, maxIter=100, learningRate=1, do_verbose=True): |
140 | 141 | # optimize for a max of maxIter iterations |
141 | 142 | nIter = 0 |
142 | 143 | times = [] |
143 | | - while nIter < maxIter: |
144 | | - start_time = time.time() |
145 | | - if state.nIter == 1: |
146 | | - tmp1 = tf.abs(g) |
147 | | - t = min(1, 1 / tf.reduce_sum(tmp1)) |
148 | | - else: |
149 | | - t = learningRate |
150 | | - # keep track of nb of iterations |
151 | | - nIter = nIter + 1 |
152 | | - state.nIter = state.nIter + 1 |
153 | | - |
154 | | - ############################################################ |
155 | | - ## compute gradient descent direction |
156 | | - ############################################################ |
157 | | - if state.nIter == 1: |
158 | | - d = -g |
159 | | - old_dirs = [] |
160 | | - old_stps = [] |
161 | | - Hdiag = 1 |
162 | | - else: |
163 | | - # do lbfgs update (update memory) |
164 | | - y = g - g_old |
165 | | - s = d * t |
166 | | - ys = dot(y, s) |
167 | | - |
168 | | - if ys > 1e-10: |
169 | | - # updating memory |
170 | | - if len(old_dirs) == nCorrection: |
171 | | - # shift history by one (limited-memory) |
172 | | - del old_dirs[0] |
173 | | - del old_stps[0] |
174 | | - |
175 | | - # store new direction/step |
176 | | - old_dirs.append(s) |
177 | | - old_stps.append(y) |
178 | | - |
179 | | - # update scale of initial Hessian approximation |
180 | | - Hdiag = ys / dot(y, y) |
181 | | - |
182 | | - # compute the approximate (L-BFGS) inverse Hessian |
183 | | - # multiplied by the gradient |
184 | | - k = len(old_dirs) |
185 | | - |
186 | | - # need to be accessed element-by-element, so don't re-type tensor: |
187 | | - ro = [0] * nCorrection |
188 | | - for i in range(k): |
189 | | - ro[i] = 1 / dot(old_stps[i], old_dirs[i]) |
190 | | - |
191 | | - # iteration in L-BFGS loop collapsed to use just one buffer |
192 | | - # need to be accessed element-by-element, so don't re-type tensor: |
193 | | - al = [0] * nCorrection |
194 | | - |
195 | | - q = -g |
196 | | - for i in range(k - 1, -1, -1): |
197 | | - al[i] = dot(old_dirs[i], q) * ro[i] |
198 | | - q = q - al[i] * old_stps[i] |
199 | | - |
200 | | - # multiply by initial Hessian |
201 | | - r = q * Hdiag |
202 | | - for i in range(k): |
203 | | - be_i = dot(old_stps[i], r) * ro[i] |
204 | | - r += (al[i] - be_i) * old_dirs[i] |
205 | | - |
206 | | - d = r |
207 | | - # final direction is in r/d (same object) |
208 | | - |
209 | | - g_old = g |
210 | | - f_old = f |
211 | | - |
212 | | - ############################################################ |
213 | | - ## compute step length |
214 | | - ############################################################ |
215 | | - # directional derivative |
216 | | - gtd = dot(g, d) |
217 | | - |
218 | | - # check that progress can be made along that direction |
219 | | - if gtd > -tolX: |
220 | | - verbose("Can not make progress along direction.") |
221 | | - break |
222 | | - |
223 | | - # reset initial guess for step size |
224 | | - if state.nIter == 1: |
| 144 | + with trange(maxIter) as t_: |
| 145 | + for epoch in t_: |
| 146 | + start_time = time.time() |
| 147 | + if state.nIter == 1: |
| 148 | + tmp1 = tf.abs(g) |
| 149 | + t = min(1, 1 / tf.reduce_sum(tmp1)) |
| 150 | + else: |
| 151 | + t = learningRate |
| 152 | + # keep track of nb of iterations |
| 153 | + nIter = nIter + 1 |
| 154 | + state.nIter = state.nIter + 1 |
| 155 | + |
| 156 | + ############################################################ |
| 157 | + ## compute gradient descent direction |
| 158 | + ############################################################ |
| 159 | + if state.nIter == 1: |
| 160 | + d = -g |
| 161 | + old_dirs = [] |
| 162 | + old_stps = [] |
| 163 | + Hdiag = 1 |
| 164 | + else: |
| 165 | + # do lbfgs update (update memory) |
| 166 | + y = g - g_old |
| 167 | + s = d * t |
| 168 | + ys = dot(y, s) |
| 169 | + |
| 170 | + if ys > 1e-10: |
| 171 | + # updating memory |
| 172 | + if len(old_dirs) == nCorrection: |
| 173 | + # shift history by one (limited-memory) |
| 174 | + del old_dirs[0] |
| 175 | + del old_stps[0] |
| 176 | + |
| 177 | + # store new direction/step |
| 178 | + old_dirs.append(s) |
| 179 | + old_stps.append(y) |
| 180 | + |
| 181 | + # update scale of initial Hessian approximation |
| 182 | + Hdiag = ys / dot(y, y) |
| 183 | + |
| 184 | + # compute the approximate (L-BFGS) inverse Hessian |
| 185 | + # multiplied by the gradient |
| 186 | + k = len(old_dirs) |
| 187 | + |
| 188 | + # need to be accessed element-by-element, so don't re-type tensor: |
| 189 | + ro = [0] * nCorrection |
| 190 | + for i in range(k): |
| 191 | + ro[i] = 1 / dot(old_stps[i], old_dirs[i]) |
| 192 | + |
| 193 | + # iteration in L-BFGS loop collapsed to use just one buffer |
| 194 | + # need to be accessed element-by-element, so don't re-type tensor: |
| 195 | + al = [0] * nCorrection |
| 196 | + |
| 197 | + q = -g |
| 198 | + for i in range(k - 1, -1, -1): |
| 199 | + al[i] = dot(old_dirs[i], q) * ro[i] |
| 200 | + q = q - al[i] * old_stps[i] |
| 201 | + |
| 202 | + # multiply by initial Hessian |
| 203 | + r = q * Hdiag |
| 204 | + for i in range(k): |
| 205 | + be_i = dot(old_stps[i], r) * ro[i] |
| 206 | + r += (al[i] - be_i) * old_dirs[i] |
| 207 | + |
| 208 | + d = r |
| 209 | + # final direction is in r/d (same object) |
| 210 | + |
| 211 | + g_old = g |
| 212 | + f_old = f |
| 213 | + |
| 214 | + ############################################################ |
| 215 | + ## compute step length |
| 216 | + ############################################################ |
| 217 | + # directional derivative |
| 218 | + gtd = dot(g, d) |
| 219 | + |
| 220 | + # check that progress can be made along that direction |
| 221 | + if gtd > -tolX: |
| 222 | + verbose("Can not make progress along direction.") |
| 223 | + break |
| 224 | + |
| 225 | + # reset initial guess for step size |
| 226 | + if state.nIter == 1: |
| 227 | + tmp1 = tf.abs(g) |
| 228 | + t = min(1, 1 / tf.reduce_sum(tmp1)) |
| 229 | + else: |
| 230 | + t = learningRate |
| 231 | + |
| 232 | + x += t * d |
| 233 | + |
| 234 | + if nIter != maxIter: |
| 235 | + # re-evaluate function only if not in last iteration |
| 236 | + # the reason we do this: in a stochastic setting, |
| 237 | + # no use to re-evaluate that function here |
| 238 | + f, g = opfunc(x) |
| 239 | + |
| 240 | + lsFuncEval = 1 |
| 241 | + f_hist.append(f) |
| 242 | + |
| 243 | + # update func eval |
| 244 | + currentFuncEval = currentFuncEval + lsFuncEval |
| 245 | + state.funcEval = state.funcEval + lsFuncEval |
| 246 | + |
| 247 | + ############################################################ |
| 248 | + ## check conditions |
| 249 | + ############################################################ |
| 250 | + if nIter == maxIter: |
| 251 | + break |
| 252 | + |
| 253 | + if currentFuncEval >= maxEval: |
| 254 | + # max nb of function evals |
| 255 | + print('max nb of function evals') |
| 256 | + break |
| 257 | + |
225 | 258 | tmp1 = tf.abs(g) |
226 | | - t = min(1, 1 / tf.reduce_sum(tmp1)) |
227 | | - else: |
228 | | - t = learningRate |
229 | | - |
230 | | - x += t * d |
231 | | - |
232 | | - if nIter != maxIter: |
233 | | - # re-evaluate function only if not in last iteration |
234 | | - # the reason we do this: in a stochastic setting, |
235 | | - # no use to re-evaluate that function here |
236 | | - f, g = opfunc(x) |
237 | | - |
238 | | - lsFuncEval = 1 |
239 | | - f_hist.append(f) |
240 | | - |
241 | | - # update func eval |
242 | | - currentFuncEval = currentFuncEval + lsFuncEval |
243 | | - state.funcEval = state.funcEval + lsFuncEval |
244 | | - |
245 | | - ############################################################ |
246 | | - ## check conditions |
247 | | - ############################################################ |
248 | | - if nIter == maxIter: |
249 | | - break |
250 | | - |
251 | | - if currentFuncEval >= maxEval: |
252 | | - # max nb of function evals |
253 | | - print('max nb of function evals') |
254 | | - break |
255 | | - |
256 | | - tmp1 = tf.abs(g) |
257 | | - if tf.reduce_sum(tmp1) <= tolFun: |
258 | | - # check optimality |
259 | | - print('optimality condition below tolFun') |
260 | | - break |
261 | | - |
262 | | - tmp1 = tf.abs(d * t) |
263 | | - if tf.reduce_sum(tmp1) <= tolX: |
264 | | - # step size below tolX |
265 | | - print('step size below tolX') |
266 | | - break |
267 | | - |
268 | | - if tf.abs(f, f_old) < tolX: |
269 | | - # function value changing less than tolX |
270 | | - print('function value changing less than tolX' + str(tf.abs(f - f_old))) |
271 | | - break |
272 | | - |
273 | | - if do_verbose: |
274 | | - if nIter % 100 == 0: |
275 | | - elapsed = time.time() - state.start_time |
276 | | - print("Step: %3d, loss: %9.8f, time: " % (nIter, f.numpy()), elapsed) |
277 | | - state.start_time = time.time() |
278 | | - |
279 | | - if nIter == maxIter - 1: |
280 | | - final_loss = f.numpy() |
| 259 | + if tf.reduce_sum(tmp1) <= tolFun: |
| 260 | + # check optimality |
| 261 | + print('optimality condition below tolFun') |
| 262 | + break |
| 263 | + |
| 264 | + tmp1 = tf.abs(d * t) |
| 265 | + if tf.reduce_sum(tmp1) <= tolX: |
| 266 | + # step size below tolX |
| 267 | + print('step size below tolX') |
| 268 | + break |
| 269 | + |
| 270 | + if tf.abs(f, f_old) < tolX: |
| 271 | + # function value changing less than tolX |
| 272 | + print('function value changing less than tolX' + str(tf.abs(f - f_old))) |
| 273 | + break |
| 274 | + |
| 275 | + t_.set_description('L-BFGS epoch %i' % epoch) |
| 276 | + if do_verbose: |
| 277 | + if nIter % 10 == 0: |
| 278 | + t_.set_postfix(loss=f.numpy()) |
| 279 | + elapsed = time.time() - state.start_time |
| 280 | + #print("Step: %3d, loss: %9.8f, time: " % (nIter, f.numpy()), elapsed) |
| 281 | + state.start_time = time.time() |
| 282 | + |
| 283 | + if nIter == maxIter - 1: |
| 284 | + final_loss = f.numpy() |
281 | 285 |
|
282 | 286 | # save state |
283 | 287 | state.old_dirs = old_dirs |
|
0 commit comments