@@ -417,6 +417,63 @@ def pprint(message):
417417 covariance = covariance_scaled ,
418418 )
419419
420+ # Output pairwise fitness advantages
421+
422+ def make_relative_growths (theta_star ):
423+ relative_growths = (
424+ qm .get_relative_growths (theta_star , n_variants = n_variants_effective )
425+ - time_scaler .t_min
426+ ) / (time_scaler .t_max - time_scaler .t_min )
427+ relative_growths = jnp .concat ([jnp .array ([0 ]), relative_growths ])
428+ relative_growths = relative_growths * DAYS_IN_A_WEEK
429+
430+ pairwise_diff = jnp .expand_dims (relative_growths , axis = 1 ) - jnp .expand_dims (
431+ relative_growths , axis = 0
432+ )
433+
434+ return pairwise_diff
435+
436+ pairwise_diffs = make_relative_growths (theta_star )
437+ jacob = jax .jacobian (make_relative_growths )(theta_star )
438+ standerr_relgrowths = qm .get_standard_errors (covariance_scaled , jacob )
439+ relgrowths_confint = qm .get_confidence_intervals (
440+ make_relative_growths (theta_star ), standerr_relgrowths , 0.95
441+ )
442+
443+ df_diffs = (
444+ pd .DataFrame (
445+ pairwise_diffs , index = variants_effective , columns = variants_effective
446+ )
447+ .reset_index ()
448+ .melt (id_vars = "index" )
449+ )
450+ df_diffs .columns = ["Variant" , "Reference_Variant" , "Estimate" ]
451+
452+ # Create confidence interval DataFrames
453+ df_lower = (
454+ pd .DataFrame (
455+ relgrowths_confint [0 ], index = variants_effective , columns = variants_effective
456+ )
457+ .reset_index ()
458+ .melt (id_vars = "index" )
459+ )
460+ df_upper = (
461+ pd .DataFrame (
462+ relgrowths_confint [1 ], index = variants_effective , columns = variants_effective
463+ )
464+ .reset_index ()
465+ .melt (id_vars = "index" )
466+ )
467+
468+ df_lower .columns = ["Variant" , "Reference_Variant" , "Lower_CI" ]
469+ df_upper .columns = ["Variant" , "Reference_Variant" , "Upper_CI" ]
470+
471+ # Merge all data
472+ df_final = df_diffs .merge (df_lower , on = ["Variant" , "Reference_Variant" ]).merge (
473+ df_upper , on = ["Variant" , "Reference_Variant" ]
474+ )
475+ df_final .to_csv (output / "pairwise_fitnesses.csv" )
476+
420477 # Create a plot
421478 colors = [config .plot .variant_colors [var ] for var in variants_investigated ]
422479
@@ -507,3 +564,4 @@ def remove_0th(arr):
507564 figure_spec .fig .legend (handles = handles , loc = "outside center right" , frameon = False )
508565
509566 figure_spec .fig .savefig (output / "figure.pdf" )
567+ figure_spec .fig .savefig (output / "figure.png" )
0 commit comments