|
178 | 178 | ## invert loss |
179 | 179 | ###################################################################################################### |
180 | 180 | # optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1) |
181 | | - optimizer = optim.Adam(params=travel_time.parameters(), lr=0.01) |
| 181 | + # optimizer = optim.Adam(params=travel_time.parameters(), lr=0.001) |
| 182 | + optimizer = optim.RMSprop(params=travel_time.parameters(), lr=0.001) |
| 183 | + # optimizer = optim.SGD(params=travel_time.parameters(), lr=0.0003) |
| 184 | + |
182 | 185 | valid_index = np.ones(len(pairs), dtype=bool) |
183 | 186 | EPOCHS = 100 |
184 | 187 | prev_loss = 1e10 |
|
298 | 301 | plotting_dd(events, stations, config, figure_path, events_init, suffix=f"_ddcc_{epoch//10}") |
299 | 302 |
|
300 | 303 | # ###################################################################################################### |
301 | | - # if len(pairs_df) < 1_000_000: |
302 | | - # optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=100, line_search_fn="strong_wolfe") |
303 | | - |
304 | | - # def closure(): |
305 | | - # optimizer.zero_grad() |
306 | | - # loss = 0 |
307 | | - # # for meta in tqdm(phase_dataset, desc=f"BFGS"): |
308 | | - # if ddp_local_rank == 0: |
309 | | - # print(f"BFGS: ", end="") |
310 | | - # for meta in phase_dataset: |
311 | | - # if ddp_local_rank == 0: |
312 | | - # print(".", end="") |
313 | | - |
314 | | - # loss_ = travel_time( |
315 | | - # meta["idx_sta"], |
316 | | - # meta["idx_eve"], |
317 | | - # meta["phase_type"], |
318 | | - # meta["phase_time"], |
319 | | - # meta["phase_weight"], |
320 | | - # )["loss"] |
321 | | - # loss_.backward() |
322 | | - |
323 | | - # if ddp: |
324 | | - # dist.all_reduce(loss_, op=dist.ReduceOp.SUM) |
325 | | - # # loss_ /= ddp_world_size |
326 | | - |
327 | | - # loss += loss_ |
328 | | - |
329 | | - # if ddp_local_rank == 0: |
330 | | - # print(f"Loss: {loss}") |
331 | | - # raw_travel_time.event_loc.weight.data[:, 2].clamp_(min=config["zlim_km"][0], max=config["zlim_km"][1]) |
332 | | - # return loss |
333 | | - |
334 | | - # optimizer.step(closure) |
| 304 | + if len(pairs_df) < 1_000_000: |
| 305 | + optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=200, line_search_fn="strong_wolfe") |
| 306 | + |
| 307 | + prev_loss = 1e10 |
| 308 | + |
| 309 | + def closure(): |
| 310 | + optimizer.zero_grad() |
| 311 | + loss = 0 |
| 312 | + # for meta in tqdm(phase_dataset, desc=f"BFGS"): |
| 313 | + if ddp_local_rank == 0: |
| 314 | + print(f"BFGS: ", end="") |
| 315 | + for meta in phase_dataset: |
| 316 | + if ddp_local_rank == 0: |
| 317 | + print(".", end="") |
| 318 | + |
| 319 | + loss_ = travel_time( |
| 320 | + meta["idx_sta"], |
| 321 | + meta["idx_eve"], |
| 322 | + meta["phase_type"], |
| 323 | + meta["phase_time"], |
| 324 | + meta["phase_weight"], |
| 325 | + )["loss"] |
| 326 | + loss_.backward() |
| 327 | + |
| 328 | + if ddp: |
| 329 | + dist.all_reduce(loss_, op=dist.ReduceOp.SUM) |
| 330 | + # loss_ /= ddp_world_size |
| 331 | + |
| 332 | + loss += loss_ |
| 333 | + |
| 334 | + if ddp: |
| 335 | + dist.barrier() |
| 336 | + if prev_loss < loss: |
| 337 | + print(f"{prev_loss = } {loss = }") |
| 338 | + return loss |
| 339 | + prev_loss = loss.item() |
| 340 | + if ddp_local_rank == 0: |
| 341 | + print(f"Loss: {loss}") |
| 342 | + raw_travel_time.event_loc.weight.data[:, 2].clamp_(min=config["zlim_km"][0], max=config["zlim_km"][1]) |
| 343 | + return loss |
| 344 | + |
| 345 | + optimizer.step(closure) |
335 | 346 | # ###################################################################################################### |
336 | 347 |
|
337 | 348 | # %% |
|
0 commit comments