Skip to content

Commit e18aed3

Browse files
committed
update run_adloc_cc.py
1 parent d691c02 commit e18aed3

File tree

2 files changed

+119
-31
lines changed

2 files changed

+119
-31
lines changed

examples/japan/run_adloc_cc.py

Lines changed: 111 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,6 @@
7575
print(json.dumps(config, indent=4))
7676
config["use_amplitude"] = True
7777

78-
# ## Eikonal for 1D velocity model
79-
zz = [0.0, 5.5, 16.0, 32.0]
80-
vp = [5.5, 5.5, 6.7, 7.8]
81-
vp_vs_ratio = 1.73
82-
vs = [v / vp_vs_ratio for v in vp]
83-
h = 0.3
84-
8578
# %%
8679
if not os.path.exists(result_path):
8780
os.makedirs(result_path)
@@ -126,10 +119,11 @@
126119
# vp = [5.5, 5.5, 6.7, 7.8]
127120
# vp_vs_ratio = 1.73
128121
# vs = [v / vp_vs_ratio for v in vp]
129-
# zz = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 30.0]
130-
# vp = [4.746, 4.793, 4.799, 5.045, 5.721, 5.879, 6.504, 6.708, 6.725, 7.800]
122+
zz = [0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 17.0, 21.0, 31.00, 31.10]
123+
vp = [5.30, 5.65, 5.93, 6.20, 6.20, 6.20, 6.20, 6.20, 6.20, 6.20, 7.50, 8.11]
131124
# vs = [2.469, 2.470, 2.929, 2.930, 3.402, 3.403, 3.848, 3.907, 3.963, 4.500]
132-
# h = 0.3
125+
vs = [v / 1.73 for v in vp]
126+
h = 0.3
133127
vel = {"Z": zz, "P": vp, "S": vs}
134128
config["eikonal"] = {
135129
"vel": vel,
@@ -165,31 +159,52 @@
165159
# event_time=event_time,
166160
eikonal=config["eikonal"],
167161
)
162+
163+
## invert loss
164+
######################################################################################################
165+
EPOCHS = 500
166+
lr = 0.01
167+
# optimizer = optim.Adam(params=travel_time.parameters(), lr=0.01)
168+
optimizer = optim.Adam(
169+
[
170+
{"params": travel_time.event_loc.parameters(), "lr": lr}, # learning rate for event_loc
171+
{"params": travel_time.event_time.parameters(), "lr": lr * 0.1}, # learning rate for event_time
172+
],
173+
lr=lr,
174+
)
175+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=lr * 0.1)
176+
scaler = optim.lr_scheduler.ReduceLROnPlateau(
177+
optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.9, patience=3, threshold=0.05
178+
)
179+
valid_index = np.ones(len(pairs), dtype=bool)
180+
168181
if ddp:
169182
travel_time = DDP(travel_time)
170183
raw_travel_time = travel_time.module if ddp else travel_time
171184

172185
if ddp_local_rank == 0:
173186
print(f"Dataset: {len(events)} events, {len(stations)} stations, {len(data_loader)} batches")
174187

175-
## invert loss
176-
######################################################################################################
177-
optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1)
178-
valid_index = np.ones(len(pairs), dtype=bool)
179-
EPOCHS = 100
188+
NUM_PAIRS = len(data_loader)
180189
for epoch in range(EPOCHS):
181190
loss = 0
182191
optimizer.zero_grad()
183192
# for meta in tqdm(phase_dataset, desc=f"Epoch {i}"):
184193
for meta in data_loader:
194+
if meta is None:
195+
continue
196+
185197
out = travel_time(
186198
meta["idx_sta"],
187199
meta["idx_eve"],
188200
meta["phase_type"],
189201
meta["phase_time"],
190202
meta["phase_weight"],
191203
)
192-
pred_, loss_ = out["phase_time"], out["loss"]
204+
if out is None:
205+
continue
206+
207+
pred_, loss_ = out["phase_time"], out["loss"] / NUM_PAIRS
193208

194209
loss_.backward()
195210

@@ -201,18 +216,24 @@
201216

202217
# torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0)
203218
optimizer.step()
219+
scheduler.step()
220+
scaler.step(loss)
204221
with torch.no_grad():
205222
raw_travel_time.event_loc.weight.data[:, 2].clamp_(
206223
min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1
207224
)
208-
raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0
225+
# raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0
209226
if ddp_local_rank == 0:
210-
print(f"Epoch {epoch}: loss {loss:.6e} of {np.sum(valid_index)} picks, {loss / np.sum(valid_index):.6e}")
227+
print(
228+
f"Epoch {epoch}: loss {loss:.6e} of {np.sum(valid_index)} picks, {loss / np.sum(valid_index):.6e}, lr {scheduler.get_last_lr()[0]:.5f}"
229+
)
211230

