Skip to content

Commit 6bf3c4e

Browse files
committed
update japn
1 parent a0f9dd0 commit 6bf3c4e

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

examples/japan/run_adloc_cc.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,14 @@
168168
optimizer = optim.Adam(
169169
[
170170
{"params": travel_time.event_loc.parameters(), "lr": lr}, # learning rate for event_loc
171-
{"params": travel_time.event_time.parameters(), "lr": lr * 0.1}, # learning rate for event_time
171+
{"params": travel_time.event_time.parameters(), "lr": lr}, # learning rate for event_time
172172
],
173173
lr=lr,
174174
)
175175
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=lr * 0.1)
176-
scaler = optim.lr_scheduler.ReduceLROnPlateau(
177-
optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.9, patience=3, threshold=0.05
178-
)
176+
# scaler = optim.lr_scheduler.ReduceLROnPlateau(
177+
# optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.95, patience=3, threshold=0.05
178+
# )
179179
valid_index = np.ones(len(pairs), dtype=bool)
180180

181181
if ddp:
@@ -217,7 +217,7 @@
217217
# torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0)
218218
optimizer.step()
219219
scheduler.step()
220-
scaler.step(loss)
220+
# scaler.step(loss)
221221
with torch.no_grad():
222222
raw_travel_time.event_loc.weight.data[:, 2].clamp_(
223223
min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1
@@ -247,8 +247,8 @@
247247
weight = np.concatenate(weight)
248248
# threshold_time = 6.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 2.0 # s
249249
threshold_time = 6.0 * (EPOCHS - 1 - epoch) / EPOCHS + 2.0 # s
250-
print(f"Scaler: {scaler.get_last_lr()[0]}")
251-
threshold_time *= scaler.get_last_lr()[0]
250+
# print(f"Scaler: {scaler.get_last_lr()[0]}")
251+
# threshold_time *= scaler.get_last_lr()[0]
252252
# valid_index = np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * threshold_time
253253
# weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])[valid_index]) ** 2, weights=weight[valid_index]))
254254
weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])) ** 2, weights=weight))
@@ -284,7 +284,7 @@
284284
)
285285
# threshold_space = 9.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 1.0 # km
286286
threshold_space = 9.0 * (EPOCHS - 1 - epoch) / EPOCHS + 1.0 # km
287-
threshold_space *= scaler.get_last_lr()[0]
287+
# threshold_space *= scaler.get_last_lr()[0]
288288
num_picks = len(pairs_df)
289289
pairs_df = pairs_df[pairs_df["dist_km"] < threshold_space]
290290
print(f"Filter by space: {num_picks} -> {len(pairs_df)} using threshold {threshold_space:.2f}")

0 commit comments

Comments
 (0)