Skip to content

Commit 20fad8a

Browse files
committed
small fix
1 parent 0c613b1 commit 20fad8a

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

examples/california/cut_templates_cc.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111
import numpy as np
1212
import obspy
1313
import pandas as pd
14-
from adloc.eikonal2d import calc_traveltime, init_eikonal2d
14+
from args import parse_args
15+
from cut_templates_merge import generate_pairs
1516
from pyproj import Proj
1617
from sklearn.neighbors import NearestNeighbors
1718
from tqdm import tqdm
18-
from args import parse_args
19-
from cut_templates_merge import generate_pairs
19+
20+
from adloc.eikonal2d import calc_traveltime, init_eikonal2d
2021

2122
np.random.seed(42)
2223

2324

2425
# %%
2526
def fillin_missing_picks(picks, events, stations, config):
26-
2727
# reference_t0 = config["reference_t0"]
2828
reference_t0 = pd.Timestamp(config["reference_t0"])
2929
vp_vs_ratio = config["vp_vs_ratio"]
@@ -70,9 +70,13 @@ def fillin_missing_picks(picks, events, stations, config):
7070

7171
## add provider
7272
if "provider" in picks.columns:
73-
picks_ps = picks_ps.merge(picks[["event_index", "station_id", "provider"]].drop_duplicates(), on=["event_index", "station_id"])
73+
picks_ps = picks_ps.merge(
74+
picks[["event_index", "station_id", "provider"]].drop_duplicates(), on=["event_index", "station_id"]
75+
)
7476
else:
75-
picks_ps = picks_ps.merge(picks[["event_index", "station_id"]].drop_duplicates(), on=["event_index", "station_id"])
77+
picks_ps = picks_ps.merge(
78+
picks[["event_index", "station_id"]].drop_duplicates(), on=["event_index", "station_id"]
79+
)
7680

7781
print(f"Original picks: {len(picks)}, Filled picks: {len(picks_ps)}")
7882
print(picks_ps.iloc[:10])
@@ -83,7 +87,6 @@ def fillin_missing_picks(picks, events, stations, config):
8387

8488
# %%
8589
def predict_full_picks(picks, events, stations, config):
86-
8790
vp_vs_ratio = config["vp_vs_ratio"]
8891
# reference_t0 = config["reference_t0"]
8992
reference_t0 = pd.Timestamp(config["reference_t0"])
@@ -158,12 +161,13 @@ def extract_template_numpy(
158161
config,
159162
lock,
160163
):
161-
162164
# reference_t0 = config["reference_t0"]
163165
reference_t0 = pd.Timestamp(config["reference_t0"])
164166

