|
35 | 35 | __all__ = [ |
36 | 36 | "unitcheck", "resample_equal", "mean_and_cov", "quantile", "jitter_run", |
37 | 37 | "resample_run", "reweight_run", "unravel_run", "merge_runs", "kld_error", |
38 | | - "LoglOutput", "LogLikelihood", "RunRecord", "DelayTimer" |
| 38 | + "compute_insertion_indices", "LoglOutput", "LogLikelihood", "RunRecord", |
| 39 | + "DelayTimer" |
39 | 40 | ] |
40 | 41 |
|
41 | 42 | SQRTEPS = math.sqrt(float(np.finfo(np.float64).eps)) |
@@ -1136,6 +1137,193 @@ def _get_nsamps_samples_n(res): |
1136 | 1137 | return nsamps, samples_n |
1137 | 1138 |
|
1138 | 1139 |
|
| 1140 | +def _compute_insertion_indices_static(res, strict=True): |
| 1141 | + """ |
| 1142 | + Compute insertion indices for a static nested-sampling run. |
| 1143 | +
|
| 1144 | + The insertion index at iteration ``k`` is the number of live points with |
| 1145 | + likelihood below the newly accepted point inserted at that iteration. |
| 1146 | +
|
| 1147 | + The reconstruction uses two saved per-sample fields: |
| 1148 | +
|
| 1149 | + - ``samples_id``: live-point lineage/slot label. |
| 1150 | + In static runs, when a point with ID ``j`` dies, the replacement point |
| 1151 | + keeps that same ID ``j``. |
| 1152 | + - ``samples_it``: proposal iteration of each point. |
| 1153 | + ``samples_it == 0`` are initial live points; ``samples_it == k`` means |
| 1154 | + the point was inserted at iteration ``k``. |
| 1155 | +
|
| 1156 | + Parameters |
| 1157 | + ---------- |
| 1158 | + res : :class:`~dynesty.results.Results` |
| 1159 | + Results object from a static nested-sampling run. |
| 1160 | + strict : bool, optional |
| 1161 | + If ``True``, raise an error when exact reconstruction is impossible. |
| 1162 | + For static runs this typically requires final live points to be |
| 1163 | + present (i.e. ``len(res.logl) == res.niter + res.nlive``). |
| 1164 | +
|
| 1165 | + Returns |
| 1166 | + ------- |
| 1167 | + insertion : `~numpy.ndarray` |
| 1168 | + Array of insertion indices of length ``res.niter``. |
| 1169 | +
|
| 1170 | + """ |
| 1171 | + if res.isdynamic(): |
| 1172 | + raise ValueError('Insertion-index recovery here only supports static ' |
| 1173 | + 'nested sampling results.') |
| 1174 | + |
| 1175 | + niter = int(res.niter) |
| 1176 | + nlive = int(res.nlive) |
| 1177 | + ids = np.asarray(res.samples_id, dtype=int) |
| 1178 | + its = np.asarray(res.samples_it, dtype=int) |
| 1179 | + logl = np.asarray(res.logl, dtype=float) |
| 1180 | + |
| 1181 | + nsamps = len(logl) |
| 1182 | + if nsamps not in [niter, niter + nlive]: |
| 1183 | + raise ValueError('Inconsistent static results: len(logl) must equal ' |
| 1184 | + 'niter or niter+nlive.') |
| 1185 | + |
| 1186 | + if nsamps == niter and strict: |
| 1187 | + raise ValueError('Exact insertion-index reconstruction requires final ' |
| 1188 | + 'live points to be included in results.') |
| 1189 | + |
| 1190 | + # Build map (live-point id, proposal iteration) -> sample logl. |
| 1191 | + proposed_logl = {} |
| 1192 | + for pid, pit, ll in zip(ids, its, logl): |
| 1193 | + key = (int(pid), int(pit)) |
| 1194 | + if key in proposed_logl and not np.isclose(proposed_logl[key], ll): |
| 1195 | + raise ValueError('Duplicate (samples_id, samples_it) entries have ' |
| 1196 | + 'inconsistent logl values.') |
| 1197 | + proposed_logl[key] = float(ll) |
| 1198 | + |
| 1199 | + # Live set before iteration 1 is points proposed at iteration 0. |
| 1200 | + active = {} |
| 1201 | + init_mask = its == 0 |
| 1202 | + for pid, ll in zip(ids[init_mask], logl[init_mask]): |
| 1203 | + pid = int(pid) |
| 1204 | + if pid in active and not np.isclose(active[pid], ll): |
| 1205 | + raise ValueError('Inconsistent initial live-point records for a ' |
| 1206 | + 'shared samples_id.') |
| 1207 | + active[pid] = float(ll) |
| 1208 | + |
| 1209 | + if len(active) != nlive: |
| 1210 | + msg = (f'Only {len(active)} initial live points recovered but ' |
| 1211 | + f'nlive={nlive}. Full reconstruction typically needs ' |
| 1212 | + 'final live points in results.') |
| 1213 | + if strict: |
| 1214 | + raise ValueError(msg) |
| 1215 | + out = np.full(niter, np.nan) |
| 1216 | + return out |
| 1217 | + |
| 1218 | + dead_ids = ids[:niter] |
| 1219 | + insertion = np.empty(niter, dtype=float) |
| 1220 | + for k in range(1, niter + 1): |
| 1221 | + dead_id = int(dead_ids[k - 1]) |
| 1222 | + key_new = (dead_id, k) |
| 1223 | + if key_new not in proposed_logl: |
| 1224 | + if strict: |
| 1225 | + raise ValueError('Missing replacement point for iteration ' |
| 1226 | + f'{k} (samples_id={dead_id}, ' |
| 1227 | + f'samples_it={k}).') |
| 1228 | + insertion[k - 1:] = np.nan |
| 1229 | + break |
| 1230 | + |
| 1231 | + new_logl = proposed_logl[key_new] |
| 1232 | + live_logl = np.fromiter(active.values(), dtype=float) |
| 1233 | + # Insertion index convention: strict rank among current live points. |
| 1234 | + insertion[k - 1] = np.sum(live_logl < new_logl) |
| 1235 | + |
| 1236 | + active[dead_id] = new_logl |
| 1237 | + |
| 1238 | + return insertion |
| 1239 | + |
| 1240 | + |
| 1241 | +def compute_insertion_indices(res, strict=True): |
| 1242 | + """ |
| 1243 | + Compute insertion indices aligned with all samples in ``res``. |
| 1244 | +
|
| 1245 | + For static runs, returns an array of length ``len(res.logl)`` where the |
| 1246 | + first ``res.niter`` entries correspond to insertion indices for dead |
| 1247 | + points and any remaining entries (final added live points) are ``NaN``. |
| 1248 | +
|
| 1249 | + For dynamic runs, computes insertion indices separately for each batch |
| 1250 | + (including the base run as batch 0) and writes them back into a single |
| 1251 | + array aligned with ``res.logl``. Within each batch, indices are computed |
| 1252 | + in the batch-local static sense. |
| 1253 | +
|
| 1254 | + Notes on key fields: |
| 1255 | +
|
| 1256 | + - ``samples_id`` tracks live-point lineage within each static-style run. |
| 1257 | + In dynamic results, IDs are unique globally, but each batch is still |
| 1258 | + internally static-like. |
| 1259 | + - ``samples_it`` stores proposal iteration. In dynamic runs this is a |
| 1260 | + global iteration counter, so each batch is rebased to start at 0 before |
| 1261 | + static reconstruction. |
| 1262 | +
|
| 1263 | + Parameters |
| 1264 | + ---------- |
| 1265 | + res : :class:`~dynesty.results.Results` |
| 1266 | + Results object from a static or dynamic run. |
| 1267 | + strict : bool, optional |
| 1268 | + If ``True``, raise when exact reconstruction is impossible. |
| 1269 | + If ``False``, missing portions are returned as ``NaN``. |
| 1270 | +
|
| 1271 | + Returns |
| 1272 | + ------- |
| 1273 | + insertion_all : `~numpy.ndarray` |
| 1274 | + Array of insertion indices aligned with ``res.logl``. |
| 1275 | +
|
| 1276 | + """ |
| 1277 | + nsamps = len(res.logl) |
| 1278 | + insertion_all = np.full(nsamps, np.nan) |
| 1279 | + |
| 1280 | + if not res.isdynamic(): |
| 1281 | + ins = _compute_insertion_indices_static(res, strict=strict) |
| 1282 | + insertion_all[:len(ins)] = ins |
| 1283 | + return insertion_all |
| 1284 | + |
| 1285 | + # Dynamic case: process each batch independently as a static-like run. |
| 1286 | + batches = np.unique(res.samples_batch) |
| 1287 | + for batch_id in batches: |
| 1288 | + sel = (res.samples_batch == batch_id) |
| 1289 | + idx = np.nonzero(sel)[0] |
| 1290 | + nsamp_batch = len(idx) |
| 1291 | + nlive_batch = int(res.batch_nlive[int(batch_id)]) |
| 1292 | + niter_batch = nsamp_batch - nlive_batch |
| 1293 | + if niter_batch < 0: |
| 1294 | + raise ValueError( |
| 1295 | + f'Batch {batch_id} has fewer samples ({nsamp_batch}) than ' |
| 1296 | + f'nlive ({nlive_batch}).') |
| 1297 | + |
| 1298 | + # Rebase iterations so the batch-local run starts at 0. |
| 1299 | + samples_it_local = np.asarray(res.samples_it[sel], dtype=int) |
| 1300 | + if len(samples_it_local) > 0: |
| 1301 | + samples_it_local = samples_it_local - samples_it_local.min() |
| 1302 | + |
| 1303 | + static_like = Results( |
| 1304 | + dict(nlive=nlive_batch, |
| 1305 | + niter=niter_batch, |
| 1306 | + ncall=res.ncall[sel], |
| 1307 | + eff=float(100. * nsamp_batch / np.sum(res.ncall[sel])), |
| 1308 | + samples=res.samples[sel], |
| 1309 | + samples_id=res.samples_id[sel], |
| 1310 | + samples_it=samples_it_local, |
| 1311 | + samples_u=res.samples_u[sel], |
| 1312 | + blob=res.blob[sel], |
| 1313 | + logwt=res.logwt[sel], |
| 1314 | + logl=res.logl[sel], |
| 1315 | + logvol=res.logvol[sel], |
| 1316 | + logz=res.logz[sel], |
| 1317 | + logzerr=res.logzerr[sel], |
| 1318 | + information=res.information[sel])) |
| 1319 | + |
| 1320 | + ins_batch = _compute_insertion_indices_static(static_like, |
| 1321 | + strict=strict) |
| 1322 | + insertion_all[idx[:len(ins_batch)]] = ins_batch |
| 1323 | + |
| 1324 | + return insertion_all |
| 1325 | + |
| 1326 | + |
1139 | 1327 | def _find_decrease(samples_n): |
1140 | 1328 | """ |
1141 | 1329 | Find all instances where the number of live points is either constant |
|
0 commit comments