212231
### filtering
213232
pred_time = []
233+
weight = []
214234
phase_dataset.valid_index = np.ones(len(pairs), dtype=bool)
215235
for meta in phase_dataset:
236+
weight.append(meta["phase_weight"].detach().numpy())
216237
meta = travel_time(
217238
meta["idx_sta"],
218239
meta["idx_eve"],
@@ -223,9 +244,15 @@
223244
pred_time.append(meta["phase_time"].detach().numpy())
224245

225246
pred_time = np.concatenate(pred_time)
226-
valid_index = (
227-
np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * 3.0
228-
) # * (np.cos(epoch * np.pi / EPOCHS) + 2.0) # 3std -> 1std
247+
weight = np.concatenate(weight)
248+
# threshold_time = 6.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 2.0 # s
249+
threshold_time = 6.0 * (EPOCHS - 1 - epoch) / EPOCHS + 2.0 # s
250+
print(f"Scaler: {scaler.get_last_lr()[0]}")
251+
threshold_time *= scaler.get_last_lr()[0]
252+
# valid_index = np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * threshold_time
253+
# weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])[valid_index]) ** 2, weights=weight[valid_index]))
254+
weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])) ** 2, weights=weight))
255+
valid_index = np.abs(pred_time - pairs["dt"]) < weighted_std * threshold_time
229256

230257
pairs_df = pd.DataFrame(
231258
{
@@ -234,11 +261,40 @@
234261
"station_index": pairs["idx_sta"],
235262
}
236263
)
264+
num_picks = len(pairs_df)
237265
pairs_df = pairs_df[valid_index]
266+
print(f"Filter by time: {num_picks} -> {len(pairs_df)} using threshold {threshold_time:.2f}")
267+
268+
event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
269+
event_loc = pd.DataFrame(
270+
{
271+
"x_km": event_loc[:, 0],
272+
"y_km": event_loc[:, 1],
273+
"z_km": event_loc[:, 2],
274+
}
275+
)
276+
pairs_df = pairs_df.merge(event_loc[["x_km", "y_km", "z_km"]], left_on="event_index1", right_index=True)
277+
pairs_df.rename(columns={"x_km": "x_km_1", "y_km": "y_km_1", "z_km": "z_km_1"}, inplace=True)
278+
pairs_df = pairs_df.merge(event_loc[["x_km", "y_km", "z_km"]], left_on="event_index2", right_index=True)
279+
pairs_df.rename(columns={"x_km": "x_km_2", "y_km": "y_km_2", "z_km": "z_km_2"}, inplace=True)
280+
pairs_df["dist_km"] = np.sqrt(
281+
(pairs_df["x_km_1"] - pairs_df["x_km_2"]) ** 2
282+
+ (pairs_df["y_km_1"] - pairs_df["y_km_2"]) ** 2
283+
+ (pairs_df["z_km_1"] - pairs_df["z_km_2"]) ** 2
284+
)
285+
# threshold_space = 9.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 1.0 # km
286+
threshold_space = 9.0 * (EPOCHS - 1 - epoch) / EPOCHS + 1.0 # km
287+
threshold_space *= scaler.get_last_lr()[0]
288+
num_picks = len(pairs_df)
289+
pairs_df = pairs_df[pairs_df["dist_km"] < threshold_space]
290+
print(f"Filter by space: {num_picks} -> {len(pairs_df)} using threshold {threshold_space:.2f}")
291+
238292
config["MIN_OBS"] = 8
293+
num_picks = len(pairs_df)
239294
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(
240295
lambda x: len(x) >= config["MIN_OBS"]
241296
)
297+
print(f"Filter by MIN_OBS: {num_picks} -> {len(pairs_df)} using threshold {config['MIN_OBS']:d}")
242298
valid_index = np.zeros(len(pairs), dtype=bool)
243299
valid_index[pairs_df.index] = True
244300

@@ -252,6 +308,27 @@
252308
)
253309
valid_event_index = np.sort(np.unique(valid_event_index))
254310

