Skip to content

Commit 15bd242

Browse files
committed
add more optimizer
1 parent cb0f2f9 commit 15bd242

File tree

1 file changed

+46
-35
lines changed

1 file changed

+46
-35
lines changed

scripts/run_adloc_cc.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@
178178
## invert loss
179179
######################################################################################################
180180
# optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1)
181-
optimizer = optim.Adam(params=travel_time.parameters(), lr=0.01)
181+
# optimizer = optim.Adam(params=travel_time.parameters(), lr=0.001)
182+
optimizer = optim.RMSprop(params=travel_time.parameters(), lr=0.001)
183+
# optimizer = optim.SGD(params=travel_time.parameters(), lr=0.0003)
184+
182185
valid_index = np.ones(len(pairs), dtype=bool)
183186
EPOCHS = 100
184187
prev_loss = 1e10
@@ -298,40 +301,48 @@
298301
plotting_dd(events, stations, config, figure_path, events_init, suffix=f"_ddcc_{epoch//10}")
299302

300303
# ######################################################################################################
301-
# if len(pairs_df) < 1_000_000:
302-
# optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=100, line_search_fn="strong_wolfe")
303-
304-
# def closure():
305-
# optimizer.zero_grad()
306-
# loss = 0
307-
# # for meta in tqdm(phase_dataset, desc=f"BFGS"):
308-
# if ddp_local_rank == 0:
309-
# print(f"BFGS: ", end="")
310-
# for meta in phase_dataset:
311-
# if ddp_local_rank == 0:
312-
# print(".", end="")
313-
314-
# loss_ = travel_time(
315-
# meta["idx_sta"],
316-
# meta["idx_eve"],
317-
# meta["phase_type"],
318-
# meta["phase_time"],
319-
# meta["phase_weight"],
320-
# )["loss"]
321-
# loss_.backward()
322-
323-
# if ddp:
324-
# dist.all_reduce(loss_, op=dist.ReduceOp.SUM)
325-
# # loss_ /= ddp_world_size
326-
327-
# loss += loss_
328-
329-
# if ddp_local_rank == 0:
330-
# print(f"Loss: {loss}")
331-
# raw_travel_time.event_loc.weight.data[:, 2].clamp_(min=config["zlim_km"][0], max=config["zlim_km"][1])
332-
# return loss
333-
334-
# optimizer.step(closure)
304+
if len(pairs_df) < 1_000_000:
305+
optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=200, line_search_fn="strong_wolfe")
306+
307+
prev_loss = 1e10
308+
309+
def closure():
310+
optimizer.zero_grad()
311+
loss = 0
312+
# for meta in tqdm(phase_dataset, desc=f"BFGS"):
313+
if ddp_local_rank == 0:
314+
print(f"BFGS: ", end="")
315+
for meta in phase_dataset:
316+
if ddp_local_rank == 0:
317+
print(".", end="")
318+
319+
loss_ = travel_time(
320+
meta["idx_sta"],
321+
meta["idx_eve"],
322+
meta["phase_type"],
323+
meta["phase_time"],
324+
meta["phase_weight"],
325+
)["loss"]
326+
loss_.backward()
327+
328+
if ddp:
329+
dist.all_reduce(loss_, op=dist.ReduceOp.SUM)
330+
# loss_ /= ddp_world_size
331+
332+
loss += loss_
333+
334+
if ddp:
335+
dist.barrier()
336+
if prev_loss < loss:
337+
print(f"{prev_loss = } {loss = }")
338+
return loss
339+
prev_loss = loss.item()
340+
if ddp_local_rank == 0:
341+
print(f"Loss: {loss}")
342+
raw_travel_time.event_loc.weight.data[:, 2].clamp_(min=config["zlim_km"][0], max=config["zlim_km"][1])
343+
return loss
344+
345+
optimizer.step(closure)
335346
# ######################################################################################################
336347

337348
# %%

0 commit comments

Comments
 (0)