Skip to content

Commit 696bf03

Browse files
committed
add code to compute the insertion index. TBD is whether I should also return the nlive array for dynamic sampling
1 parent 8812f7e commit 696bf03

1 file changed

Lines changed: 189 additions & 1 deletion

File tree

py/dynesty/utils.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
__all__ = [
3636
"unitcheck", "resample_equal", "mean_and_cov", "quantile", "jitter_run",
3737
"resample_run", "reweight_run", "unravel_run", "merge_runs", "kld_error",
38-
"LoglOutput", "LogLikelihood", "RunRecord", "DelayTimer"
38+
"compute_insertion_indices", "LoglOutput", "LogLikelihood", "RunRecord",
39+
"DelayTimer"
3940
]
4041

4142
SQRTEPS = math.sqrt(float(np.finfo(np.float64).eps))
@@ -1136,6 +1137,193 @@ def _get_nsamps_samples_n(res):
11361137
return nsamps, samples_n
11371138

11381139

1140+
def _compute_insertion_indices_static(res, strict=True):
1141+
"""
1142+
Compute insertion indices for a static nested-sampling run.
1143+
1144+
The insertion index at iteration ``k`` is the number of live points with
1145+
likelihood below the newly accepted point inserted at that iteration.
1146+
1147+
The reconstruction uses two saved per-sample fields:
1148+
1149+
- ``samples_id``: live-point lineage/slot label.
1150+
In static runs, when a point with ID ``j`` dies, the replacement point
1151+
keeps that same ID ``j``.
1152+
- ``samples_it``: proposal iteration of each point.
1153+
``samples_it == 0`` are initial live points; ``samples_it == k`` means
1154+
the point was inserted at iteration ``k``.
1155+
1156+
Parameters
1157+
----------
1158+
res : :class:`~dynesty.results.Results`
1159+
Results object from a static nested-sampling run.
1160+
strict : bool, optional
1161+
If ``True``, raise an error when exact reconstruction is impossible.
1162+
For static runs this typically requires final live points to be
1163+
present (i.e. ``len(res.logl) == res.niter + res.nlive``).
1164+
1165+
Returns
1166+
-------
1167+
insertion : `~numpy.ndarray`
1168+
Array of insertion indices of length ``res.niter``.
1169+
1170+
"""
1171+
if res.isdynamic():
1172+
raise ValueError('Insertion-index recovery here only supports static '
1173+
'nested sampling results.')
1174+
1175+
niter = int(res.niter)
1176+
nlive = int(res.nlive)
1177+
ids = np.asarray(res.samples_id, dtype=int)
1178+
its = np.asarray(res.samples_it, dtype=int)
1179+
logl = np.asarray(res.logl, dtype=float)
1180+
1181+
nsamps = len(logl)
1182+
if nsamps not in [niter, niter + nlive]:
1183+
raise ValueError('Inconsistent static results: len(logl) must equal '
1184+
'niter or niter+nlive.')
1185+
1186+
if nsamps == niter and strict:
1187+
raise ValueError('Exact insertion-index reconstruction requires final '
1188+
'live points to be included in results.')
1189+
1190+
# Build map (live-point id, proposal iteration) -> sample logl.
1191+
proposed_logl = {}
1192+
for pid, pit, ll in zip(ids, its, logl):
1193+
key = (int(pid), int(pit))
1194+
if key in proposed_logl and not np.isclose(proposed_logl[key], ll):
1195+
raise ValueError('Duplicate (samples_id, samples_it) entries have '
1196+
'inconsistent logl values.')
1197+
proposed_logl[key] = float(ll)
1198+
1199+
# Live set before iteration 1 is points proposed at iteration 0.
1200+
active = {}
1201+
init_mask = its == 0
1202+
for pid, ll in zip(ids[init_mask], logl[init_mask]):
1203+
pid = int(pid)
1204+
if pid in active and not np.isclose(active[pid], ll):
1205+
raise ValueError('Inconsistent initial live-point records for a '
1206+
'shared samples_id.')
1207+
active[pid] = float(ll)
1208+
1209+
if len(active) != nlive:
1210+
msg = (f'Only {len(active)} initial live points recovered but '
1211+
f'nlive={nlive}. Full reconstruction typically needs '
1212+
'final live points in results.')
1213+
if strict:
1214+
raise ValueError(msg)
1215+
out = np.full(niter, np.nan)
1216+
return out
1217+
1218+
dead_ids = ids[:niter]
1219+
insertion = np.empty(niter, dtype=float)
1220+
for k in range(1, niter + 1):
1221+
dead_id = int(dead_ids[k - 1])
1222+
key_new = (dead_id, k)
1223+
if key_new not in proposed_logl:
1224+
if strict:
1225+
raise ValueError('Missing replacement point for iteration '
1226+
f'{k} (samples_id={dead_id}, '
1227+
f'samples_it={k}).')
1228+
insertion[k - 1:] = np.nan
1229+
break
1230+
1231+
new_logl = proposed_logl[key_new]
1232+
live_logl = np.fromiter(active.values(), dtype=float)
1233+
# Insertion index convention: strict rank among current live points.
1234+
insertion[k - 1] = np.sum(live_logl < new_logl)
1235+
1236+
active[dead_id] = new_logl
1237+
1238+
return insertion
1239+
1240+
1241+
def compute_insertion_indices(res, strict=True):
1242+
"""
1243+
Compute insertion indices aligned with all samples in ``res``.
1244+
1245+
For static runs, returns an array of length ``len(res.logl)`` where the
1246+
first ``res.niter`` entries correspond to insertion indices for dead
1247+
points and any remaining entries (final added live points) are ``NaN``.
1248+
1249+
For dynamic runs, computes insertion indices separately for each batch
1250+
(including the base run as batch 0) and writes them back into a single
1251+
array aligned with ``res.logl``. Within each batch, indices are computed
1252+
in the batch-local static sense.
1253+
1254+
Notes on key fields:
1255+
1256+
- ``samples_id`` tracks live-point lineage within each static-style run.
1257+
In dynamic results, IDs are unique globally, but each batch is still
1258+
internally static-like.
1259+
- ``samples_it`` stores proposal iteration. In dynamic runs this is a
1260+
global iteration counter, so each batch is rebased to start at 0 before
1261+
static reconstruction.
1262+
1263+
Parameters
1264+
----------
1265+
res : :class:`~dynesty.results.Results`
1266+
Results object from a static or dynamic run.
1267+
strict : bool, optional
1268+
If ``True``, raise when exact reconstruction is impossible.
1269+
If ``False``, missing portions are returned as ``NaN``.
1270+
1271+
Returns
1272+
-------
1273+
insertion_all : `~numpy.ndarray`
1274+
Array of insertion indices aligned with ``res.logl``.
1275+
1276+
"""
1277+
nsamps = len(res.logl)
1278+
insertion_all = np.full(nsamps, np.nan)
1279+
1280+
if not res.isdynamic():
1281+
ins = _compute_insertion_indices_static(res, strict=strict)
1282+
insertion_all[:len(ins)] = ins
1283+
return insertion_all
1284+
1285+
# Dynamic case: process each batch independently as a static-like run.
1286+
batches = np.unique(res.samples_batch)
1287+
for batch_id in batches:
1288+
sel = (res.samples_batch == batch_id)
1289+
idx = np.nonzero(sel)[0]
1290+
nsamp_batch = len(idx)
1291+
nlive_batch = int(res.batch_nlive[int(batch_id)])
1292+
niter_batch = nsamp_batch - nlive_batch
1293+
if niter_batch < 0:
1294+
raise ValueError(
1295+
f'Batch {batch_id} has fewer samples ({nsamp_batch}) than '
1296+
f'nlive ({nlive_batch}).')
1297+
1298+
# Rebase iterations so the batch-local run starts at 0.
1299+
samples_it_local = np.asarray(res.samples_it[sel], dtype=int)
1300+
if len(samples_it_local) > 0:
1301+
samples_it_local = samples_it_local - samples_it_local.min()
1302+
1303+
static_like = Results(
1304+
dict(nlive=nlive_batch,
1305+
niter=niter_batch,
1306+
ncall=res.ncall[sel],
1307+
eff=float(100. * nsamp_batch / np.sum(res.ncall[sel])),
1308+
samples=res.samples[sel],
1309+
samples_id=res.samples_id[sel],
1310+
samples_it=samples_it_local,
1311+
samples_u=res.samples_u[sel],
1312+
blob=res.blob[sel],
1313+
logwt=res.logwt[sel],
1314+
logl=res.logl[sel],
1315+
logvol=res.logvol[sel],
1316+
logz=res.logz[sel],
1317+
logzerr=res.logzerr[sel],
1318+
information=res.information[sel]))
1319+
1320+
ins_batch = _compute_insertion_indices_static(static_like,
1321+
strict=strict)
1322+
insertion_all[idx[:len(ins_batch)]] = ins_batch
1323+
1324+
return insertion_all
1325+
1326+
11391327
def _find_decrease(samples_n):
11401328
"""
11411329
Find all instances where the number of live points is either constant

0 commit comments

Comments
 (0)