3737
3838
3939PARAMETER_NAMES_FOR_PLOTTING = {
40- "coef_lasso_shift " : "Lasso Penalty" ,
40+ "scale_coeff_lasso_shift " : "Lasso Penalty" ,
4141}
4242
4343
@@ -350,7 +350,7 @@ def all_mutations(self) -> tuple:
350350 @lru_cache (maxsize = 10 )
351351 def split_apply_combine_muts (
352352 self ,
353- groupby = ("dataset_name" , "coef_lasso_shift " ),
353+ groupby = ("dataset_name" , "scale_coeff_lasso_shift " ),
354354 aggregate_func = "mean" ,
355355 inner_merge_dataset_muts = True ,
356356 query = None ,
@@ -374,7 +374,7 @@ def split_apply_combine_muts(
374374 groupby : str or tuple of str or None, optional
375375 The attributes to group the fits by. If None, then group by all
376376 attributes except for the model, data, and step_loss attributes.
377- The default is ("dataset_name", "coef_lasso_shift ").
377+ The default is ("dataset_name", "scale_coeff_lasso_shift ").
378378 aggregate_func : str or callable, optional
379379 The function to aggregate the mutational dataframes within each group.
380380 The default is "mean".
@@ -524,7 +524,7 @@ def add_validation_loss(self, test_data, overwrite=False):
524524 def get_conditional_loss_df (self , query = None ):
525525 """
526526 Return a long form dataframe with columns
527- "dataset_name", "coef_lasso_shift ",
527+ "dataset_name", "scale_coeff_lasso_shift ",
528528 "split" ("training" or "validation"),
529529 "loss" (actual value), and "condition".
530530
@@ -541,7 +541,7 @@ def get_conditional_loss_df(self, query=None):
541541 if len (queried_fits ) == 0 :
542542 raise ValueError ("invalid query, no fits returned" )
543543
544- id_vars = ["dataset_name" , "coef_lasso_shift " ]
544+ id_vars = ["dataset_name" , "scale_coeff_lasso_shift " ]
545545 value_vars = [
546546 c for c in queried_fits .columns if "loss" in c and c != "step_loss"
547547 ]
@@ -559,7 +559,7 @@ def get_conditional_loss_df(self, query=None):
559559 def convergence_trajectory_df (
560560 self ,
561561 query = None ,
562- id_vars = ("dataset_name" , "coef_lasso_shift " ),
562+ id_vars = ("dataset_name" , "scale_coeff_lasso_shift " ),
563563 ):
564564 """
565565 Combine the converence trajectory dataframes of
@@ -776,7 +776,7 @@ def mut_param_traceplot(
776776 self ,
777777 mutations ,
778778 mut_param = "shift" ,
779- x = "coef_lasso_shift " ,
779+ x = "scale_coeff_lasso_shift " ,
780780 width_scalar = 100 ,
781781 height_scalar = 100 ,
782782 ** kwargs ,
@@ -821,7 +821,7 @@ def mut_param_traceplot(
821821 muts_df = muts_df .query ("mutation.isin(@mutations)" )
822822
823823 # check that we have multiple lasso penalty weights
824- if len (muts_df .coef_lasso_shift .unique ()) <= 1 :
824+ if len (muts_df .scale_coeff_lasso_shift .unique ()) <= 1 :
825825 raise ValueError (
826826 "invalid kwargs, must specify a subset of fits with "
827827 "multiple lasso penalty weights"
@@ -834,7 +834,7 @@ def mut_type(mut):
834834 muts_df = muts_df .assign (mut_type = muts_df .mutation .apply (mut_type ))
835835
836836 # melt conditions and stats cols, beta is already "tall"
837- # id_cols = ["coef_lasso_shift ", "mutation", "is_stop"]
837+ # id_cols = ["scale_coeff_lasso_shift ", "mutation", "is_stop"]
838838 id_cols = ["dataset_name" , x , "mut_type" , "mutation" ]
839839 stat_cols_to_keep = [c for c in muts_df .columns if c .startswith (mut_param )]
840840 if mut_param == "beta" :
@@ -891,7 +891,7 @@ def mut_type(mut):
891891
892892 def shift_sparsity (
893893 self ,
894- x = "coef_lasso_shift " ,
894+ x = "scale_coeff_lasso_shift " ,
895895 width_scalar = 100 ,
896896 height_scalar = 100 ,
897897 return_data = False ,
@@ -952,7 +952,7 @@ def mut_type(mut):
952952 alt .Chart (sparsity_df )
953953 .encode (
954954 x = alt .X (
955- "coef_lasso_shift " ,
955+ "scale_coeff_lasso_shift " ,
956956 type = "nominal" ,
957957 title = (
958958 PARAMETER_NAMES_FOR_PLOTTING [x ]
@@ -967,7 +967,7 @@ def mut_type(mut):
967967 ),
968968 color = alt .Color ("mut_type" , type = "nominal" , title = "Mutation type" ),
969969 tooltip = [
970- "coef_lasso_shift " ,
970+ "scale_coeff_lasso_shift " ,
971971 "sparsity" ,
972972 "mut_type" ,
973973 ],
@@ -995,7 +995,7 @@ def mut_type(mut):
995995
996996 def mut_param_dataset_correlation (
997997 self ,
998- x = "coef_lasso_shift " ,
998+ x = "scale_coeff_lasso_shift " ,
999999 width_scalar = 200 ,
10001000 height = 200 ,
10011001 return_data = False ,
@@ -1012,7 +1012,7 @@ def mut_param_dataset_correlation(
10121012 ----------
10131013 x : str, optional
10141014 The parameter to plot on the x-axis.
1015- The default is "coef_lasso_shift ".
1015+ The default is "scale_coeff_lasso_shift ".
10161016 width_scalar : int, optional
10171017 The width of the chart. The default is 150.
10181018 height : int, optional
0 commit comments