Skip to content

Commit b7d2a38

Browse files
committed
update phasenet plus
1 parent 099f1d8 commit b7d2a38

File tree

6 files changed

+171
-56
lines changed

6 files changed

+171
-56
lines changed

scripts/download_waveform_v2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,13 @@ def download_waveform(
211211
client = obspy.clients.fdsn.Client(provider)
212212

213213
DELTATIME = "1H" # 1H or 1D
214+
# DELTATIME = "1D"
214215
if DELTATIME == "1H":
215216
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%dT%H")
216217
elif DELTATIME == "1D":
217218
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%d")
219+
else:
220+
raise ValueError("Invalid interval")
218221
starttimes = pd.date_range(start, config["endtime"], freq=DELTATIME, tz="UTC", inclusive="left").to_list()
219222
starttimes = np.array_split(starttimes, num_nodes)[rank]
220223
print(f"rank {rank}: {len(starttimes) = }, {starttimes[0]}, {starttimes[-1]}")

scripts/download_waveform_v3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def download_waveform(
228228
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%dT%H")
229229
elif DELTATIME == "1D":
230230
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%d")
231+
else:
232+
raise ValueError("Invalid interval")
231233
starttimes = pd.date_range(start, config["endtime"], freq=DELTATIME, tz="UTC", inclusive="left").to_list()
232234
starttimes = np.array_split(starttimes, num_nodes)[node_rank]
233235
print(f"rank {node_rank}/{num_nodes}: {len(starttimes) = }, {starttimes[0]}, {starttimes[-1]}")

scripts/merge_phasenet_picks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
import os
55
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
66
from datetime import datetime, timedelta, timezone
7+
from glob import glob
78
from threading import Lock, Thread
89

910
import fsspec
1011
import numpy as np
1112
import pandas as pd
1213
import pyproj
14+
from args import parse_args
1315
from obspy import read_inventory
1416
from obspy.clients.fdsn import Client
1517
from sklearn.cluster import DBSCAN
1618
from tqdm import tqdm
17-
from args import parse_args
18-
from glob import glob
1919

2020

2121
def scan_csv(year, root_path, region, model, fs=None, bucket=None, protocol="file"):
@@ -31,6 +31,7 @@ def scan_csv(year, root_path, region, model, fs=None, bucket=None, protocol="fil
3131
csvs = fs.glob(f"{jday}/??/*.csv")
3232
else:
3333
csvs = glob(f"{root_path}/{region}/{model}/picks/{year}/{jday}/??/*.csv")
34+
# csvs = glob(f"{root_path}/{region}/{model}/picks/{year}/{jday}/*.csv")
3435

3536
csv_list.extend([[year, jday, csv] for csv in csvs])
3637

@@ -89,7 +90,7 @@ def read_csv(rows, region, model, year, jday, root_path, fs=None, bucket=None):
8990

9091
# %%
9192
# years = os.listdir(f"{root_path}/{region}/{model}/picks_{model}")
92-
years = glob(f"{root_path}/{region}/{model}/picks_{model}/????/")
93+
years = glob(f"{root_path}/{region}/{model}/picks/????/")
9394
years = [year.rstrip("/").split("/")[-1] for year in years]
9495
print(f"Years: {years}")
9596

@@ -132,6 +133,9 @@ def read_csv(rows, region, model, year, jday, root_path, fs=None, bucket=None):
132133
for csv in tqdm(csvs, desc="Merge csv files"):
133134
picks.append(pd.read_csv(csv, dtype=str))
134135
picks = pd.concat(picks, ignore_index=True)
136+
print(f"Number of picks: {len(picks):,}")
137+
print(f"Number of P picks: {len(picks[picks['phase_type'] == 'P']):,}")
138+
print(f"Number of S picks: {len(picks[picks['phase_type'] == 'S']):,}")
135139
picks.to_csv(f"{root_path}/{region}/{model}/{model}_picks.csv", index=False)
136140

137141
# %%

scripts/run_event_association.py

Lines changed: 128 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,79 @@
1818
from tqdm import tqdm
1919

2020

21+
def plotting_debug(xt, hist, topk_index, topk_score, picks, events, stations, config):
22+
23+
# timestamp0 = config["timestamp0"]
24+
# events_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_gamma/ransac_events.csv")
25+
# picks_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_gamma/ransac_picks.csv")
26+
# # events_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_plus2/ransac_events_sst_0.csv")
27+
# # picks_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_plus2/ransac_picks_sst_0.csv")
28+
# events_compare["time"] = pd.to_datetime(events_compare["time"])
29+
# events_compare["timestamp"] = events_compare["time"].apply(lambda x: (x - timestamp0).total_seconds())
30+
# picks_compare["phase_time"] = pd.to_datetime(picks_compare["phase_time"])
31+
# picks_compare["timestamp"] = picks_compare["phase_time"].apply(lambda x: (x - timestamp0).total_seconds())
32+
33+
DT = config["DT"]
34+
MIN_STATION = config["MIN_STATION"]
35+
36+
# map station_id to int
37+
stations["xy"] = stations["longitude"] - stations["latitude"]
38+
stations.sort_values(by="xy", inplace=True)
39+
mapping_id = {v: i for i, v in enumerate(stations["station_id"])}
40+
mapping_color = {v: f"C{i}" if v != -1 else "k" for i, v in enumerate(events["event_index"].unique())}
41+
42+
NX = 100
43+
for i in tqdm(range(0, len(hist), NX)):
44+
bins = np.arange(i, i + NX, DT)
45+
46+
fig, ax = plt.subplots(2, 1, figsize=(15, 10), sharex=True)
47+
48+
# plot hist
49+
idx = (xt > i) & (xt < i + NX)
50+
ax[0].bar(xt[idx], hist[idx], width=DT)
51+
52+
ylim = ax[0].get_ylim()
53+
idx = (xt[topk_index] > i) & (xt[topk_index] < i + NX)
54+
ax[0].vlines(xt[topk_index][idx], ylim[0], ylim[1], color="k", linewidth=1)
55+
56+
# idx = (events_compare["timestamp"] > i) & (events_compare["timestamp"] < i + NX)
57+
# ax[0].vlines(events_compare["timestamp"][idx], ylim[0], ylim[1], color="r", linewidth=1, linestyle="--")
58+
59+
# plot picks-events match
60+
idx = (events["timestamp"] > i) & (events["timestamp"] < i + NX)
61+
ax[1].scatter(
62+
events["timestamp"][idx],
63+
events["station_id"][idx].map(mapping_id),
64+
c=events["event_index"][idx].map(mapping_color),
65+
marker=".",
66+
s=30,
67+
)
68+
69+
idx = (picks["timestamp"] > i) & (picks["timestamp"] < i + NX)
70+
ax[1].scatter(
71+
picks["timestamp"][idx],
72+
picks["station_id"][idx].map(mapping_id),
73+
c=picks["event_index"][idx].map(mapping_color),
74+
marker="x",
75+
linewidth=0.5,
76+
s=10,
77+
)
78+
79+
# idx = (picks_compare["timestamp"] > i) & (picks_compare["timestamp"] < i + NX)
80+
# ax[1].scatter(
81+
# picks_compare["timestamp"][idx],
82+
# picks_compare["station_id"][idx].map(mapping_id),
83+
# facecolors="none",
84+
# edgecolors="r",
85+
# linewidths=0.1,
86+
# s=30,
87+
# )
88+
89+
if not os.path.exists(f"figures"):
90+
os.makedirs(f"figures")
91+
plt.savefig(f"figures/debug_{i:04d}.png", dpi=300, bbox_inches="tight")
92+
93+
2194
def associate(
2295
picks: pd.DataFrame,
2396
events: pd.DataFrame,
@@ -27,63 +100,68 @@ def associate(
27100

28101
VPVS_RATIO = config["VPVS_RATIO"]
29102
VP = config["VP"]
30-
DT = 1.0 # seconds
103+
DT = 2.0 # seconds
31104
MIN_STATION = 3
32105

33106
# %%
34-
t0 = min(events["event_time"].min(), picks["phase_time"].min())
35-
events["timestamp"] = events["event_time"].apply(lambda x: (x - t0).total_seconds())
36-
events["timestamp_center"] = events["center_time"].apply(lambda x: (x - t0).total_seconds())
37-
picks["timestamp"] = picks["phase_time"].apply(lambda x: (x - t0).total_seconds())
107+
timestamp0 = min(events["event_time"].min(), picks["phase_time"].min())
38108

39-
# proj = Proj(proj="merc", datum="WGS84", units="km")
40-
# stations[["x_km", "y_km"]] = stations.apply(lambda x: pd.Series(proj(x.longitude, x.latitude)), axis=1)
109+
events["timestamp"] = events["event_time"].apply(lambda x: (x - timestamp0).total_seconds())
110+
events["timestamp_center"] = events["center_time"].apply(lambda x: (x - timestamp0).total_seconds())
111+
picks["timestamp"] = picks["phase_time"].apply(lambda x: (x - timestamp0).total_seconds())
41112

42-
# dist_matrix = squareform(pdist(stations[["x_km", "y_km"]].values))
43-
# mst = minimum_spanning_tree(dist_matrix)
44-
# dx = np.median(mst.data[mst.data > 0])
45-
# print(f"dx: {dx:.3f}")
46-
# eps_t = dx / VP * 2.0
47-
# eps_t = 6.0
48-
# eps_xy = eps_t * VP * 2 / (1.0 + VPVS_RATIO)
49-
# print(f"eps_t: {eps_t:.3f}, eps_xy: {eps_xy:.3f}")
50-
# eps_xy = 30.0
51-
# print(f"eps_xy: {eps_xy:.3f}")
113+
t0 = min(events["timestamp"].min(), picks["timestamp"].min())
114+
t1 = max(events["timestamp"].max(), picks["timestamp"].max())
52115

53116
# %% Using DBSCAN to cluster events
117+
# proj = Proj(proj="merc", datum="WGS84", units="km")
118+
# stations[["x_km", "y_km"]] = stations.apply(lambda x: pd.Series(proj(x.longitude, x.latitude)), axis=1)
54119
# events = events.merge(stations[["station_id", "x_km", "y_km"]], on="station_id", how="left")
55-
56120
# scaling = np.array([1.0, 1.0 / eps_xy, 1.0 / eps_xy])
57121
# clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp", "x_km", "y_km"]] * scaling)
58122
# # clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp"]])
59-
# # clustering = DBSCAN(eps=3.0, min_samples=3).fit(events[["timestamp"]])
60-
# # clustering = DBSCAN(eps=1.0, min_samples=3).fit(events[["timestamp"]])
61123
# events["event_index"] = clustering.labels_
62124

63125
## Using histogram to cluster events
64126
events["event_index"] = -1
65-
t = np.arange(events["timestamp"].min(), events["timestamp"].max(), DT)
66-
hist, _ = np.histogram(events["timestamp"], bins=t)
67-
# retrieve picks using max_pool of kernel size 5 seconds
127+
t = np.arange(t0, t1, DT)
128+
hist, edge = np.histogram(events["timestamp"], bins=t, weights=events["event_score"])
129+
xt = (edge[:-1] + edge[1:]) / 2 # center of the bin
130+
# hist_numpy = hist.copy()
131+
68132
hist = torch.from_numpy(hist).float().unsqueeze(0).unsqueeze(0)
69-
hist_pool = F.max_pool1d(hist, kernel_size=5, padding=2, stride=1)
70-
# find the index of the maximum value in hist_pool
133+
hist_pool = F.max_pool1d(hist, kernel_size=3, padding=1, stride=1)
71134
mask = hist_pool == hist
72135
hist = hist * mask
73-
K = int((t[-1] - t[0]) / 10) # assume max 1 event per 10 seconds on average
136+
hist = hist.squeeze(0).squeeze(0)
137+
K = int((t[-1] - t[0]) / 5) # assume max 1 event per 10 seconds on average
74138
topk_score, topk_index = torch.topk(hist, k=K)
75-
topk_index = topk_index[topk_score > MIN_STATION] # min 3 stations
76-
topk_index = topk_index.squeeze().numpy()
139+
topk_index = topk_index[topk_score >= MIN_STATION] # min 3 stations
140+
topk_score = topk_score[topk_score >= MIN_STATION]
141+
topk_index = topk_index.numpy()
142+
topk_score = topk_score.numpy()
77143
num_events = len(topk_index)
78-
# assign timestamp to events based on the topk_index within 2 DT
79-
t0 = (topk_index - 2) * DT
80-
t1 = (topk_index + 2) * DT
144+
t00 = xt[topk_index - 1]
145+
t11 = xt[topk_index + 1]
81146
timestamp = events["timestamp"].values
82147
for i in tqdm(range(num_events), desc="Assigning event index"):
83-
mask = (timestamp >= t0[i]) & (timestamp <= t1[i])
148+
mask = (timestamp >= t00[i]) & (timestamp <= t11[i])
84149
events.loc[mask, "event_index"] = i
85-
86-
print(f"Number of associated events: {len(events['event_index'].unique())}")
150+
events["num_picks"] = events.groupby("event_index").size()
151+
152+
# # refine event index using DBSCAN
153+
# events["group_index"] = -1
154+
# for group_id, event in tqdm(events.groupby("event_index"), desc="DBSCAN clustering"):
155+
# if len(event) < MIN_STATION:
156+
# events.loc[event.index, "event_index"] = -1
157+
# clustering = DBSCAN(eps=20, min_samples=MIN_STATION).fit(event[["x_km", "y_km"]])
158+
# events.loc[event.index, "group_index"] = clustering.labels_
159+
# events["dummy_index"] = events["event_index"].astype(str) + "." + events["group_index"].astype(str)
160+
# mapping = {v: i for i, v in enumerate(events["dummy_index"].unique())}
161+
# events["dummy_index"] = events["dummy_index"].map(mapping)
162+
# events.loc[(events["event_index"] == -1) | (events["group_index"] == -1), "dummy_index"] = -1
163+
# events["event_index"] = events["dummy_index"]
164+
# events.drop(columns=["dummy_index"], inplace=True)
87165

88166
# %% link picks to events
89167
picks["event_index"] = -1
@@ -92,6 +170,8 @@ def associate(
92170
for group_id, event in tqdm(events.groupby("station_id"), desc="Linking picks to events"):
93171
# travel time tt = (tp + ts) / 2 = (1 + ps_ratio)/2 * tp => tp = tt * 2 / (1 + ps_ratio)
94172
# (ts - tp) = (ps_ratio - 1) tp = tt * 2 * (ps_ratio - 1) / (ps_ratio + 1)
173+
174+
event = event.sort_values(by="num_picks", ascending=True)
95175
ps_delta = event["travel_time_s"].values * 2 * (VPVS_RATIO - 1) / (VPVS_RATIO + 1)
96176
t1 = event["timestamp_center"].values - ps_delta * 1.2
97177
t2 = event["timestamp_center"].values + ps_delta * 1.2
@@ -107,6 +187,17 @@ def associate(
107187

108188
picks.reset_index(inplace=True)
109189

190+
# plotting_debug(
191+
# xt,
192+
# hist_numpy,
193+
# topk_index,
194+
# topk_score,
195+
# picks,
196+
# events,
197+
# stations,
198+
# {"DT": DT, "MIN_STATION": MIN_STATION, "timestamp0": timestamp0},
199+
# )
200+
110201
picks.drop(columns=["timestamp"], inplace=True)
111202
events.drop(columns=["timestamp", "timestamp_center"], inplace=True)
112203

@@ -127,6 +218,9 @@ def associate(
127218
# drop event index -1
128219
events = events[events["event_index"] != -1]
129220

221+
print(f"Number of associated events: {len(events['event_index'].unique()):,}")
222+
print(f"Number of associated picks: {len(picks[picks['event_index'] != -1]):,} / {len(picks):,}")
223+
130224
return events, picks
131225

132226

scripts/run_phasenet_plus.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# %%
2-
from typing import Dict, List
32
import json
43
import os
54
import sys
6-
from args import parse_args
7-
import os
5+
from collections import defaultdict
86
from glob import glob
7+
from typing import Dict, List
98

109
import fsspec
11-
import torch
12-
from collections import defaultdict
1310
import numpy as np
1411
import pandas as pd
12+
import torch
13+
from args import parse_args
1514
from run_event_association import associate
1615

1716

@@ -31,6 +30,7 @@ def run_phasenet(
3130
# %%
3231
if data_type == "continuous":
3332
subdir = 3
33+
# subdir = 2
3434
elif data_type == "event":
3535
subdir = 1
3636

@@ -50,6 +50,7 @@ def run_phasenet(
5050

5151
if data_type == "continuous":
5252
mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/????/???/??/*.mseed"))
53+
# mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/????/???/*.mseed"))
5354
elif data_type == "event":
5455
mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/*.mseed"))
5556
else:
@@ -59,16 +60,17 @@ def run_phasenet(
5960
mseed_3c = defaultdict(list)
6061
for mseed in mseed_list:
6162
# key = mseed.replace(f"{root_path}/{waveform_dir}/", "").replace(".mseed", "")
62-
key = "/".join(mseed.replace(".mseed", "").split("/")[-subdir:])
63+
key = "/".join(mseed.replace(".mseed", "").split("/")[-subdir - 1 :])
6364
if data_type == "continuous":
6465
key = key[:-1]
6566
mseed_3c[key].append(mseed)
6667
print(f"Number of mseed files: {len(mseed_3c)}")
6768

6869
# %% skip processed files
6970
if not overwrite:
70-
processed = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_plus/????/???/??/*.csv"))
71-
processed = ["/".join(f.replace(".csv", "").split("/")[-subdir:]) for f in processed]
71+
processed = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_plus/????/???/*.csv"))
72+
# processed = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_plus/????/???/*.csv"))
73+
processed = ["/".join(f.replace(".csv", "").split("/")[-subdir - 1 :]) for f in processed]
7274
processed = [p[:-1] for p in processed] ## remove the channel suffix
7375
print(f"Number of processed files: {len(processed)}")
7476

@@ -93,6 +95,7 @@ def run_phasenet(
9395
num_gpu = torch.cuda.device_count()
9496
print(f"num_gpu = {num_gpu}")
9597
base_cmd = f"../EQNet/predict.py --model phasenet_plus --add_polarity --add_event --format mseed --data_list={root_path}/{result_path}/mseed_list_{node_rank:03d}_{num_nodes:03d}.csv --response_path={root_path}/{response_path} --result_path={root_path}/{result_path} --batch_size 1 --workers 1 --subdir_level {subdir}"
98+
# base_cmd += " --resume ../../QuakeFlow/EQNet/model_phasenet_plus_0630/model_99.pth"
9699
if num_gpu == 0:
97100
cmd = f"python {base_cmd} --device=cpu"
98101
elif num_gpu == 1:
@@ -116,6 +119,9 @@ def run_phasenet(
116119

117120
run_phasenet(root_path=root_path, region=region, config=config)
118121

122+
if num_nodes == 1:
123+
os.system(f"python merge_phasenet_plus_picks.py --region {region}")
124+
119125
if num_nodes == 1:
120126
config.update({"VPVS_RATIO": 1.73, "VP": 6.0})
121127
stations = pd.read_json(f"{root_path}/{region}/obspy/stations.json", orient="index")
@@ -125,8 +131,6 @@ def run_phasenet(
125131
)
126132
picks = pd.read_csv(f"{root_path}/{region}/phasenet_plus/picks_phasenet_plus.csv", parse_dates=["phase_time"])
127133
events, picks = associate(picks, events, stations, config)
128-
print(f"Number of picks: {len(picks):,}")
129-
print(f"Number of associated events: {len(events['event_index'].unique()):,}")
130134
events.to_csv(f"{root_path}/{region}/phasenet_plus/phasenet_plus_events_associated.csv", index=False)
131135
picks.to_csv(f"{root_path}/{region}/phasenet_plus/phasenet_plus_picks_associated.csv", index=False)
132136

0 commit comments

Comments
 (0)