-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbank_add_duration.py
More file actions
76 lines (65 loc) · 2.52 KB
/
bank_add_duration.py
File metadata and controls
76 lines (65 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import h5py
from argparse import ArgumentParser
from pyseobnr.generate_waveform import GenerateWaveform
from tqdm import tqdm
import multiprocessing
import pandas as pd
def wf_wrapper(p):
p2 = {"approximant": "SEOBNRv5EHM",
"ModeArray": [(2,2)],
"f22_start": 20,
"lmax_nyquist": 1
}
p.update(p2)
index = p['index']
try:
wf = GenerateWaveform(p)
hp, _ = wf.generate_td_polarizations()
return index, abs(float(hp.epoch))
except Exception as e:
print(e)
return index, None
def main():
parser = ArgumentParser()
parser.add_argument('--bank', type=str, required=True,
help="Template bank")
parser.add_argument('--output', type=str, required=True,
help="Path to output bank with durations.")
parser.add_argument('--nprocesses', type=int, default=1,
help="Number of processes to use for waveform generation parallelization.")
args = parser.parse_args()
p = {}
duration_cache = {}
if args.bank.endswith('.csv'):
# Read the CSV file
df = pd.read_csv(args.bank)
df['index'] = df.index
# generate waveforms
param_list = ['index', 'mass1', 'mass2', 'spin1z', 'spin2z', 'eccentricity', 'rel_anomaly']
with multiprocessing.Pool(args.nprocesses) as pool:
for return_i, return_epoch in pool.imap_unordered(
wf_wrapper,
({k: df.loc[idx, k] for k in param_list} for idx in tqdm(df.index))
):
if return_epoch is not None:
duration_cache[return_i] = return_epoch
df['template_duration'] = df['index'].map(duration_cache)
df.drop('index', axis=1, inplace=True)
df.to_csv(args.output, index=False)
elif args.bank.endswith('.hdf5'):
with h5py.File(args.bank, 'r') as f:
for k in f.keys():
p[k] = f[k][:]
p['index'] = np.arange(len(p['approximant']))
sorti = np.argsort(list(duration_cache.keys()))
duration = np.array(list(duration_cache.values()))[sorti]
with h5py.File(args.output,'w') as f_write:
with h5py.File(args.bank,'r') as f_bank:
for k in f_bank.keys():
if k != 'template_duration':
f_write[k] = f_bank[k][()]
# https://github.com/h5py/h5py/issues/1329
f_write['template_duration'] = duration
if __name__ == "__main__":
main()