Skip to content

Commit b503ec0

Browse files
dr-davidpawel-czyz
andauthored
Subsampling (#65)
* added subsampling to workflow * small fix --------- Co-authored-by: Paweł Czyż <pczyz@protonmail.com>
1 parent b345548 commit b503ec0

File tree

7 files changed

+1047
-20
lines changed

7 files changed

+1047
-20
lines changed

examples/analyze_comparison.py

Lines changed: 235 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@
200200
# plt.show()
201201

202202

203-
# %% jupyter={"outputs_hidden": true}
203+
# %%
204204
pairwise_res = np.load(
205205
"../workflows/compare_clinical/results/config_ba1ba2/consolidated_pairwise_wastewater_results.npz"
206206
)
@@ -421,7 +421,7 @@ def make_plot(
421421
# )
422422

423423
ax.set_ylabel("rel. fitness / week")
424-
ax.set_title("Wastewater-Derived Selection advantage")
424+
ax.set_title("Wastewater-Derived Fitness advantage")
425425
ax.set_ylim(0, 0.2 * 7 * 100)
426426

427427
## clinical solutions
@@ -457,7 +457,7 @@ def make_plot(
457457
)
458458

459459
ax.set_ylabel("rel. fitness / week")
460-
ax.set_title(f"Clinical-Derived Selection Advantage")
460+
ax.set_title(f"Clinical-Derived Fitness Advantage")
461461
ax.set_ylim(0, 0.2 * 7 * 100)
462462

463463
# Ensure the 'date' column is in datetime format
@@ -504,15 +504,15 @@ def make_plot(
504504
clinical_totals,
505505
width=width,
506506
label="Clinical Samples",
507-
color="blue",
507+
color="orange",
508508
alpha=0.7,
509509
)
510510
ax.bar(
511511
ww_totals.index + pd.Timedelta(days=2), # Shift right for alignment
512512
ww_totals,
513513
width=width,
514514
label="Wastewater Samples",
515-
color="orange",
515+
color="blue",
516516
alpha=0.7,
517517
)
518518

@@ -590,6 +590,226 @@ def make_plot(
590590
for ax in axes:
591591
ax.set_xlim(x_min, x_max)
592592

593+
clinical_totals = clinical_totals[clinical_totals.index >= x_min]
594+
clinical_totals = clinical_totals[clinical_totals.index <= x_max]
595+
596+
# Print summary statistics
597+
print("Clinical Totals Summary:")
598+
print(f" Min: {clinical_totals.min()}")
599+
print(f" Max: {clinical_totals.max()}")
600+
print(f" Mean: {clinical_totals.mean():.2f}")
601+
print(f" Median: {clinical_totals.median()}")
602+
603+
print("\nWastewater Totals Summary:")
604+
print(f" Min: {ww_totals.min()}")
605+
print(f" Max: {ww_totals.max()}")
606+
print(f" Mean: {ww_totals.mean():.2f}")
607+
print(f" Median: {ww_totals.median()}")
608+
609+
610+
# %%
611+
# Plot
612+
fig, axes = plt.subplots(4, 2, figsize=(10, 8), sharey="none")
613+
614+
615+
variants = ["BA.2.86*", "JN.1*"]
616+
divisions = ["Zürich", "Geneva", "Ticino", "Graubünden", "Bern", "Sankt Gallen"]
617+
variants_evaluated = ["BA.2.86*", "JN.1*"]
618+
reference_variant = "EG.5*"
619+
folder = "config_jn1"
620+
621+
(
622+
config,
623+
grouped_ww_data,
624+
clin_freq,
625+
wastewater_df,
626+
clinical_df,
627+
grouped_clinical_data,
628+
variants_evaluated_index,
629+
variants_reference_index,
630+
x_min,
631+
x_max,
632+
merged_ww_data,
633+
pairwise_ww_res,
634+
pairwise_clin_res,
635+
) = load_data(folder, divisions, variants_evaluated, reference_variant).values()
636+
637+
make_plot(
638+
axes[:, 1],
639+
config,
640+
grouped_ww_data,
641+
clin_freq,
642+
wastewater_df,
643+
clinical_df,
644+
grouped_clinical_data,
645+
variants_evaluated_index,
646+
variants_reference_index,
647+
x_min,
648+
x_max,
649+
merged_ww_data,
650+
pairwise_ww_res,
651+
pairwise_clin_res,
652+
)
653+
654+
655+
variants = ["BA.1*", "BA.2*"]
656+
divisions = ["Zürich", "Geneva", "Ticino", "Graubünden", "Bern", "Sankt Gallen"]
657+
variants_evaluated = ["BA.2*"]
658+
reference_variant = "BA.1*"
659+
folder = "config_ba1ba2"
660+
661+
(
662+
config,
663+
grouped_ww_data,
664+
clin_freq,
665+
wastewater_df,
666+
clinical_df,
667+
grouped_clinical_data,
668+
variants_evaluated_index,
669+
variants_reference_index,
670+
x_min,
671+
x_max,
672+
merged_ww_data,
673+
pairwise_ww_res,
674+
pairwise_clin_res,
675+
) = load_data(folder, divisions, variants_evaluated, reference_variant).values()
676+
677+
make_plot(
678+
axes[:, 0],
679+
config,
680+
grouped_ww_data,
681+
clin_freq,
682+
wastewater_df,
683+
clinical_df,
684+
grouped_clinical_data,
685+
variants_evaluated_index,
686+
variants_reference_index,
687+
x_min,
688+
x_max,
689+
merged_ww_data,
690+
pairwise_ww_res,
691+
pairwise_clin_res,
692+
)
693+
694+
axes[0, 1].set_xlim([pd.to_datetime("2023-08-01"), pd.to_datetime("2024-01-01")])
695+
axes[1, 1].set_xlim([pd.to_datetime("2023-08-01"), pd.to_datetime("2024-01-01")])
696+
axes[2, 1].set_xlim([pd.to_datetime("2023-08-01"), pd.to_datetime("2024-01-01")])
697+
axes[3, 1].set_xlim([pd.to_datetime("2023-08-01"), pd.to_datetime("2024-01-01")])
698+
axes[0, 0].set_ylim([0, 1])
699+
axes[0, 1].set_ylim([0, 1])
700+
701+
cutoffs = [0.025, 0.05, 0.10] # Cutoff values
702+
break_dates = pd.to_datetime(
703+
[
704+
"2022-01-15 00:00:00",
705+
"2022-01-20 12:00:00",
706+
"2022-01-28 12:00:00",
707+
"2023-10-14 12:00:00",
708+
"2023-10-19 00:00:00",
709+
"2023-10-27 00:00:00",
710+
]
711+
)
712+
713+
# First three vlines for axes[0,0] and axes[1,0]
714+
for i in range(3):
715+
for ax in [axes[0, 0], axes[1, 0], axes[2, 0]]:
716+
ax.axvline(
717+
x=break_dates[i],
718+
color="black",
719+
linestyle="dashed",
720+
linewidth=1,
721+
)
722+
ax.text(
723+
break_dates[i],
724+
ax.get_ylim()[1] * 0.9, # Position at 90% of y-axis max
725+
f"{cutoffs[i] * 100}%",
726+
color="black",
727+
fontsize=10,
728+
ha="right",
729+
va="top",
730+
rotation=90,
731+
)
732+
733+
# Next three vlines for axes[0,1] and axes[1,1]
734+
for i in range(3, 6):
735+
for ax in [axes[0, 1], axes[1, 1], axes[2, 1]]:
736+
ax.axvline(
737+
x=break_dates[i],
738+
color="black",
739+
linestyle="dashed",
740+
linewidth=1,
741+
)
742+
ax.text(
743+
break_dates[i],
744+
ax.get_ylim()[1] * 0.9, # Position at 90% of y-axis max
745+
f"{cutoffs[i-3] * 100}%", # Using the same labels for both sets
746+
color="black",
747+
fontsize=10,
748+
ha="right",
749+
va="top",
750+
rotation=90,
751+
)
752+
753+
754+
# axes[2,0].set_yscale("log")
755+
# axes[2,1].set_yscale("log")
756+
757+
# Collect legend handles and labels from all axes
758+
handles = []
759+
labels = []
760+
for ax in axes.flat: # Iterate through all axes in the figure
761+
h, l = ax.get_legend_handles_labels()
762+
handles.extend(h)
763+
labels.extend(l)
764+
765+
# Deduplicate legend entries
766+
unique_legend = {}
767+
unique_handles = []
768+
for handle, label in zip(handles, labels):
769+
if label not in unique_legend:
770+
unique_legend[label] = handle
771+
unique_handles.append((handle, label))
772+
773+
# Create custom line handles for Wastewater and Clinical
774+
wastewater_line = mlines.Line2D(
775+
[], [], color="black", linestyle="-", label="Wastewater"
776+
)
777+
clinical_line = mlines.Line2D([], [], color="black", linestyle="--", label="Clinical")
778+
779+
# Add custom handles to the legend
780+
unique_handles.insert(0, (wastewater_line, "Wastewater"))
781+
unique_handles.insert(1, (clinical_line, "Clinical"))
782+
783+
# Apply the deduplicated legend
784+
fig.legend(
785+
[h for h, _ in unique_handles],
786+
[l for _, l in unique_handles],
787+
loc="center left",
788+
bbox_to_anchor=(1, 0.5),
789+
)
790+
791+
import string
792+
793+
# Generate panel labels: a, b, c, d, e, f, g, ...
794+
panel_labels = list(string.ascii_lowercase)
795+
# Loop through all subplots and label them
796+
for i, ax in enumerate(axes.flatten()):
797+
ax.text(
798+
-0.1,
799+
1.2, # Position: slightly above each subplot
800+
panel_labels[i], # Get the next letter
801+
transform=ax.transAxes, # Use subplot-relative coordinates
802+
fontsize=12,
803+
fontweight="bold",
804+
va="top",
805+
ha="right",
806+
)
807+
808+
809+
# fig.legend(loc="center left", bbox_to_anchor=(1, 0.5))
810+
# Ensure layout is correct
811+
plt.tight_layout()
812+
plt.show()
593813

594814
# %%
595815
# Plot
@@ -788,6 +1008,16 @@ def make_plot(
7881008
va="top",
7891009
ha="right",
7901010
)
1011+
axes[3, 0].set_yscale("log")
1012+
axes[3, 1].set_yscale("log")
1013+
1014+
# Get the current y-limits to synchronize them
1015+
y_min = min(axes[3, 0].get_ylim()[0], axes[3, 1].get_ylim()[0])
1016+
y_max = max(axes[3, 0].get_ylim()[1], axes[3, 1].get_ylim()[1])
1017+
1018+
# Apply the same y-limits to both axes
1019+
axes[3, 0].set_ylim(y_min, y_max)
1020+
axes[3, 1].set_ylim(y_min, y_max)
7911021

7921022

7931023
# fig.legend(loc="center left", bbox_to_anchor=(1, 0.5))

0 commit comments

Comments
 (0)