165167
template_array = np.memmap(template_fname, dtype=np.float32, mode="r+", shape=tuple(config["template_shape"]))
166-
traveltime_array = np.memmap(traveltime_fname, dtype=np.float32, mode="r+", shape=tuple(config["traveltime_shape"]))
168+
traveltime_array = np.memmap(
169+
traveltime_fname, dtype=np.float32, mode="r+", shape=tuple(config["traveltime_shape"])
170+
)
167171
traveltime_index_array = np.memmap(
168172
traveltime_index_fname, dtype=np.int32, mode="r+", shape=tuple(config["traveltime_shape"])
169173
)
@@ -175,14 +179,12 @@ def extract_template_numpy(
175179
# traveltime_mask = np.zeros(tuple(config["traveltime_shape"]), dtype=bool)
176180

177181
for picks in picks_group:
178-
179182
waveforms_dict = {}
180183
picks = picks.set_index(["idx_eve", "idx_sta", "phase_type"])
181184
picks_index = list(picks.index.unique())
182185

183186
## Cut templates
184187
for (idx_eve, idx_sta, phase_type), pick in picks.iterrows():
185-
186188
idx_pick = pick["idx_pick"]
187189
phase_timestamp = pick["phase_timestamp"]
188190

@@ -223,9 +225,10 @@ def extract_template_numpy(
223225
end_time = phase_timestamp - trace_starttime + config[f"time_after_{phase_type.lower()}"]
224226

225227
if phase_type == "P" and ((idx_eve, idx_sta, "S") in picks_index):
226-
227228
s_begin_time = (
228-
picks.loc[idx_eve, idx_sta, "S"]["phase_timestamp"] - trace_starttime - config[f"time_before_s"]
229+
picks.loc[idx_eve, idx_sta, "S"]["phase_timestamp"]
230+
- trace_starttime
231+
- config[f"time_before_s"]
229232
)
230233
if config["no_overlapping"]:
231234
end_time = min(end_time, s_begin_time)
@@ -234,10 +237,14 @@ def extract_template_numpy(
234237
end_time_index = int(round(end_time * config["sampling_rate"]))
235238

236239
if begin_time_index < 0:
237-
print(f"Warning: {begin_time = } < 0, {trace_starttime = }, {event['event_timestamp'] = }, {config[f'time_before_{phase_type.lower()}'] = }")
240+
print(
241+
f"Warning: {begin_time = } < 0, {trace_starttime = }, {event['event_timestamp'] = }, {config[f'time_before_{phase_type.lower()}'] = }"
242+
)
238243
continue
239244
if end_time_index > len(trace.data):
240-
print(f"Warning: {end_time = } > {len(trace.data)}, {trace_starttime = }, {event['event_timestamp'] = }, {config[f'time_after_{phase_type.lower()}'] = }")
245+
print(
246+
f"Warning: {end_time = } > {len(trace.data)}, {trace_starttime = }, {event['event_timestamp'] = }, {config[f'time_after_{phase_type.lower()}'] = }"
247+
)
241248
continue
242249

243250
# begin_time_index = max(0, int(round(begin_time * config["sampling_rate"])))
@@ -315,7 +322,6 @@ def extract_template_numpy(
315322

316323
# %%
317324
def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
318-
319325
# %%
320326
fs = fsspec.filesystem(protocol, token=token)
321327

@@ -452,7 +458,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
452458
stations = pd.read_csv(fp)
453459
stations.sort_values(by=["latitude", "longitude"], inplace=True)
454460

455-
456461
# stations = stations[stations["network"] == "7D"]
457462
print(f"{len(stations) = }")
458463
print(stations.head())
@@ -472,7 +477,8 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
472477
if protocol == "file":
473478
events = pd.read_csv(
474479
f"{root_path}/{data_path}/{year:04d}/adloc_events_{jday:03d}.csv", parse_dates=["time"]
475-
# f"{root_path}/{data_path}/{year:04d}/ransac_events_{jday:03d}.csv", parse_dates=["time"]
480+
# f"{root_path}/{data_path}/{year:04d}/ransac_events_{jday:03d}.csv",
481+
parse_dates=["time"],
476482
)
477483
else:
478484
with fs.open(f"{bucket}/{data_path}/{year:04d}/adloc_events_{jday:03d}.csv", "r") as fp:
@@ -572,7 +578,9 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
572578
)
573579
if "dist_km" not in picks:
574580
picks = picks.merge(stations[["station_id", "x_km", "y_km", "z_km"]], on="station_id")
575-
picks.rename(columns={"x_km": "station_x_km", "y_km": "station_y_km", "z_km": "station_z_km"}, inplace=True)
581+
picks.rename(
582+
columns={"x_km": "station_x_km", "y_km": "station_y_km", "z_km": "station_z_km"}, inplace=True
583+
)
576584
picks = picks.merge(events[["event_index", "x_km", "y_km", "z_km"]], on="event_index")
577585
picks.rename(columns={"x_km": "event_x_km", "y_km": "event_y_km", "z_km": "event_z_km"}, inplace=True)
578586
picks["dist_km"] = np.linalg.norm(
@@ -655,7 +663,16 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
655663
config["reference_t0"] = reference_t0.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
656664
events = events[["idx_eve", "x_km", "y_km", "z_km", "event_index", "event_time", "event_timestamp"]]
657665
stations = stations[["idx_sta", "x_km", "y_km", "z_km", "station_id", "component", "network", "station"]]
658-
columns = ["idx_eve", "idx_sta", "phase_type", "phase_score", "phase_time", "phase_timestamp", "phase_source", "station_id"]
666+
columns = [
667+
"idx_eve",
668+
"idx_sta",
669+
"phase_type",
670+
"phase_score",
671+
"phase_time",
672+
"phase_timestamp",
673+
"phase_source",
674+
"station_id"
675+
]
659676
if "provider" in picks.columns:
660677
columns.append("provider")
661678
picks = picks[columns]
@@ -698,8 +715,7 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
698715
mseeds["location"] = mseeds["fname"].apply(lambda x: x[10:12].strip("_"))
699716
mseeds["year"] = mseeds["fname"].apply(lambda x: x[13:17])
700717
mseeds["jday"] = mseeds["fname"].apply(lambda x: x[17:20])
701-
if "provider" not in picks.columns:
702-
mseeds["provider"] = "SC"
718+
mseeds["provider"] = "SC"
703719
elif folder == "NC":
704720
mseeds["fname"] = mseeds["ENZ"].apply(lambda x: x.split("/")[-1])
705721
mseeds["network"] = mseeds["fname"].apply(lambda x: x.split(".")[1])
@@ -708,8 +724,7 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
708724
mseeds["location"] = mseeds["fname"].apply(lambda x: x.split(".")[3])
709725
mseeds["year"] = mseeds["fname"].apply(lambda x: x.split(".")[5])
710726
mseeds["jday"] = mseeds["fname"].apply(lambda x: x.split(".")[6])
711-
if "provider" not in picks.columns:
712-
mseeds["provider"] = "NC"
727+
mseeds["provider"] = "NC"
713728
elif folder == "IRIS":
714729
mseeds["fname"] = mseeds["ENZ"].apply(lambda x: x.split("/")[-1])
715730
mseeds["jday"] = mseeds["ENZ"].apply(lambda x: x.split("/")[-2])
@@ -718,8 +733,7 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
718733
mseeds["station"] = mseeds["fname"].apply(lambda x: x.split(".")[1])
719734
mseeds["location"] = mseeds["fname"].apply(lambda x: x.split(".")[2])
720735
mseeds["instrument"] = mseeds["fname"].apply(lambda x: x.split(".")[3][:2])
721-
if "provider" not in picks.columns:
722-
mseeds["provider"] = "IRIS"
736+
mseeds["provider"] = "IRIS"
723737
else:
724738
raise ValueError(f"Unknown folder: {folder}")
725739
mseeds_df.append(mseeds)
@@ -739,7 +753,9 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
739753
picks = picks[(picks["year"].astype(int) == year) & (picks["jday"].astype(int) == jday)]
740754

741755
if "provider" in picks.columns:
742-
picks = picks.merge(mseeds_df, on=["network", "station", "location", "instrument", "year", "jday", "provider"])
756+
picks = picks.merge(
757+
mseeds_df, on=["network", "station", "location", "instrument", "year", "jday", "provider"]
758+
)
743759
else:
744760
picks = picks.merge(mseeds_df, on=["network", "station", "location", "instrument", "year", "jday"])
745761
picks.drop(columns=["fname", "station_id", "network", "location", "instrument", "year", "jday"], inplace=True)
@@ -748,8 +764,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
748764
print(f"No picks found for {year:04d}/{jday:03d}")
749765
continue
750766

751-
752-
753767
# ####
754768
# out = picks.drop(columns=["ENZ"])
755769
# out.to_csv(f"{root_path}/{result_path}/{year:04d}/cctorch_picks_{jday:03d}.csv", index=False)
@@ -770,7 +784,8 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
770784
print(f"Using {ncpu} cores")
771785

772786
pbar = tqdm(total=nsplit, desc="Cutting templates")
773-
ctx = mp.get_context("fork")
787+
# ctx = mp.get_context("fork")
788+
ctx = mp.get_context("spawn")
774789

775790
with ctx.Manager() as manager:
776791
lock = manager.Lock()
@@ -781,7 +796,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
781796
picks_group_chunk = [[picks_group.get_group(g) for g in group] for group in group_chunk]
782797

783798
for picks_group in picks_group_chunk:
784-
785799
job = pool.apply_async(
786800
extract_template_numpy,
787801
(
@@ -857,7 +871,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
857871

858872
# %%
859873
if __name__ == "__main__":
860-
861874
# %%
862875
protocol = "gs"
863876
token_json = f"application_default_credentials.json"
@@ -929,7 +942,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
929942
# num_jday = 366 if (year % 4 == 0 and year % 100 != 0) or year % 400 == 0 else 365
930943
# jdays.extend([f"{year}.{i:03d}" for i in range(1, num_jday + 1)])
931944

932-
933945
num_jday = 366 if (year % 4 == 0 and year % 100 != 0) or year % 400 == 0 else 365
934946
# jdays = [f"{year}.{i:03d}" for i in range(1, num_jday + 1)]
935947
jdays = range(1, num_jday + 1)

0 commit comments

Comments
 (0)