Skip to content

Commit a7c28df

Browse files
committed
cut based on file not station
1 parent cf2d2ee commit a7c28df

File tree

1 file changed

+113
-113
lines changed

1 file changed

+113
-113
lines changed

scripts/cut_templates_cc.py

Lines changed: 113 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import multiprocessing as mp
44
import os
55
import sys
6+
from collections import defaultdict
67
from glob import glob
78

9+
import fsspec
810
import matplotlib.pyplot as plt
911
import numpy as np
1012
import obspy
@@ -141,10 +143,8 @@ def extract_template_numpy(
141143
traveltime_fname,
142144
traveltime_index_fname,
143145
traveltime_mask_fname,
144-
mseed_path,
146+
picks_group,
145147
events,
146-
picks,
147-
stations,
148148
config,
149149
lock,
150150
):
@@ -158,80 +158,73 @@ def extract_template_numpy(
158158
)
159159
traveltime_mask = np.memmap(traveltime_mask_fname, dtype=bool, mode="r+", shape=tuple(config["traveltime_shape"]))
160160

161-
## Load waveforms
162161
waveforms_dict = {}
163-
for i, station in stations.iterrows():
164-
station_id = station["station_id"]
165-
# for c in station["component"]:
166-
for c in ["E", "N", "Z", "1", "2", "3"]:
167-
mseed_name = f"{mseed_path}/{station_id}{c}.mseed"
168-
if os.path.exists(mseed_name):
169-
try:
170-
stream = obspy.read(mseed_name)
171-
stream.merge(fill_value="latest")
172-
if len(stream) > 1:
173-
print(f"More than one trace: {stream}")
174-
trace = stream[0]
175-
if trace.stats.sampling_rate != config["sampling_rate"]:
176-
if trace.stats.sampling_rate % config["sampling_rate"] == 0:
177-
trace.decimate(int(trace.stats.sampling_rate / config["sampling_rate"]))
178-
else:
179-
trace.resample(config["sampling_rate"])
180-
# trace.detrend("linear")
181-
# trace.taper(max_percentage=0.05, type="cosine")
182-
trace.filter("bandpass", freqmin=2.0, freqmax=12.0, corners=4, zerophase=True)
183-
waveforms_dict[f"{station_id}{c}"] = trace
184-
except Exception as e:
185-
print(e)
186-
continue
187-
188-
## Cut templates
189-
for (idx_eve, idx_sta, phase_type), pick in picks.iterrows():
190-
191-
idx_pick = pick["idx_pick"]
192-
phase_timestamp = pick["phase_timestamp"]
193-
194-
station = stations.loc[idx_sta]
195-
station_id = station["station_id"]
196-
event = events.loc[idx_eve]
162+
for picks in picks_group:
163+
164+
# waveforms_dict = {}
165+
picks = picks.set_index(["idx_eve", "idx_sta", "phase_type"])
166+
picks_index = list(picks.index.unique())
167+
168+
## Cut templates
169+
for (idx_eve, idx_sta, phase_type), pick in picks.iterrows():
170+
171+
idx_pick = pick["idx_pick"]
172+
phase_timestamp = pick["phase_timestamp"]
173+
174+
event = events.loc[idx_eve]
175+
ENZ = pick["ENZ"].split(",")
176+
177+
for c in ENZ:
178+
if c not in waveforms_dict:
179+
with fsspec.open(c, "rb", anon=True) as f:
180+
stream = obspy.read(f)
181+
stream.merge(fill_value="latest")
182+
if len(stream) > 1:
183+
print(f"More than one trace: {stream}")
184+
trace = stream[0]
185+
if trace.stats.sampling_rate != config["sampling_rate"]:
186+
if trace.stats.sampling_rate % config["sampling_rate"] == 0:
187+
trace.decimate(int(trace.stats.sampling_rate / config["sampling_rate"]))
188+
else:
189+
trace.resample(config["sampling_rate"])
190+
# trace.detrend("linear")
191+
# trace.taper(max_percentage=0.05, type="cosine")
192+
trace.filter("bandpass", freqmin=2.0, freqmax=12.0, corners=4, zerophase=True)
193+
waveforms_dict[c] = trace
194+
else:
195+
trace = waveforms_dict[c]
196+
197+
ic = config["component_mapping"][trace.stats.channel[-1]]
197198

198-
# for c in station["component"]:
199-
for c in ["E", "N", "Z", "1", "2", "3"]:
200-
ic = config["component_mapping"][c] # 012 for P, 345 for S
201-
202-
if f"{station_id}{c}" in waveforms_dict:
203-
trace = waveforms_dict[f"{station_id}{c}"]
204199
trace_starttime = (
205200
pd.to_datetime(trace.stats.starttime.datetime, utc=True) - reference_t0
206201
).total_seconds()
207-
else:
208-
continue
209202

