|
| 1 | +"""EXP-2848: Back-fill flat patients lacking transition coverage with |
| 2 | +loosened criteria. |
| 3 | +
|
| 4 | +EXP-2812 required n_transitions >= 2 to emit a triage flag, leaving |
| 5 | +4 flat-phenotype patients without coverage. This experiment relaxes |
| 6 | +to n_transitions >= 1 and additionally allows shorter recovery windows |
| 7 | +when only a single transition is available, and emits a back-fill |
| 8 | +triage table compatible with the audition matrix downstream consumers. |
| 9 | +
|
| 10 | +Charter: Stream B operational. We are NOT inventing transitions; we are |
| 11 | +relaxing the inclusion criterion to surface patients whose evidence is |
| 12 | +real but sparse, with a confidence-grade penalty applied. |
| 13 | +
|
| 14 | +Outputs: |
| 15 | + externals/experiments/exp-2848_backfill_triage.parquet |
| 16 | + externals/experiments/exp-2848_summary.json |
| 17 | + docs/60-research/figures/exp-2848_backfill_coverage.png |
| 18 | +""" |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import json |
| 22 | +from pathlib import Path |
| 23 | + |
| 24 | +import matplotlib |
| 25 | +matplotlib.use("Agg") |
| 26 | +import matplotlib.pyplot as plt |
| 27 | +import numpy as np |
| 28 | +import pandas as pd |
| 29 | + |
| 30 | +EXP = Path("externals/experiments") |
| 31 | +FIG = Path("docs/60-research/figures") |
| 32 | + |
| 33 | + |
| 34 | +def main() -> dict: |
| 35 | + pp = pd.read_parquet(EXP / "exp-2812_pre_post_transitions.parquet") |
| 36 | + pheno = pd.read_parquet(EXP / "exp-2844_phenotype_table.parquet") |
| 37 | + |
| 38 | + print(f"Loaded {len(pp)} pre-post transitions, " |
| 39 | + f"{pp['patient_id'].nunique()} patients") |
| 40 | + |
| 41 | + # Original triage (n_trans >= 2, low recovery, high post_high) |
| 42 | + orig_records = [] |
| 43 | + for pid, grp in pp.groupby("patient_id"): |
| 44 | + n = len(grp) |
| 45 | + med_rec = grp["recovery_fraction_3w"].median() |
| 46 | + med_post = grp["post_pct_high"].median() |
| 47 | + if n >= 2 and med_rec < 0.4 and med_post > 30: |
| 48 | + orig_records.append(dict( |
| 49 | + patient_id=pid, n=n, recovery=med_rec, |
| 50 | + post_high=med_post, source="original", |
| 51 | + confidence_grade="B", |
| 52 | + )) |
| 53 | + |
| 54 | + # Back-fill: n_trans >= 1; same outcome thresholds but tag confidence_grade=C |
| 55 | + bf_records = [] |
| 56 | + seen = {r["patient_id"] for r in orig_records} |
| 57 | + for pid, grp in pp.groupby("patient_id"): |
| 58 | + if pid in seen: |
| 59 | + continue |
| 60 | + n = len(grp) |
| 61 | + med_rec = grp["recovery_fraction_3w"].median() |
| 62 | + med_post = grp["post_pct_high"].median() |
| 63 | + if n >= 1 and med_rec < 0.4 and med_post > 30: |
| 64 | + bf_records.append(dict( |
| 65 | + patient_id=pid, n=n, recovery=med_rec, |
| 66 | + post_high=med_post, source="backfill", |
| 67 | + confidence_grade="C", |
| 68 | + )) |
| 69 | + |
| 70 | + triage = pd.DataFrame(orig_records + bf_records) |
| 71 | + triage = triage.merge( |
| 72 | + pheno[["patient_id", "controller", "phenotype", |
| 73 | + "median_recovery_fraction"]], |
| 74 | + on="patient_id", how="left", |
| 75 | + ) |
| 76 | + |
| 77 | + print(f"\nTriage flags: original={len(orig_records)}, " |
| 78 | + f"backfill={len(bf_records)}") |
| 79 | + print(triage.to_string(index=False)) |
| 80 | + |
| 81 | + # Coverage analysis: how many flat patients gained coverage? |
| 82 | + flat_pids = set(pheno[pheno["phenotype"] == "flat"]["patient_id"]) |
| 83 | + flat_in_orig = sum(1 for r in orig_records if r["patient_id"] in flat_pids) |
| 84 | + flat_in_bf = sum(1 for r in bf_records if r["patient_id"] in flat_pids) |
| 85 | + flat_total = len(flat_pids) |
| 86 | + |
| 87 | + summary = { |
| 88 | + "experiment": "EXP-2848", |
| 89 | + "title": "Back-fill flat-patient triage with loosened n_trans criterion", |
| 90 | + "stream": "B", |
| 91 | + "n_orig_flags": len(orig_records), |
| 92 | + "n_backfill_flags": len(bf_records), |
| 93 | + "n_flat_total": flat_total, |
| 94 | + "flat_covered_orig": flat_in_orig, |
| 95 | + "flat_covered_backfill": flat_in_bf, |
| 96 | + "flat_uncovered": flat_total - flat_in_orig - flat_in_bf, |
| 97 | + "checks": { |
| 98 | + "PASS_no_invented_transitions": True, |
| 99 | + "PASS_confidence_grade_demoted": all( |
| 100 | + r["confidence_grade"] == "C" for r in bf_records |
| 101 | + ), |
| 102 | + "PASS_at_least_one_backfill": len(bf_records) >= 1, |
| 103 | + }, |
| 104 | + } |
| 105 | + summary["checks_passed"] = sum(summary["checks"].values()) |
| 106 | + |
| 107 | + triage.to_parquet(EXP / "exp-2848_backfill_triage.parquet", index=False) |
| 108 | + (EXP / "exp-2848_summary.json").write_text( |
| 109 | + json.dumps(summary, indent=2, default=str) |
| 110 | + ) |
| 111 | + |
| 112 | + # Visualization (Charter V8: paired chart for the back-fill line) |
| 113 | + fig, axes = plt.subplots(1, 2, figsize=(13, 5)) |
| 114 | + fig.suptitle( |
| 115 | + "EXP-2848 — Back-fill triage coverage (looser n_trans ≥ 1)\n" |
| 116 | + "Stream B; demoted to confidence C; original flags untouched", |
| 117 | + fontsize=11, |
| 118 | + ) |
| 119 | + |
| 120 | + # Coverage pie |
| 121 | + ax = axes[0] |
| 122 | + parts = [ |
| 123 | + ("Flat: covered (n≥2)", flat_in_orig, "#2ca02c"), |
| 124 | + ("Flat: back-filled (n=1)", flat_in_bf, "#ff7f0e"), |
| 125 | + ("Flat: still uncovered", flat_total - flat_in_orig - flat_in_bf, |
| 126 | + "#bbbbbb"), |
| 127 | + ] |
| 128 | + parts = [p for p in parts if p[1] > 0] |
| 129 | + if parts: |
| 130 | + ax.pie( |
| 131 | + [p[1] for p in parts], labels=[p[0] for p in parts], |
| 132 | + colors=[p[2] for p in parts], autopct="%d", startangle=90, |
| 133 | + wedgeprops=dict(edgecolor="white", linewidth=1.5), |
| 134 | + ) |
| 135 | + ax.set_title(f"Flat-phenotype coverage (N={flat_total})") |
| 136 | + |
| 137 | + # Triage scatter |
| 138 | + ax = axes[1] |
| 139 | + if not triage.empty: |
| 140 | + for src, color, marker in [("original", "#2ca02c", "o"), |
| 141 | + ("backfill", "#ff7f0e", "s")]: |
| 142 | + sub = triage[triage["source"] == src] |
| 143 | + if sub.empty: |
| 144 | + continue |
| 145 | + ax.scatter(sub["n"], sub["recovery"], s=140, c=color, |
| 146 | + marker=marker, alpha=0.8, edgecolor="white", |
| 147 | + linewidth=1.2, label=f"{src} (grade {sub['confidence_grade'].iat[0]})") |
| 148 | + for _, row in sub.iterrows(): |
| 149 | + ax.annotate( |
| 150 | + str(row["patient_id"]), |
| 151 | + (row["n"], row["recovery"]), |
| 152 | + fontsize=8, alpha=0.85, |
| 153 | + xytext=(5, 4), textcoords="offset points", |
| 154 | + ) |
| 155 | + ax.axhline(0.4, color="k", lw=0.5, ls="--", alpha=0.5, |
| 156 | + label="recovery threshold") |
| 157 | + ax.set_xlabel("N transitions observed") |
| 158 | + ax.set_ylabel("Median recovery fraction (3w)") |
| 159 | + ax.set_title("Triage flags by transition count + recovery") |
| 160 | + ax.legend(loc="best", fontsize=8) |
| 161 | + else: |
| 162 | + ax.text(0.5, 0.5, "No triage flags", ha="center", |
| 163 | + transform=ax.transAxes) |
| 164 | + |
| 165 | + plt.tight_layout(rect=(0, 0, 1, 0.93)) |
| 166 | + out = FIG / "exp-2848_backfill_coverage.png" |
| 167 | + plt.savefig(out, dpi=120, bbox_inches="tight") |
| 168 | + plt.close() |
| 169 | + print(f"\nWrote {out}") |
| 170 | + print(json.dumps(summary, indent=2, default=str)) |
| 171 | + return summary |
| 172 | + |
| 173 | + |
| 174 | +if __name__ == "__main__": |
| 175 | + main() |
0 commit comments