311+
print(
312+
f"{invert_event_time.shape = }, {invert_event_time.min() = }, {invert_event_time.max() = }, {np.median(invert_event_time) = }"
313+
)
314+
315+
# # ## correct events time
316+
# pairs_df = pd.DataFrame(
317+
# {
318+
# "event_index1": pairs["idx_eve1"],
319+
# "event_index2": pairs["idx_eve2"],
320+
# "resisual": pred_time - pairs["dt"],
321+
# }
322+
# )
323+
# # pair_df = pairs_df[valid_index]
324+
# res1 = pairs_df.groupby("event_index1")["resisual"].median()
325+
# res2 = pairs_df.groupby("event_index2")["resisual"].median()
326+
# res = pd.Series(np.zeros(len(events_init)), index=events_init.index)
327+
# res = res.add(-res1, fill_value=0)
328+
# res = res.add(res2, fill_value=0)
329+
# print(f"{res.describe() = }")
330+
# raw_travel_time.event_time.weight.data = torch.tensor(res.values[:, np.newaxis] / 2.0, dtype=torch.float32)
331+
255332
if ddp_local_rank == 0 and (epoch % 10 == 0):
256333
events = events_init.copy()
257334
events["time"] = events["time"] + pd.to_timedelta(np.squeeze(invert_event_time), unit="s")
@@ -271,8 +348,8 @@
271348
)
272349
plotting_dd(events, stations, config, figure_path, events_init, suffix=f"_ddcc_{epoch//10}")
273350

274-
# ######################################################################################################
275-
# optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=10, line_search_fn="strong_wolfe")
351+
# # ######################################################################################################
352+
# optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=50, line_search_fn="strong_wolfe")
276353

277354
# def closure():
278355
# optimizer.zero_grad()
@@ -284,13 +361,16 @@
284361
# if ddp_local_rank == 0:
285362
# print(".", end="")
286363

287-
# loss_ = travel_time(
288-
# meta["idx_sta"],
289-
# meta["idx_eve"],
290-
# meta["phase_type"],
291-
# meta["phase_time"],
292-
# meta["phase_weight"],
293-
# )["loss"]
364+
# loss_ = (
365+
# travel_time(
366+
# meta["idx_sta"],
367+
# meta["idx_eve"],
368+
# meta["phase_type"],
369+
# meta["phase_time"],
370+
# meta["phase_weight"],
371+
# )["loss"]
372+
# / NUM_PAIRS
373+
# )
294374
# loss_.backward()
295375

296376
# if ddp:
@@ -305,7 +385,7 @@
305385
# return loss
306386

307387
# optimizer.step(closure)
308-
# ######################################################################################################
388+
# # ######################################################################################################
309389

310390
# %%
311391
if ddp_local_rank == 0:

scripts/plot_catalog.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def parse_args():
130130
catalog_ct_hypodd = catalog_ct_hypodd[catalog_ct_hypodd["DEPTH"] != "*********"]
131131
catalog_ct_hypodd["DEPTH"] = catalog_ct_hypodd["DEPTH"].astype(float)
132132

133+
catalog_ct_hypodd.to_csv(f"{root_path}/{region}/hypodd/hypodd_ct.csv", index=False)
134+
133135
plt.figure()
134136
plt.scatter(catalog_ct_hypodd["LON"], catalog_ct_hypodd["LAT"], s=2)
135137
plt.show()
@@ -179,6 +181,8 @@ def parse_args():
179181
catalog_cc_hypodd = catalog_cc_hypodd[catalog_cc_hypodd["DEPTH"] != "*********"]
180182
catalog_cc_hypodd["DEPTH"] = catalog_cc_hypodd["DEPTH"].astype(float)
181183

184+
catalog_cc_hypodd.to_csv(f"{root_path}/{region}/hypodd/hypodd_cc.csv", index=False)
185+
182186
plt.figure()
183187
plt.scatter(catalog_cc_hypodd["LON"], catalog_cc_hypodd["LAT"], s=2)
184188
plt.show()
@@ -225,6 +229,8 @@ def parse_args():
225229
)
226230
growclust_ct_catalog = growclust_ct_catalog[growclust_ct_catalog["nbranch"] > 1]
227231

232+
growclust_ct_catalog.to_csv(f"{root_path}/{region}/growclust/growclust_ct.csv", index=False)
233+
228234
# %%
229235
growclust_file = f"{root_path}/{region}/growclust/growclust_cc_catalog.txt"
230236
growclust_cc_exist = False
@@ -267,6 +273,8 @@ def parse_args():
267273
)
268274
growclust_cc_catalog = growclust_cc_catalog[growclust_cc_catalog["nbranch"] > 1]
269275

276+
growclust_cc_catalog.to_csv(f"{root_path}/{region}/growclust/growclust_cc.csv", index=False)
277+
270278

271279
# %% Debug
272280
# def load_Shelly2020():

0 commit comments

Comments
 (0)