210-
begin_time = phase_timestamp - trace_starttime - config[f"time_before_{phase_type.lower()}"]
211-
end_time = phase_timestamp - trace_starttime + config[f"time_after_{phase_type.lower()}"]
203+
begin_time = phase_timestamp - trace_starttime - config[f"time_before_{phase_type.lower()}"]
204+
end_time = phase_timestamp - trace_starttime + config[f"time_after_{phase_type.lower()}"]
212205

213-
if phase_type == "P" and ((idx_eve, idx_sta, "S") in picks.index):
214-
s_begin_time = (
215-
picks.loc[idx_eve, idx_sta, "S"]["phase_timestamp"] - trace_starttime - config[f"time_before_s"]
216-
)
217-
if config["no_overlapping"]:
218-
end_time = min(end_time, s_begin_time)
206+
if phase_type == "P" and ((idx_eve, idx_sta, "S") in picks.index):
207+
s_begin_time = (
208+
picks.loc[idx_eve, idx_sta, "S"]["phase_timestamp"] - trace_starttime - config[f"time_before_s"]
209+
)
210+
if config["no_overlapping"]:
211+
end_time = min(end_time, s_begin_time)
219212

220-
begin_time_index = max(0, int(round(begin_time * config["sampling_rate"])))
221-
end_time_index = max(0, int(round(end_time * config["sampling_rate"])))
213+
begin_time_index = max(0, int(round(begin_time * config["sampling_rate"])))
214+
end_time_index = max(0, int(round(end_time * config["sampling_rate"])))
222215

