|
75 | 75 | print(json.dumps(config, indent=4)) |
76 | 76 | config["use_amplitude"] = True |
77 | 77 |
|
78 | | - # ## Eikonal for 1D velocity model |
79 | | - zz = [0.0, 5.5, 16.0, 32.0] |
80 | | - vp = [5.5, 5.5, 6.7, 7.8] |
81 | | - vp_vs_ratio = 1.73 |
82 | | - vs = [v / vp_vs_ratio for v in vp] |
83 | | - h = 0.3 |
84 | | - |
85 | 78 | # %% |
86 | 79 | if not os.path.exists(result_path): |
87 | 80 | os.makedirs(result_path) |
|
126 | 119 | # vp = [5.5, 5.5, 6.7, 7.8] |
127 | 120 | # vp_vs_ratio = 1.73 |
128 | 121 | # vs = [v / vp_vs_ratio for v in vp] |
129 | | - # zz = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 30.0] |
130 | | - # vp = [4.746, 4.793, 4.799, 5.045, 5.721, 5.879, 6.504, 6.708, 6.725, 7.800] |
| 122 | + zz = [0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 17.0, 21.0, 31.00, 31.10] |
| 123 | + vp = [5.30, 5.65, 5.93, 6.20, 6.20, 6.20, 6.20, 6.20, 6.20, 6.20, 7.50, 8.11] |
131 | 124 | # vs = [2.469, 2.470, 2.929, 2.930, 3.402, 3.403, 3.848, 3.907, 3.963, 4.500] |
132 | | - # h = 0.3 |
| 125 | + vs = [v / 1.73 for v in vp] |
| 126 | + h = 0.3 |
133 | 127 | vel = {"Z": zz, "P": vp, "S": vs} |
134 | 128 | config["eikonal"] = { |
135 | 129 | "vel": vel, |
|
165 | 159 | # event_time=event_time, |
166 | 160 | eikonal=config["eikonal"], |
167 | 161 | ) |
| 162 | + |
| 163 | + ## invert loss |
| 164 | + ###################################################################################################### |
| 165 | + EPOCHS = 500 |
| 166 | + lr = 0.01 |
| 167 | + # optimizer = optim.Adam(params=travel_time.parameters(), lr=0.01) |
| 168 | + optimizer = optim.Adam( |
| 169 | + [ |
| 170 | + {"params": travel_time.event_loc.parameters(), "lr": lr}, # learning rate for event_loc |
| 171 | + {"params": travel_time.event_time.parameters(), "lr": lr * 0.1}, # learning rate for event_time |
| 172 | + ], |
| 173 | + lr=lr, |
| 174 | + ) |
| 175 | + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=lr * 0.1) |
| 176 | + scaler = optim.lr_scheduler.ReduceLROnPlateau( |
| 177 | + optim.SGD(params=travel_time.parameters(), lr=1.0), mode="min", factor=0.9, patience=3, threshold=0.05 |
| 178 | + ) |
| 179 | + valid_index = np.ones(len(pairs), dtype=bool) |
| 180 | + |
168 | 181 | if ddp: |
169 | 182 | travel_time = DDP(travel_time) |
170 | 183 | raw_travel_time = travel_time.module if ddp else travel_time |
171 | 184 |
|
172 | 185 | if ddp_local_rank == 0: |
173 | 186 | print(f"Dataset: {len(events)} events, {len(stations)} stations, {len(data_loader)} batches") |
174 | 187 |
|
175 | | - ## invert loss |
176 | | - ###################################################################################################### |
177 | | - optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1) |
178 | | - valid_index = np.ones(len(pairs), dtype=bool) |
179 | | - EPOCHS = 100 |
| 188 | + NUM_PAIRS = len(data_loader) |
180 | 189 | for epoch in range(EPOCHS): |
181 | 190 | loss = 0 |
182 | 191 | optimizer.zero_grad() |
183 | 192 | # for meta in tqdm(phase_dataset, desc=f"Epoch {i}"): |
184 | 193 | for meta in data_loader: |
| 194 | + if meta is None: |
| 195 | + continue |
| 196 | + |
185 | 197 | out = travel_time( |
186 | 198 | meta["idx_sta"], |
187 | 199 | meta["idx_eve"], |
188 | 200 | meta["phase_type"], |
189 | 201 | meta["phase_time"], |
190 | 202 | meta["phase_weight"], |
191 | 203 | ) |
192 | | - pred_, loss_ = out["phase_time"], out["loss"] |
| 204 | + if out is None: |
| 205 | + continue |
| 206 | + |
| 207 | + pred_, loss_ = out["phase_time"], out["loss"] / NUM_PAIRS |
193 | 208 |
|
194 | 209 | loss_.backward() |
195 | 210 |
|
|
201 | 216 |
|
202 | 217 | # torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0) |
203 | 218 | optimizer.step() |
| 219 | + scheduler.step() |
| 220 | + scaler.step(loss) |
204 | 221 | with torch.no_grad(): |
205 | 222 | raw_travel_time.event_loc.weight.data[:, 2].clamp_( |
206 | 223 | min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1 |
207 | 224 | ) |
208 | | - raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0 |
| 225 | + # raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0 |
209 | 226 | if ddp_local_rank == 0: |
210 | | - print(f"Epoch {epoch}: loss {loss:.6e} of {np.sum(valid_index)} picks, {loss / np.sum(valid_index):.6e}") |
| 227 | + print( |
| 228 | + f"Epoch {epoch}: loss {loss:.6e} of {np.sum(valid_index)} picks, {loss / np.sum(valid_index):.6e}, lr {scheduler.get_last_lr()[0]:.5f}" |
| 229 | + ) |
211 | 230 |
|
212 | 231 | ### filtering |
213 | 232 | pred_time = [] |
| 233 | + weight = [] |
214 | 234 | phase_dataset.valid_index = np.ones(len(pairs), dtype=bool) |
215 | 235 | for meta in phase_dataset: |
| 236 | + weight.append(meta["phase_weight"].detach().numpy()) |
216 | 237 | meta = travel_time( |
217 | 238 | meta["idx_sta"], |
218 | 239 | meta["idx_eve"], |
|
223 | 244 | pred_time.append(meta["phase_time"].detach().numpy()) |
224 | 245 |
|
225 | 246 | pred_time = np.concatenate(pred_time) |
226 | | - valid_index = ( |
227 | | - np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * 3.0 |
228 | | - ) # * (np.cos(epoch * np.pi / EPOCHS) + 2.0) # 3std -> 1std |
| 247 | + weight = np.concatenate(weight) |
| 248 | + # threshold_time = 6.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 2.0 # s |
| 249 | + threshold_time = 6.0 * (EPOCHS - 1 - epoch) / EPOCHS + 2.0 # s |
| 250 | + print(f"Scaler: {scaler.get_last_lr()[0]}") |
| 251 | + threshold_time *= scaler.get_last_lr()[0] |
| 252 | + # valid_index = np.abs(pred_time - pairs["dt"]) < np.std((pred_time - pairs["dt"])[valid_index]) * threshold_time |
| 253 | + # weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])[valid_index]) ** 2, weights=weight[valid_index])) |
| 254 | + weighted_std = np.sqrt(np.average(((pred_time - pairs["dt"])) ** 2, weights=weight)) |
| 255 | + valid_index = np.abs(pred_time - pairs["dt"]) < weighted_std * threshold_time |
229 | 256 |
|
230 | 257 | pairs_df = pd.DataFrame( |
231 | 258 | { |
|
234 | 261 | "station_index": pairs["idx_sta"], |
235 | 262 | } |
236 | 263 | ) |
| 264 | + num_picks = len(pairs_df) |
237 | 265 | pairs_df = pairs_df[valid_index] |
| 266 | + print(f"Filter by time: {num_picks} -> {len(pairs_df)} using threshold {threshold_time:.2f}") |
| 267 | + |
| 268 | + event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy() |
| 269 | + event_loc = pd.DataFrame( |
| 270 | + { |
| 271 | + "x_km": event_loc[:, 0], |
| 272 | + "y_km": event_loc[:, 1], |
| 273 | + "z_km": event_loc[:, 2], |
| 274 | + } |
| 275 | + ) |
| 276 | + pairs_df = pairs_df.merge(event_loc[["x_km", "y_km", "z_km"]], left_on="event_index1", right_index=True) |
| 277 | + pairs_df.rename(columns={"x_km": "x_km_1", "y_km": "y_km_1", "z_km": "z_km_1"}, inplace=True) |
| 278 | + pairs_df = pairs_df.merge(event_loc[["x_km", "y_km", "z_km"]], left_on="event_index2", right_index=True) |
| 279 | + pairs_df.rename(columns={"x_km": "x_km_2", "y_km": "y_km_2", "z_km": "z_km_2"}, inplace=True) |
| 280 | + pairs_df["dist_km"] = np.sqrt( |
| 281 | + (pairs_df["x_km_1"] - pairs_df["x_km_2"]) ** 2 |
| 282 | + + (pairs_df["y_km_1"] - pairs_df["y_km_2"]) ** 2 |
| 283 | + + (pairs_df["z_km_1"] - pairs_df["z_km_2"]) ** 2 |
| 284 | + ) |
| 285 | + # threshold_space = 9.0 * (np.cos(epoch * np.pi / EPOCHS) + 1.0) / 2.0 + 1.0 # km |
| 286 | + threshold_space = 9.0 * (EPOCHS - 1 - epoch) / EPOCHS + 1.0 # km |
| 287 | + threshold_space *= scaler.get_last_lr()[0] |
| 288 | + num_picks = len(pairs_df) |
| 289 | + pairs_df = pairs_df[pairs_df["dist_km"] < threshold_space] |
| 290 | + print(f"Filter by space: {num_picks} -> {len(pairs_df)} using threshold {threshold_space:.2f}") |
| 291 | + |
238 | 292 | config["MIN_OBS"] = 8 |
| 293 | + num_picks = len(pairs_df) |
239 | 294 | pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter( |
240 | 295 | lambda x: len(x) >= config["MIN_OBS"] |
241 | 296 | ) |
| 297 | + print(f"Filter by MIN_OBS: {num_picks} -> {len(pairs_df)} using threshold {config['MIN_OBS']:d}") |
242 | 298 | valid_index = np.zeros(len(pairs), dtype=bool) |
243 | 299 | valid_index[pairs_df.index] = True |
244 | 300 |
|
|
252 | 308 | ) |
253 | 309 | valid_event_index = np.sort(np.unique(valid_event_index)) |
254 | 310 |
|
| 311 | + print( |
| 312 | + f"{invert_event_time.shape = }, {invert_event_time.min() = }, {invert_event_time.max() = }, {np.median(invert_event_time) = }" |
| 313 | + ) |
| 314 | + |
| 315 | + # # ## correct events time |
| 316 | + # pairs_df = pd.DataFrame( |
| 317 | + # { |
| 318 | + # "event_index1": pairs["idx_eve1"], |
| 319 | + # "event_index2": pairs["idx_eve2"], |
| 320 | + # "resisual": pred_time - pairs["dt"], |
| 321 | + # } |
| 322 | + # ) |
| 323 | + # # pair_df = pairs_df[valid_index] |
| 324 | + # res1 = pairs_df.groupby("event_index1")["resisual"].median() |
| 325 | + # res2 = pairs_df.groupby("event_index2")["resisual"].median() |
| 326 | + # res = pd.Series(np.zeros(len(events_init)), index=events_init.index) |
| 327 | + # res = res.add(-res1, fill_value=0) |
| 328 | + # res = res.add(res2, fill_value=0) |
| 329 | + # print(f"{res.describe() = }") |
| 330 | + # raw_travel_time.event_time.weight.data = torch.tensor(res.values[:, np.newaxis] / 2.0, dtype=torch.float32) |
| 331 | + |
255 | 332 | if ddp_local_rank == 0 and (epoch % 10 == 0): |
256 | 333 | events = events_init.copy() |
257 | 334 | events["time"] = events["time"] + pd.to_timedelta(np.squeeze(invert_event_time), unit="s") |
|
271 | 348 | ) |
272 | 349 | plotting_dd(events, stations, config, figure_path, events_init, suffix=f"_ddcc_{epoch//10}") |
273 | 350 |
|
274 | | - # ###################################################################################################### |
275 | | - # optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=10, line_search_fn="strong_wolfe") |
| 351 | + # # ###################################################################################################### |
| 352 | + # optimizer = optim.LBFGS(params=raw_travel_time.parameters(), max_iter=50, line_search_fn="strong_wolfe") |
276 | 353 |
|
277 | 354 | # def closure(): |
278 | 355 | # optimizer.zero_grad() |
|
284 | 361 | # if ddp_local_rank == 0: |
285 | 362 | # print(".", end="") |
286 | 363 |
|
287 | | - # loss_ = travel_time( |
288 | | - # meta["idx_sta"], |
289 | | - # meta["idx_eve"], |
290 | | - # meta["phase_type"], |
291 | | - # meta["phase_time"], |
292 | | - # meta["phase_weight"], |
293 | | - # )["loss"] |
| 364 | + # loss_ = ( |
| 365 | + # travel_time( |
| 366 | + # meta["idx_sta"], |
| 367 | + # meta["idx_eve"], |
| 368 | + # meta["phase_type"], |
| 369 | + # meta["phase_time"], |
| 370 | + # meta["phase_weight"], |
| 371 | + # )["loss"] |
| 372 | + # / NUM_PAIRS |
| 373 | + # ) |
294 | 374 | # loss_.backward() |
295 | 375 |
|
296 | 376 | # if ddp: |
|
305 | 385 | # return loss |
306 | 386 |
|
307 | 387 | # optimizer.step(closure) |
308 | | - # ###################################################################################################### |
| 388 | + # # ###################################################################################################### |
309 | 389 |
|
310 | 390 | # %% |
311 | 391 | if ddp_local_rank == 0: |
|
0 commit comments