1111import numpy as np
1212import obspy
1313import pandas as pd
14- from adloc .eikonal2d import calc_traveltime , init_eikonal2d
14+ from args import parse_args
15+ from cut_templates_merge import generate_pairs
1516from pyproj import Proj
1617from sklearn .neighbors import NearestNeighbors
1718from tqdm import tqdm
18- from args import parse_args
19- from cut_templates_merge import generate_pairs
19+
20+ from adloc . eikonal2d import calc_traveltime , init_eikonal2d
2021
2122np .random .seed (42 )
2223
2324
2425# %%
2526def fillin_missing_picks (picks , events , stations , config ):
26-
2727 # reference_t0 = config["reference_t0"]
2828 reference_t0 = pd .Timestamp (config ["reference_t0" ])
2929 vp_vs_ratio = config ["vp_vs_ratio" ]
@@ -70,9 +70,13 @@ def fillin_missing_picks(picks, events, stations, config):
7070
7171 ## add provider
7272 if "provider" in picks .columns :
73- picks_ps = picks_ps .merge (picks [["event_index" , "station_id" , "provider" ]].drop_duplicates (), on = ["event_index" , "station_id" ])
73+ picks_ps = picks_ps .merge (
74+ picks [["event_index" , "station_id" , "provider" ]].drop_duplicates (), on = ["event_index" , "station_id" ]
75+ )
7476 else :
75- picks_ps = picks_ps .merge (picks [["event_index" , "station_id" ]].drop_duplicates (), on = ["event_index" , "station_id" ])
77+ picks_ps = picks_ps .merge (
78+ picks [["event_index" , "station_id" ]].drop_duplicates (), on = ["event_index" , "station_id" ]
79+ )
7680
7781 print (f"Original picks: { len (picks )} , Filled picks: { len (picks_ps )} " )
7882 print (picks_ps .iloc [:10 ])
@@ -83,7 +87,6 @@ def fillin_missing_picks(picks, events, stations, config):
8387
8488# %%
8589def predict_full_picks (picks , events , stations , config ):
86-
8790 vp_vs_ratio = config ["vp_vs_ratio" ]
8891 # reference_t0 = config["reference_t0"]
8992 reference_t0 = pd .Timestamp (config ["reference_t0" ])
@@ -158,12 +161,13 @@ def extract_template_numpy(
158161 config ,
159162 lock ,
160163):
161-
162164 # reference_t0 = config["reference_t0"]
163165 reference_t0 = pd .Timestamp (config ["reference_t0" ])
164166
165167 template_array = np .memmap (template_fname , dtype = np .float32 , mode = "r+" , shape = tuple (config ["template_shape" ]))
166- traveltime_array = np .memmap (traveltime_fname , dtype = np .float32 , mode = "r+" , shape = tuple (config ["traveltime_shape" ]))
168+ traveltime_array = np .memmap (
169+ traveltime_fname , dtype = np .float32 , mode = "r+" , shape = tuple (config ["traveltime_shape" ])
170+ )
167171 traveltime_index_array = np .memmap (
168172 traveltime_index_fname , dtype = np .int32 , mode = "r+" , shape = tuple (config ["traveltime_shape" ])
169173 )
@@ -175,14 +179,12 @@ def extract_template_numpy(
175179 # traveltime_mask = np.zeros(tuple(config["traveltime_shape"]), dtype=bool)
176180
177181 for picks in picks_group :
178-
179182 waveforms_dict = {}
180183 picks = picks .set_index (["idx_eve" , "idx_sta" , "phase_type" ])
181184 picks_index = list (picks .index .unique ())
182185
183186 ## Cut templates
184187 for (idx_eve , idx_sta , phase_type ), pick in picks .iterrows ():
185-
186188 idx_pick = pick ["idx_pick" ]
187189 phase_timestamp = pick ["phase_timestamp" ]
188190
@@ -223,9 +225,10 @@ def extract_template_numpy(
223225 end_time = phase_timestamp - trace_starttime + config [f"time_after_{ phase_type .lower ()} " ]
224226
225227 if phase_type == "P" and ((idx_eve , idx_sta , "S" ) in picks_index ):
226-
227228 s_begin_time = (
228- picks .loc [idx_eve , idx_sta , "S" ]["phase_timestamp" ] - trace_starttime - config [f"time_before_s" ]
229+ picks .loc [idx_eve , idx_sta , "S" ]["phase_timestamp" ]
230+ - trace_starttime
231+ - config [f"time_before_s" ]
229232 )
230233 if config ["no_overlapping" ]:
231234 end_time = min (end_time , s_begin_time )
@@ -234,10 +237,14 @@ def extract_template_numpy(
234237 end_time_index = int (round (end_time * config ["sampling_rate" ]))
235238
236239 if begin_time_index < 0 :
237- print (f"Warning: { begin_time = } < 0, { trace_starttime = } , { event ['event_timestamp' ] = } , { config [f'time_before_{ phase_type .lower ()} ' ] = } " )
240+ print (
241+ f"Warning: { begin_time = } < 0, { trace_starttime = } , { event ['event_timestamp' ] = } , { config [f'time_before_{ phase_type .lower ()} ' ] = } "
242+ )
238243 continue
239244 if end_time_index > len (trace .data ):
240- print (f"Warning: { end_time = } > { len (trace .data )} , { trace_starttime = } , { event ['event_timestamp' ] = } , { config [f'time_after_{ phase_type .lower ()} ' ] = } " )
245+ print (
246+ f"Warning: { end_time = } > { len (trace .data )} , { trace_starttime = } , { event ['event_timestamp' ] = } , { config [f'time_after_{ phase_type .lower ()} ' ] = } "
247+ )
241248 continue
242249
243250 # begin_time_index = max(0, int(round(begin_time * config["sampling_rate"])))
@@ -315,7 +322,6 @@ def extract_template_numpy(
315322
316323# %%
317324def cut_templates (jdays , root_path , region , config , bucket , protocol , token ):
318-
319325 # %%
320326 fs = fsspec .filesystem (protocol , token = token )
321327
@@ -452,7 +458,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
452458 stations = pd .read_csv (fp )
453459 stations .sort_values (by = ["latitude" , "longitude" ], inplace = True )
454460
455-
456461 # stations = stations[stations["network"] == "7D"]
457462 print (f"{ len (stations ) = } " )
458463 print (stations .head ())
@@ -472,7 +477,8 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
472477 if protocol == "file" :
473478 events = pd .read_csv (
474479 f"{ root_path } /{ data_path } /{ year :04d} /adloc_events_{ jday :03d} .csv" , parse_dates = ["time" ]
475- # f"{root_path}/{data_path}/{year:04d}/ransac_events_{jday:03d}.csv", parse_dates=["time"]
480+ # f"{root_path}/{data_path}/{year:04d}/ransac_events_{jday:03d}.csv",
481+ parse_dates = ["time" ],
476482 )
477483 else :
478484 with fs .open (f"{ bucket } /{ data_path } /{ year :04d} /adloc_events_{ jday :03d} .csv" , "r" ) as fp :
@@ -572,7 +578,9 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
572578 )
573579 if "dist_km" not in picks :
574580 picks = picks .merge (stations [["station_id" , "x_km" , "y_km" , "z_km" ]], on = "station_id" )
575- picks .rename (columns = {"x_km" : "station_x_km" , "y_km" : "station_y_km" , "z_km" : "station_z_km" }, inplace = True )
581+ picks .rename (
582+ columns = {"x_km" : "station_x_km" , "y_km" : "station_y_km" , "z_km" : "station_z_km" }, inplace = True
583+ )
576584 picks = picks .merge (events [["event_index" , "x_km" , "y_km" , "z_km" ]], on = "event_index" )
577585 picks .rename (columns = {"x_km" : "event_x_km" , "y_km" : "event_y_km" , "z_km" : "event_z_km" }, inplace = True )
578586 picks ["dist_km" ] = np .linalg .norm (
@@ -655,7 +663,16 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
655663 config ["reference_t0" ] = reference_t0 .strftime ("%Y-%m-%dT%H:%M:%S.%fZ" )
656664 events = events [["idx_eve" , "x_km" , "y_km" , "z_km" , "event_index" , "event_time" , "event_timestamp" ]]
657665 stations = stations [["idx_sta" , "x_km" , "y_km" , "z_km" , "station_id" , "component" , "network" , "station" ]]
658- columns = ["idx_eve" , "idx_sta" , "phase_type" , "phase_score" , "phase_time" , "phase_timestamp" , "phase_source" , "station_id" ]
666+ columns = [
667+ "idx_eve" ,
668+ "idx_sta" ,
669+ "phase_type" ,
670+ "phase_score" ,
671+ "phase_time" ,
672+ "phase_timestamp" ,
673+ "phase_source" ,
674+ "station_id"
675+ ]
659676 if "provider" in picks .columns :
660677 columns .append ("provider" )
661678 picks = picks [columns ]
@@ -698,8 +715,7 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
698715 mseeds ["location" ] = mseeds ["fname" ].apply (lambda x : x [10 :12 ].strip ("_" ))
699716 mseeds ["year" ] = mseeds ["fname" ].apply (lambda x : x [13 :17 ])
700717 mseeds ["jday" ] = mseeds ["fname" ].apply (lambda x : x [17 :20 ])
701- if "provider" not in picks .columns :
702- mseeds ["provider" ] = "SC"
718+ mseeds ["provider" ] = "SC"
703719 elif folder == "NC" :
704720 mseeds ["fname" ] = mseeds ["ENZ" ].apply (lambda x : x .split ("/" )[- 1 ])
705721 mseeds ["network" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[1 ])
@@ -708,8 +724,7 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
708724 mseeds ["location" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[3 ])
709725 mseeds ["year" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[5 ])
710726 mseeds ["jday" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[6 ])
711- if "provider" not in picks .columns :
712- mseeds ["provider" ] = "NC"
727+ mseeds ["provider" ] = "NC"
713728 elif folder == "IRIS" :
714729 mseeds ["fname" ] = mseeds ["ENZ" ].apply (lambda x : x .split ("/" )[- 1 ])
715730 mseeds ["jday" ] = mseeds ["ENZ" ].apply (lambda x : x .split ("/" )[- 2 ])
@@ -718,8 +733,7 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
718733 mseeds ["station" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[1 ])
719734 mseeds ["location" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[2 ])
720735 mseeds ["instrument" ] = mseeds ["fname" ].apply (lambda x : x .split ("." )[3 ][:2 ])
721- if "provider" not in picks .columns :
722- mseeds ["provider" ] = "IRIS"
736+ mseeds ["provider" ] = "IRIS"
723737 else :
724738 raise ValueError (f"Unknown folder: { folder } " )
725739 mseeds_df .append (mseeds )
@@ -739,7 +753,9 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
739753 picks = picks [(picks ["year" ].astype (int ) == year ) & (picks ["jday" ].astype (int ) == jday )]
740754
741755 if "provider" in picks .columns :
742- picks = picks .merge (mseeds_df , on = ["network" , "station" , "location" , "instrument" , "year" , "jday" , "provider" ])
756+ picks = picks .merge (
757+ mseeds_df , on = ["network" , "station" , "location" , "instrument" , "year" , "jday" , "provider" ]
758+ )
743759 else :
744760 picks = picks .merge (mseeds_df , on = ["network" , "station" , "location" , "instrument" , "year" , "jday" ])
745761 picks .drop (columns = ["fname" , "station_id" , "network" , "location" , "instrument" , "year" , "jday" ], inplace = True )
@@ -748,8 +764,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
748764 print (f"No picks found for { year :04d} /{ jday :03d} " )
749765 continue
750766
751-
752-
753767 # ####
754768 # out = picks.drop(columns=["ENZ"])
755769 # out.to_csv(f"{root_path}/{result_path}/{year:04d}/cctorch_picks_{jday:03d}.csv", index=False)
@@ -770,7 +784,8 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
770784 print (f"Using { ncpu } cores" )
771785
772786 pbar = tqdm (total = nsplit , desc = "Cutting templates" )
773- ctx = mp .get_context ("fork" )
787+ # ctx = mp.get_context("fork")
788+ ctx = mp .get_context ("spawn" )
774789
775790 with ctx .Manager () as manager :
776791 lock = manager .Lock ()
@@ -781,7 +796,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
781796 picks_group_chunk = [[picks_group .get_group (g ) for g in group ] for group in group_chunk ]
782797
783798 for picks_group in picks_group_chunk :
784-
785799 job = pool .apply_async (
786800 extract_template_numpy ,
787801 (
@@ -857,7 +871,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
857871
858872# %%
859873if __name__ == "__main__" :
860-
861874 # %%
862875 protocol = "gs"
863876 token_json = f"application_default_credentials.json"
@@ -929,7 +942,6 @@ def cut_templates(jdays, root_path, region, config, bucket, protocol, token):
929942 # num_jday = 366 if (year % 4 == 0 and year % 100 != 0) or year % 400 == 0 else 365
930943 # jdays.extend([f"{year}.{i:03d}" for i in range(1, num_jday + 1)])
931944
932-
933945 num_jday = 366 if (year % 4 == 0 and year % 100 != 0 ) or year % 400 == 0 else 365
934946 # jdays = [f"{year}.{i:03d}" for i in range(1, num_jday + 1)]
935947 jdays = range (1 , num_jday + 1 )
0 commit comments