223-
## define traveltime at the exact data point of event origin time
224-
traveltime_array[idx_pick, ic, 0] = begin_time_index / config["sampling_rate"] - (
225-
event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"]
226-
)
227-
traveltime_index_array[idx_pick, ic, 0] = begin_time_index - int(
228-
(event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"])
229-
* config["sampling_rate"]
230-
)
231-
traveltime_mask[idx_pick, ic, 0] = True
216+
## define traveltime at the exact data point of event origin time
217+
traveltime_array[idx_pick, ic, 0] = begin_time_index / config["sampling_rate"] - (
218+
event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"]
219+
)
220+
traveltime_index_array[idx_pick, ic, 0] = begin_time_index - int(
221+
(event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"])
222+
* config["sampling_rate"]
223+
)
224+
traveltime_mask[idx_pick, ic, 0] = True
232225

233-
trace_data = trace.data[begin_time_index:end_time_index].astype(np.float32)
234-
template_array[idx_pick, ic, 0, : len(trace_data)] = trace_data
226+
trace_data = trace.data[begin_time_index:end_time_index].astype(np.float32)
227+
template_array[idx_pick, ic, 0, : len(trace_data)] = trace_data
235228

236229
if lock is not None:
237230
with lock:
@@ -240,7 +233,7 @@ def extract_template_numpy(
240233
traveltime_index_array.flush()
241234
traveltime_mask.flush()
242235

243-
return mseed_path
236+
return
244237

245238

246239
# %%
@@ -508,65 +501,74 @@ def cut_templates(root_path, region, config):
508501
config["reference_t0"] = reference_t0
509502
events = events[["idx_eve", "x_km", "y_km", "z_km", "event_index", "event_time", "event_timestamp"]]
510503
stations = stations[["idx_sta", "x_km", "y_km", "z_km", "station_id", "component", "network", "station"]]
511-
picks = picks[["idx_eve", "idx_sta", "phase_type", "phase_score", "phase_time", "phase_timestamp", "phase_source"]]
504+
picks = picks[
505+
[
506+
"idx_eve",
507+
"idx_sta",
508+
"phase_type",
509+
"phase_score",
510+
"phase_time",
511+
"phase_timestamp",
512+
"phase_source",
513+
"station_id",
514+
]
515+
]
512516
events.set_index("idx_eve", inplace=True)
513517
stations.set_index("idx_sta", inplace=True)
514518
picks.sort_values(by=["idx_eve", "idx_sta", "phase_type"], inplace=True)
515519
picks["idx_pick"] = np.arange(len(picks))
516520

517521
picks.to_csv(f"{root_path}/{result_path}/cctorch_picks.csv", index=False)
518522

519-
## By hour
520-
# dirs = sorted(glob(f"{root_path}/{region}/waveforms/????/???/??"), reverse=True)
521-
## By day
522-
dirs = sorted(glob(f"{root_path}/{region}/waveforms/????/???"), reverse=True)
523+
## Find mseed files
524+
mseed_list = sorted(glob(f"{root_path}/{region}/waveforms/????/???/*.mseed"))
525+
subdir = 2
526+
527+
mseed_3c = defaultdict(list)
528+
for mseed in mseed_list:
529+
key = "/".join(mseed.replace(".mseed", "").split("/")[-subdir - 1 :])
530+
key = key[:-1] ## remove the channel suffix
531+
mseed_3c[key].append(mseed)
532+
print(f"Number of mseed files: {len(mseed_3c)}")
533+
534+
def parse_key(key):
535+
year, jday, name = key.split("/")
536+
network, station, location, instrument = name.split(".")
537+
return [year, jday, network, station, location, instrument]
538+
539+
mseeds = [parse_key(k) + [",".join(sorted(mseed_3c[k]))] for k in mseed_3c]
540+
mseeds = pd.DataFrame(mseeds, columns=["year", "jday", "network", "station", "location", "instrument", "ENZ"])
541+
542+
## Match picks with mseed files
543+
picks["network"] = picks["station_id"].apply(lambda x: x.split(".")[0])
544+
picks["station"] = picks["station_id"].apply(lambda x: x.split(".")[1])
545+
picks["location"] = picks["station_id"].apply(lambda x: x.split(".")[2])
546+
picks["instrument"] = picks["station_id"].apply(lambda x: x.split(".")[3])
547+
picks["year"] = picks["phase_time"].dt.strftime("%Y")
548+
picks["jday"] = picks["phase_time"].dt.strftime("%j")
549+
picks = picks.merge(mseeds, on=["network", "station", "location", "instrument", "year", "jday"])
550+
picks.drop(columns=["station_id", "network", "location", "instrument", "year", "jday"], inplace=True)
551+
552+
picks_group = picks.copy()
553+
picks_group = picks_group.groupby("ENZ")
523554

524555
ncpu = min(16, mp.cpu_count())
556+
nsplit = min(ncpu * 2, len(picks_group))
525557
print(f"Using {ncpu} cores")
526558

527-
pbar = tqdm(total=len(dirs), desc="Cutting templates")
528-
529-
def pbar_update(x):
530-
"""
531-
x: the return value of extract_template_numpy
532-
"""
533-
pbar.update()
534-
pbar.set_description(f"Cutting templates: {'/'.join(x.split('/')[-3:])}")
559+
pbar = tqdm(total=nsplit, desc="Cutting templates")
535560

536561
ctx = mp.get_context("spawn")
537-
picks_group = picks.copy()
538-
## By hour
539-
# picks_group["year_jday_hour"] = picks_group["phase_time"].dt.strftime("%Y-%jT%H")
540-
# picks_group = picks_group.groupby("year_jday_hour")
541-
## By day
542-
picks_group["year_jday"] = picks_group["phase_time"].dt.strftime("%Y-%j")
543-
picks_group = picks_group.groupby("year_jday")
544562

545563
with ctx.Manager() as manager:
546564
lock = manager.Lock()
547565
with ctx.Pool(ncpu) as pool:
548566
jobs = []
549-
for d in dirs:
550-
551-
tmp = d.split("/")
552-
## By hour
553-
# year, jday, hour = tmp[-3:]
554-
## By day
555-
year, jday = tmp[-2:]
556-
557-
## By hour
558-
# if f"{year}-{jday}T{hour}" not in picks_group.groups:
559-
## By day
560-
if f"{year}-{jday}" not in picks_group.groups:
561-
pbar_update(d)
562-
continue
563567

564-
## By hour
565-
# picks_ = picks_group.get_group(f"{year}-{jday}T{hour}")
566-
## By day
567-
picks_ = picks_group.get_group(f"{year}-{jday}")
568-
events_ = events.loc[picks_["idx_eve"].unique()]
569-
picks_ = picks_.set_index(["idx_eve", "idx_sta", "phase_type"])
568+
group_chunk = np.array_split(list(picks_group.groups.keys()), nsplit)
569+
picks_group_chunk = [[picks_group.get_group(g) for g in group] for group in group_chunk]
570+
571+
for picks_group in picks_group_chunk:
570572

571573
job = pool.apply_async(
572574
extract_template_numpy,
@@ -575,14 +577,12 @@ def pbar_update(x):
575577
traveltime_fname,
576578
traveltime_index_fname,
577579
traveltime_mask_fname,
578-
d,
579-
events_,
580-
picks_,
581-
stations,
580+
picks_group,
581+
events,
582582
config,
583583
lock,
584584
),
585-
callback=pbar_update,
585+
callback=lambda x: pbar.update(),
586586
)
587587
jobs.append(job)
588588
pool.close()

0 commit comments

Comments
 (0)