Skip to content

Commit a3ba958

Browse files
authored
added pairwise coeff (#61)
1 parent 6d1acd1 commit a3ba958

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

src/covvfit/_cli/infer.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,63 @@ def pprint(message):
417417
covariance=covariance_scaled,
418418
)
419419

420+
# Output pairwise fitness advantages
421+
422+
def make_relative_growths(theta_star):
423+
relative_growths = (
424+
qm.get_relative_growths(theta_star, n_variants=n_variants_effective)
425+
- time_scaler.t_min
426+
) / (time_scaler.t_max - time_scaler.t_min)
427+
relative_growths = jnp.concat([jnp.array([0]), relative_growths])
428+
relative_growths = relative_growths * DAYS_IN_A_WEEK
429+
430+
pairwise_diff = jnp.expand_dims(relative_growths, axis=1) - jnp.expand_dims(
431+
relative_growths, axis=0
432+
)
433+
434+
return pairwise_diff
435+
436+
pairwise_diffs = make_relative_growths(theta_star)
437+
jacob = jax.jacobian(make_relative_growths)(theta_star)
438+
standerr_relgrowths = qm.get_standard_errors(covariance_scaled, jacob)
439+
relgrowths_confint = qm.get_confidence_intervals(
440+
make_relative_growths(theta_star), standerr_relgrowths, 0.95
441+
)
442+
443+
df_diffs = (
444+
pd.DataFrame(
445+
pairwise_diffs, index=variants_effective, columns=variants_effective
446+
)
447+
.reset_index()
448+
.melt(id_vars="index")
449+
)
450+
df_diffs.columns = ["Variant", "Reference_Variant", "Estimate"]
451+
452+
# Create confidence interval DataFrames
453+
df_lower = (
454+
pd.DataFrame(
455+
relgrowths_confint[0], index=variants_effective, columns=variants_effective
456+
)
457+
.reset_index()
458+
.melt(id_vars="index")
459+
)
460+
df_upper = (
461+
pd.DataFrame(
462+
relgrowths_confint[1], index=variants_effective, columns=variants_effective
463+
)
464+
.reset_index()
465+
.melt(id_vars="index")
466+
)
467+
468+
df_lower.columns = ["Variant", "Reference_Variant", "Lower_CI"]
469+
df_upper.columns = ["Variant", "Reference_Variant", "Upper_CI"]
470+
471+
# Merge all data
472+
df_final = df_diffs.merge(df_lower, on=["Variant", "Reference_Variant"]).merge(
473+
df_upper, on=["Variant", "Reference_Variant"]
474+
)
475+
df_final.to_csv(output / "pairwise_fitnesses.csv")
476+
420477
# Create a plot
421478
colors = [config.plot.variant_colors[var] for var in variants_investigated]
422479

@@ -507,3 +564,4 @@ def remove_0th(arr):
507564
figure_spec.fig.legend(handles=handles, loc="outside center right", frameon=False)
508565

509566
figure_spec.fig.savefig(output / "figure.pdf")
567+
figure_spec.fig.savefig(output / "figure.png")

0 commit comments

Comments
 (0)