Skip to content

Commit 099f1d8

Browse files
committed
update clustering
1 parent b4fec1f commit 099f1d8

File tree

1 file changed

+59
-30
lines changed

1 file changed

+59
-30
lines changed

scripts/run_event_association.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import matplotlib.pyplot as plt
99
import numpy as np
1010
import pandas as pd
11+
import torch
12+
import torch.nn.functional as F
1113
from args import parse_args
1214
from pyproj import Proj
1315
from scipy.sparse.csgraph import minimum_spanning_tree
@@ -25,9 +27,17 @@ def associate(
2527

2628
VPVS_RATIO = config["VPVS_RATIO"]
2729
VP = config["VP"]
30+
DT = 1.0 # seconds
31+
MIN_STATION = 3
2832

29-
proj = Proj(proj="merc", datum="WGS84", units="km")
30-
stations[["x_km", "y_km"]] = stations.apply(lambda x: pd.Series(proj(x.longitude, x.latitude)), axis=1)
33+
# %%
34+
t0 = min(events["event_time"].min(), picks["phase_time"].min())
35+
events["timestamp"] = events["event_time"].apply(lambda x: (x - t0).total_seconds())
36+
events["timestamp_center"] = events["center_time"].apply(lambda x: (x - t0).total_seconds())
37+
picks["timestamp"] = picks["phase_time"].apply(lambda x: (x - t0).total_seconds())
38+
39+
# proj = Proj(proj="merc", datum="WGS84", units="km")
40+
# stations[["x_km", "y_km"]] = stations.apply(lambda x: pd.Series(proj(x.longitude, x.latitude)), axis=1)
3141

3242
# dist_matrix = squareform(pdist(stations[["x_km", "y_km"]].values))
3343
# mst = minimum_spanning_tree(dist_matrix)
@@ -37,43 +47,62 @@ def associate(
3747
# eps_t = 6.0
3848
# eps_xy = eps_t * VP * 2 / (1.0 + VPVS_RATIO)
3949
# print(f"eps_t: {eps_t:.3f}, eps_xy: {eps_xy:.3f}")
40-
eps_xy = 30.0
41-
print(f"eps_xy: {eps_xy:.3f}")
50+
# eps_xy = 30.0
51+
# print(f"eps_xy: {eps_xy:.3f}")
52+
53+
# %% Using DBSCAN to cluster events
54+
# events = events.merge(stations[["station_id", "x_km", "y_km"]], on="station_id", how="left")
55+
56+
# scaling = np.array([1.0, 1.0 / eps_xy, 1.0 / eps_xy])
57+
# clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp", "x_km", "y_km"]] * scaling)
58+
# # clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp"]])
59+
# # clustering = DBSCAN(eps=3.0, min_samples=3).fit(events[["timestamp"]])
60+
# # clustering = DBSCAN(eps=1.0, min_samples=3).fit(events[["timestamp"]])
61+
# events["event_index"] = clustering.labels_
62+
63+
## Using histogram to cluster events
64+
events["event_index"] = -1
65+
t = np.arange(events["timestamp"].min(), events["timestamp"].max(), DT)
66+
hist, _ = np.histogram(events["timestamp"], bins=t)
67+
# retrieve picks using max_pool of kernel size 5 seconds
68+
hist = torch.from_numpy(hist).float().unsqueeze(0).unsqueeze(0)
69+
hist_pool = F.max_pool1d(hist, kernel_size=5, padding=2, stride=1)
70+
# find the index of the maximum value in hist_pool
71+
mask = hist_pool == hist
72+
hist = hist * mask
73+
K = int((t[-1] - t[0]) / 10) # assume max 1 event per 10 seconds on average
74+
topk_score, topk_index = torch.topk(hist, k=K)
75+
topk_index = topk_index[topk_score > MIN_STATION] # min 3 stations
76+
topk_index = topk_index.squeeze().numpy()
77+
num_events = len(topk_index)
78+
# assign timestamp to events based on the topk_index within 2 DT
79+
t0 = (topk_index - 2) * DT
80+
t1 = (topk_index + 2) * DT
81+
timestamp = events["timestamp"].values
82+
for i in tqdm(range(num_events), desc="Assigning event index"):
83+
mask = (timestamp >= t0[i]) & (timestamp <= t1[i])
84+
events.loc[mask, "event_index"] = i
4285

43-
# %%
44-
t0 = min(events["event_time"].min(), picks["phase_time"].min())
45-
events["timestamp"] = events["event_time"].apply(lambda x: (x - t0).total_seconds())
46-
events["timestamp_center"] = events["center_time"].apply(lambda x: (x - t0).total_seconds())
47-
picks["timestamp"] = picks["phase_time"].apply(lambda x: (x - t0).total_seconds())
48-
49-
# %%
50-
events = events.merge(stations[["station_id", "x_km", "y_km"]], on="station_id", how="left")
51-
52-
scaling = np.array([1.0, 1.0 / eps_xy, 1.0 / eps_xy])
53-
clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp", "x_km", "y_km"]] * scaling)
54-
# clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp"]])
55-
# clustering = DBSCAN(eps=3.0, min_samples=3).fit(events[["timestamp"]])
56-
# clustering = DBSCAN(eps=1.0, min_samples=3).fit(events[["timestamp"]])
57-
events["event_index"] = clustering.labels_
5886
print(f"Number of associated events: {len(events['event_index'].unique())}")
5987

6088
# %% link picks to events
6189
picks["event_index"] = -1
6290
picks.set_index("station_id", inplace=True)
6391

6492
for group_id, event in tqdm(events.groupby("station_id"), desc="Linking picks to events"):
65-
# travel time tt = (tp + ts) / 2 = (ps_ratio + 1)/2 * tp,
66-
# (ts - tp) = (ps_ratio - 1) tp = tt * (ps_ratio + 1) * 2 * (ps_ratio - 1)
67-
ps_delta = event["travel_time_s"] / (VPVS_RATIO + 1) * 2 * (VPVS_RATIO - 1)
68-
t1 = event["timestamp_center"] - ps_delta * 1.2
69-
t2 = event["timestamp_center"] + ps_delta * 1.2
70-
index = event["event_index"]
71-
72-
mask = (picks.loc[group_id, "timestamp"].values[None, :] >= t1.values[:, None]) & (
73-
picks.loc[group_id, "timestamp"].values[None, :] <= t2.values[:, None]
74-
)
93+
# travel time tt = (tp + ts) / 2 = (1 + ps_ratio)/2 * tp => tp = tt * 2 / (1 + ps_ratio)
94+
# (ts - tp) = (ps_ratio - 1) tp = tt * 2 * (ps_ratio - 1) / (ps_ratio + 1)
95+
ps_delta = event["travel_time_s"].values * 2 * (VPVS_RATIO - 1) / (VPVS_RATIO + 1)
96+
t1 = event["timestamp_center"].values - ps_delta * 1.2
97+
t2 = event["timestamp_center"].values + ps_delta * 1.2
98+
99+
picks_ = picks.loc[group_id, "timestamp"].values # (Npk, )
100+
mask = (picks_[None, :] >= t1[:, None]) & (picks_[None, :] <= t2[:, None]) # (Nev, Npk)
101+
# picks.loc[group_id, "event_index"] = np.where(
102+
# mask.any(axis=0), index.values[mask.argmax(axis=0)], picks.loc[group_id, "event_index"]
103+
# )
75104
picks.loc[group_id, "event_index"] = np.where(
76-
mask.any(axis=0), index.values[mask.argmax(axis=0)], picks.loc[group_id, "event_index"]
105+
mask.any(axis=0), event["event_index"].values[mask.argmax(axis=0)], -1
77106
)
78107

79108
picks.reset_index(inplace=True)

0 commit comments

Comments
 (0)