@@ -956,12 +956,12 @@ class ModelSelectParsimony(PipelineOp):
956956 The variable name for model names in the dataset
957957 sample_dim : str
958958 The dimension containing each sample
959- cutoff_threshold : float
959+ cutoff : float
960960 The chi-squared threshold for acceptable fits (default: 1.0)
961- model_complexity : dict[str, int] | None
961+ model_priority : dict[str, int] | None
962962 Dictionary mapping model names to their complexity (number of parameters)
963963 model_inputs : list[dict[str, Any]] | None
964- List of model configurations, used to determine complexity if model_complexity is None
964+ List of model configurations, used to determine complexity if model_priority is None
965965 server_id : str | None
966966 Server ID in the format "host:port" for remote execution, or None for local execution
967967 output_prefix : str
@@ -972,8 +972,8 @@ def __init__(
972972 all_chisq_var ,
973973 model_names_var ,
974974 sample_dim ,
975- cutoff_threshold = 1.0 ,
976- model_complexity = None ,
975+ cutoff = 1.0 ,
976+ model_priority = None ,
977977 model_inputs = None , # Added to support local complexity calculation
978978 server_id = None , # Made optional to support local operation
979979 output_prefix = 'Parsimony' ,
@@ -993,8 +993,8 @@ def __init__(
993993
994994 self .sample_dim = sample_dim
995995 self .model_names_var = model_names_var
996- self .cutoff_threshold = cutoff_threshold
997- self .model_complexity = model_complexity
996+ self .cutoff = cutoff
997+ self .model_priority = model_priority
998998 self .model_inputs = model_inputs
999999 self .all_chisq_var = all_chisq_var
10001000 self .server_id = server_id
@@ -1016,67 +1016,77 @@ def calculate(self, dataset):
10161016 """Method for selecting the model based on parsimony given a user defined ChiSq threshold """
10171017
10181018 self .dataset = dataset .copy (deep = True )
1019-
1020- bestChiSq_labels = self .dataset [self .all_chisq_var ].argmin (self .model_names_var )
1021- bestChiSq_label_names = np .array ([self .dataset [self .model_names_var ][i ].values for i in bestChiSq_labels .values ])
1022-
1019+
10231020 ### default behavior is that complexity is determined by number of free parameters.
10241021 ### this is an issue if the number of parameters is the same between models. You bank on them having wildly different ChiSq vals
10251022 ### could use a neighbor approach or some more intelligent selection methods
1026- if self .model_complexity is None :
1027- print ('aggregating complexity' )
1023+
10281024
1029- # Determine model complexity either from server or local model_inputs
1030- if self .server_id is not None :
1031- # Get complexity from server
1032- self .construct_clients ()
1033- aSAS_config = self .AutoSAS_client .get_config ('all' , interactive = True )['return_val' ]
1034- model_inputs = aSAS_config ['model_inputs' ]
1035- elif self .model_inputs is not None :
1036- # Use local model_inputs
1037- model_inputs = self .model_inputs
1038- else :
1039- raise ValueError ("Either server_id or model_inputs must be provided to calculate model complexity" )
1040-
1041- # Calculate complexity based on number of free parameters
1042- order = []
1043- for model in model_inputs :
1044- n_params = 0
1045- for p in model ['fit_params' ]:
1046- if model ['fit_params' ][p ]['bounds' ] != None :
1047- n_params += 1
1048- order .append (n_params )
1049- print (order )
1050- print (np .argsort (order ))
1051- self .model_complexity = np .argsort (order ).tolist ()
1052-
1053- # As written in dev full of jank...
1054- replacement_labels = bestChiSq_labels .copy (deep = True )
1055- all_chisq = self .dataset [self .all_chisq_var ]
1056- sorted_chisq = all_chisq .sortby (self .model_names_var , ascending = False ).values
1057-
1058- min_diff_chisq = np .array ([row [1 ] - row [0 ] for row in sorted_chisq ])
1059- next_best_idx = np .array ([np .argpartition (row ,1 )[1 ] for row in all_chisq ])
1060-
1061- for idx in range (len (replacement_labels )):
1062- chisq_set = all_chisq .min (dim = self .model_names_var ).values
1063-
1064- if (min_diff_chisq [idx ] <= self .cutoff_threshold ):
1065- best_model_index = replacement_labels [idx ]
1066- next_best_index = next_best_idx [idx ]
1067- bm_rank = self .model_complexity .index (best_model_index )
1068- nbm_rank = self .model_complexity .index (next_best_index )
1069-
1070- if (bm_rank > nbm_rank ):
1071- replacement_labels [idx ] = next_best_index
1025+ # Determine model complexity either from server or local model_inputs
1026+ if self .server_id is not None :
1027+ # Get complexity from server
1028+ self .construct_clients ()
1029+ aSAS_config = self .AutoSAS_client .get_config ('all' , interactive = True )['return_val' ]
1030+ model_inputs = aSAS_config ['model_inputs' ]
1031+ elif self .model_inputs is not None :
1032+ # Use local model_inputs
1033+ model_inputs = self .model_inputs
1034+ else :
1035+ raise ValueError ("Either server_id or model_inputs must be provided to calculate model complexity" )
1036+
1037+ # Calculate complexity based on number of free parameters
1038+ model_params = []
1039+ for model in model_inputs :
1040+ n_params = 0
1041+ for p in model ['fit_params' ]:
1042+ if model ['fit_params' ][p ]['bounds' ] != None :
1043+ n_params += 1
1044+ model_params .append (n_params )
1045+
1046+ if self .model_priority is None :
1047+ self .model_priority = np .argsort (model_params ).tolist ()
1048+
1049+ models = self .dataset [self .model_names_var ].values #extract models
1050+
1051+ # Sort models by priority
1052+ priority_order = [m for _ ,m in sorted (zip (self .model_priority ,models ))]
1053+
1054+ # Sort chi-squared and params accordingly
1055+ sorted_chisq = self .dataset [self .all_chisq_var ].sortby (self .model_names_var )
1056+
1057+ # Find best model per sample based on chi-squared
1058+ best_indices = sorted_chisq .argmin (dim = self .model_names_var )
1059+ best_chisq = sorted_chisq .min (dim = self .model_names_var )
1060+
1061+ # Iterate over samples to apply parsimony rule
1062+ selected_indices = []
1063+ for i in range (self .dataset .sizes [self .sample_dim ]):
1064+ chisq_values = sorted_chisq .isel (sample = i ).values
1065+ min_chisq = best_chisq .isel (sample = i ).item ()
1066+
1067+ # Find all models within cutoff
1068+ within_cutoff = np .where (chisq_values - min_chisq <= self .cutoff )[0 ]
10721069
1070+ # Choose the simplest model among them
1071+ simplest_idx = within_cutoff [np .argmin ([self .model_priority [i ] for i in within_cutoff ])]
1072+
1073+ # print(chisq_values)
1074+ # print(chisq_values - min_chisq)
1075+ # print(within_cutoff)
1076+ # print(self.model_priority)
1077+ # print(simplest_idx,'\n')
1078+ selected_indices .append (simplest_idx )
1079+
1080+ selected_indices = np .array (selected_indices )
1081+
1082+
10731083 self .output [self ._prefix_output ("labels" )] = xr .DataArray (
1074- data = replacement_labels ,
1084+ data = selected_indices ,
10751085 dims = [self .sample_dim ]
10761086 )
10771087
10781088 self .output [self ._prefix_output ("label_names" )] = xr .DataArray (
1079- data = [self . dataset [ self . model_names_var ]. values [ i ] for i in replacement_labels ],
1089+ data = [priority_order [ i ] for i in selected_indices ],
10801090 dims = [self .sample_dim ]
10811091 )
10821092 return self
@@ -1153,9 +1163,6 @@ def calculate(self, dataset):
11531163
11541164 self .dataset = dataset .copy (deep = True )
11551165
1156- bestChiSq_labels = self .dataset [self .all_chisq_var ].argmin (self .model_names_var ).values
1157- bestChiSq_label_names = np .array ([self .dataset [self .model_names_var ][i ].values for i in bestChiSq_labels ])
1158-
11591166 # Determine model complexity either from server or local model_inputs
11601167 if self .server_id is not None :
11611168 # Get complexity from server
@@ -1168,18 +1175,18 @@ def calculate(self, dataset):
11681175 else :
11691176 raise ValueError ("Either server_id or model_inputs must be provided to calculate model complexity" )
11701177
1171- # Calculate number of parameters for each model
1172- n = []
1178+ # Calculate number of parameters for each model, d
1179+ d = []
11731180 for model in model_inputs :
11741181 n_params = 0
11751182 for p in model ['fit_params' ]:
11761183 if model ['fit_params' ][p ]['bounds' ] != None :
11771184 n_params += 1
1178- n .append (n_params )
1179- n = np .array (n )
1185+ d .append (n_params )
1186+ d = np .array (d )
11801187
1181- ### chisq + 2*ln(d) = AIC
1182- AIC = np .array ([2 * np . log ( i ) + 2 * n for i in self .dataset [self .all_chisq_var ].values ])
1188+ ### AIC = 2*d + chisq
1189+ AIC = np .array ([2 * i + 2 * d for i in self .dataset [self .all_chisq_var ].values ])
11831190
11841191 AIC_labels = np .argmin (AIC , axis = 1 )
11851192 AIC_label_names = np .array ([self .dataset [self .model_names_var ][i ].values for i in AIC_labels ])
@@ -1348,13 +1355,84 @@ def calculate(self, dataset: xr.Dataset) -> Self:
13481355
13491356
13501357
1358+ class ModelSelectMostProbable (PipelineOp ):
1359+ """ModelSelectMostProbable is a pipeline operation for selecting the most probable model.
1360+
1361+ This class selects the model with the highest probability for each sample by
1362+ argmaxing over the 'probabilities' data variable.
1363+
1364+ Attributes
1365+ ----------
1366+ model_names_var : str
1367+ The variable name for model names in the dataset
1368+ sample_dim : str
1369+ The dimension containing each sample
1370+ model_inputs : list[dict[str, Any]] | None
1371+ List of model configurations (optional)
1372+ server_id : str | None
1373+ Server ID in the format "host:port" for remote execution, or None for local execution
1374+ output_prefix : str
1375+ Prefix to add to output variable names
1376+ """
1377+
1378+ def __init__ (
1379+ self ,
1380+ model_names_var ,
1381+ sample_dim ,
1382+ output_prefix = 'MostProbable' ,
1383+ name = "ModelSelection_MostProbable" ,
1384+ ** kwargs
1385+ ):
1386+ output_variables = ["labels" , "label_names" ]
1387+ super ().__init__ (
1388+ name = name ,
1389+ input_variable = [model_names_var ],
1390+ output_variable = [
1391+ output_prefix + "_" + o for o in listify (output_variables )
1392+ ],
1393+ output_prefix = output_prefix ,
1394+ )
1395+
1396+ self .sample_dim = sample_dim
1397+ self .model_names_var = model_names_var
13511398
13521399
1400+ def calculate (self , dataset ):
1401+ """Method for selecting the model with the highest probability for each sample
1402+
1403+ Raises
1404+ ------
1405+ ValueError
1406+ If 'probabilities' variable is not present in the dataset
1407+ """
1408+
1409+ self .dataset = dataset .copy (deep = True )
1410+
1411+ # Check if 'probabilities' variable exists in the dataset
1412+ if 'probabilities' not in self .dataset :
1413+ raise ValueError (
1414+ f"The 'probabilities' variable is required for { self .__class__ .__name__ } . "
1415+ "Please ensure the dataset contains a 'probabilities' variable with model probabilities."
1416+ )
1417+
1418+ # Determine the most probable model by argmaxing over probabilities
1419+ probabilities = self .dataset ['probabilities' ]
1420+
1421+ # Find the index of the maximum probability for each sample
1422+ most_probable_indices = probabilities .argmax (dim = self .model_names_var )
1423+ print (most_probable_indices )
1424+ # Get the corresponding model names
1425+ model_names = self .dataset [self .model_names_var ]
1426+ most_probable_labels = model_names .isel (** {self .model_names_var : most_probable_indices })
1427+
1428+ self .output [self ._prefix_output ("labels" )] = xr .DataArray (
1429+ data = most_probable_indices .values ,
1430+ dims = [self .sample_dim ]
1431+ )
1432+
1433+ self .output [self ._prefix_output ("label_names" )] = xr .DataArray (
1434+ data = most_probable_labels .values ,
1435+ dims = [self .sample_dim ]
1436+ )
1437+ return self
13531438
1354-
1355-
1356-
1357-
1358-
1359-
1360-
0 commit comments