Skip to content

Commit 63c0303

Browse files
authored
523 speed up post processing (#526)
* Allow post-processing with .npy for faster loading of the run. Test run now takes less than a minute to load versus hours. Still need to speed-up post-processing with using nestcheck which is now the bottleneck. * Simplified code by restructuring it so that less if-statements and double definition are used. * Improved readability * Small comment change
1 parent 2af6ad1 commit 63c0303

3 files changed

Lines changed: 107 additions & 54 deletions

File tree

xpsi/PostProcessing/_backends.py

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
_warning('Cannot create a GetDist sample backend.')
88

99
try:
10-
from nestcheck.data_processing import process_multinest_run
10+
from ._nestcheck_modifications import process_multinest_run
1111
from nestcheck.data_processing import process_polychord_run
1212
except ImportError:
1313
_warning('Cannot use nestcheck sample backend.')
@@ -39,25 +39,36 @@ def __init__(self, root, base_dir, use_nestcheck, transform=None,
3939
filerootpath =_os.path.join(base_dir, root)
4040
_filerootpath = filerootpath
4141

42+
# check whether to use .npy files or .txt files
43+
if _os.path.isfile(filerootpath+'.npy'):
44+
loader = _np.load
45+
saver = _np.save
46+
filetype = ".npy"
47+
else:
48+
loader = _np.loadtxt
49+
saver = _np.savetxt
50+
filetype = ".txt"
51+
print("Change output files to .npy to speed up post-processing")
52+
4253
if transform is not None:
43-
samples = _np.loadtxt(filerootpath+'.txt')
54+
samples = loader(filerootpath + filetype)
4455
ndims = samples.shape[1] - 2
4556
temp = transform(samples[0,2:], old_API=True)
4657
ntransform = len(temp) - ndims
4758

48-
_exists = _os.path.isfile(filerootpath+'_transformed.txt')
59+
_exists = _os.path.isfile(filerootpath+'_transformed'+filetype)
4960
if not _exists or overwrite_transformed:
5061
transformed = _np.zeros((samples.shape[0],
5162
samples.shape[1] + ntransform))
5263
transformed[:,:2] = samples[:,:2]
5364
for i in range(samples.shape[0]):
5465
transformed[i,2:] = transform(samples[i,2:], old_API=True)
55-
_np.savetxt(filerootpath+'_transformed.txt', transformed)
66+
saver(f'{filerootpath}_transformed{filetype}', transformed)
5667

5768
filerootpath += '_transformed'
5869
root += '_transformed'
5970

60-
super(NestedBackend, self).__init__(filepath=filerootpath+'.txt',**kwargs)
71+
super(NestedBackend, self).__init__(filepath=filerootpath+filetype,**kwargs)
6172

6273
if getdist is not None:
6374
# getdist backend
@@ -71,57 +82,50 @@ def __init__(self, root, base_dir, use_nestcheck, transform=None,
7182

7283
self.use_nestcheck = use_nestcheck
7384

74-
if self.use_nestcheck: # nestcheck backend
75-
if transform is not None:
76-
for ext in ['dead-birth.txt', 'phys_live-birth.txt']:
77-
_exists = _os.path.isfile(filerootpath + ext)
78-
if not _exists or overwrite_transformed:
79-
samples = _np.loadtxt(_filerootpath + ext)
80-
transformed = _np.zeros((samples.shape[0],
81-
samples.shape[1] + ntransform))
82-
transformed[:,ndims+ntransform:] = samples[:,ndims:]
83-
for i in range(samples.shape[0]):
84-
transformed[i,:ndims+ntransform] =\
85-
transform(samples[i,:ndims],
86-
old_API=True)
87-
88-
_np.savetxt(filerootpath + "-" + ext, transformed)
89-
90-
# .stats file with same root needed, but do not need to modify
91-
# the .stats file contents
92-
if not _os.path.isfile(filerootpath + '.stats'):
93-
if _os.path.isfile(_filerootpath + '.stats'):
94-
try:
95-
from shutil import copyfile as _copyfile
96-
except ImportError:
97-
pass
98-
else:
99-
_copyfile(_filerootpath + '.stats',
100-
filerootpath + '.stats')
101-
try:
102-
kwargs['implementation']
103-
except KeyError:
104-
print('Root %r sampling implementation not specified... '
105-
'assuming MultiNest for nestcheck...')
85+
if self.use_nestcheck and transform is not None:
86+
for ext in [f'dead-birth{filetype}', f'phys_live-birth{filetype}']:
87+
# save dead and live points for later use in process_multinest_run():
88+
if 'dead-birth' in ext:
89+
dead_points = samples
90+
if 'phys_live-birth' in ext:
91+
live_points = samples
92+
93+
if not _os.path.isfile(filerootpath + ext) or overwrite_transformed:
94+
# transform samples
95+
transformed = _np.zeros((samples.shape[0],
96+
samples.shape[1] + ntransform))
97+
transformed[:,ndims+ntransform:] = samples[:,ndims:]
98+
for i in range(samples.shape[0]):
99+
transformed[i,:ndims+ntransform] =\
100+
transform(samples[i,:ndims],
101+
old_API=True)
102+
# save transformed part
103+
saver(filerootpath + "-" + ext, transformed)
104+
105+
# .stats file with same root needed, but do not need to modify
106+
# the .stats file contents
107+
if not _os.path.isfile(filerootpath + '.stats'):
108+
if _os.path.isfile(_filerootpath + '.stats'):
109+
try:
110+
from shutil import copyfile as _copyfile
111+
except ImportError:
112+
pass
113+
else:
114+
_copyfile(_filerootpath + '.stats',
115+
filerootpath + '.stats')
116+
117+
# assuming multinest for nestcheck if not specified
118+
implementation = kwargs.get('implementation', 'multinest')
119+
120+
if implementation == 'multinest':
106121
try:
107-
self._nc_bcknd = process_multinest_run(root,
108-
base_dir=base_dir)
122+
self._nc_bcknd = process_multinest_run(dead_points, live_points, root, base_dir=base_dir)
109123
except FileNotFoundError:
110-
self._nc_bcknd = process_multinest_run(root+"-",
111-
base_dir=base_dir)
124+
self._nc_bcknd = process_multinest_run(dead_points, live_points, root + "-", base_dir=base_dir)
125+
elif implementation == 'polychord':
126+
self._nc_bcknd = process_polychord_run(root, base_dir=base_dir)
112127
else:
113-
if kwargs['implementation'] == 'multinest':
114-
try:
115-
self._nc_bcknd = process_multinest_run(root,
116-
base_dir=base_dir)
117-
except FileNotFoundError:
118-
self._nc_bcknd = process_multinest_run(root+"-",
119-
base_dir=base_dir)
120-
elif kwargs['implementation'] == 'polychord':
121-
self._nc_bcknd = process_polychord_run(root,
122-
base_dir=base_dir)
123-
else:
124-
raise ValueError('Cannot process with nestcheck.')
128+
raise ValueError('Cannot process with nestcheck.')
125129

126130
@property
127131
def getdist_backend(self):

xpsi/PostProcessing/_nestcheck_modifications.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import getdist
22
import getdist.plots
33
getdist.chains.print_load_details = False
4+
from ._global_imports import *
5+
from nestcheck.data_processing import process_samples_array
46

57
def getdist_kde(x, samples, weights, **kwargs):
68
"""
@@ -32,3 +34,47 @@ def getdist_kde(x, samples, weights, **kwargs):
3234
bcknd.get1DDensity('x').normalize(by='integral',
3335
in_place=True)
3436
return bcknd.get1DDensity('x').Prob(x)
37+
38+
def process_multinest_run(dead_points, live_points, file_root, base_dir, **kwargs):
39+
"""Modified version of nestcheck.data_processing.process_multinest_run().
40+
Loads data (both .txt and .npy) from a MultiNest run into the nestcheck
41+
dictionary format for analysis.
42+
43+
N.B. producing required output file containing information about the
44+
iso-likelihood contours within which points were sampled (where they were
45+
"born") requies MultiNest version 3.11 or later.
46+
47+
Parameters
48+
----------
49+
dead_points: ndarray
50+
Dead points
51+
live_points: ndarray
52+
Live points
53+
file_root: str
54+
Root name for output files. When running MultiNest, this is determined
55+
by the nest_root parameter.
56+
base_dir: str
57+
Directory containing output files. When running MultiNest, this is
58+
determined by the nest_root parameter.
59+
kwargs: dict, optional
60+
Passed to ns_run_utils.check_ns_run (via process_samples_array)
61+
62+
Returns
63+
-------
64+
ns_run: dict
65+
Nested sampling run dict (see the module docstring for more details).
66+
"""
67+
# Remove unnecessary final columns
68+
dead_points = dead_points[:, :-2]
69+
live_points = live_points[:, :-1]
70+
assert dead_points[:, -2].max() < live_points[:, -2].min(), (
71+
'final live points should have greater logls than any dead point!',
72+
dead_points, live_points)
73+
ns_run = process_samples_array(_np.vstack((dead_points, live_points)), **kwargs)
74+
assert _np.all(ns_run['thread_min_max'][:, 0] == -_np.inf), (
75+
'As MultiNest does not currently perform dynamic nested sampling, all '
76+
'threads should start by sampling the whole prior.')
77+
ns_run['output'] = {}
78+
ns_run['output']['file_root'] = file_root
79+
ns_run['output']['base_dir'] = base_dir
80+
return ns_run

xpsi/PostProcessing/_run.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def samples(self, filepath):
2525
if _os.path.isfile(filepath):
2626
# we should be able to load samples straightforwardly into
2727
# a NumPy array for the most basic type of sample lookup
28-
self._samples = _np.loadtxt(filepath)
28+
if filepath.endswith(".npy"):
29+
self._samples = _np.load(filepath)
30+
elif filepath.endswith(".txt"):
31+
self._samples = _np.loadtxt(filepath)
2932
else:
3033
raise ValueError('File %s does not exist.' % filepath)
3134

0 commit comments

Comments
 (0)