Skip to content

Commit 3700139

Browse files
authored
Merge pull request #40 from usnistgov/39-autosas_model_selection_broke
39 autosas model selection broke
2 parents 0bbb722 + d037fc5 commit 3700139

1 file changed

Lines changed: 152 additions & 74 deletions

File tree

AFL/double_agent/AutoSAS.py

Lines changed: 152 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -956,12 +956,12 @@ class ModelSelectParsimony(PipelineOp):
956956
The variable name for model names in the dataset
957957
sample_dim : str
958958
The dimension containing each sample
959-
cutoff_threshold : float
959+
cutoff : float
960960
The chi-squared threshold for acceptable fits (default: 1.0)
961-
model_complexity : dict[str, int] | None
961+
model_priority : dict[str, int] | None
962962
Dictionary mapping model names to their complexity (number of parameters)
963963
model_inputs : list[dict[str, Any]] | None
964-
List of model configurations, used to determine complexity if model_complexity is None
964+
List of model configurations, used to determine complexity if model_priority is None
965965
server_id : str | None
966966
Server ID in the format "host:port" for remote execution, or None for local execution
967967
output_prefix : str
@@ -972,8 +972,8 @@ def __init__(
972972
all_chisq_var,
973973
model_names_var,
974974
sample_dim,
975-
cutoff_threshold=1.0,
976-
model_complexity=None,
975+
cutoff=1.0,
976+
model_priority=None,
977977
model_inputs=None, # Added to support local complexity calculation
978978
server_id=None, # Made optional to support local operation
979979
output_prefix='Parsimony',
@@ -993,8 +993,8 @@ def __init__(
993993

994994
self.sample_dim = sample_dim
995995
self.model_names_var = model_names_var
996-
self.cutoff_threshold = cutoff_threshold
997-
self.model_complexity = model_complexity
996+
self.cutoff = cutoff
997+
self.model_priority = model_priority
998998
self.model_inputs = model_inputs
999999
self.all_chisq_var = all_chisq_var
10001000
self.server_id = server_id
@@ -1016,67 +1016,77 @@ def calculate(self, dataset):
10161016
"""Method for selecting the model based on parsimony given a user defined ChiSq threshold """
10171017

10181018
self.dataset = dataset.copy(deep=True)
1019-
1020-
bestChiSq_labels = self.dataset[self.all_chisq_var].argmin(self.model_names_var)
1021-
bestChiSq_label_names = np.array([self.dataset[self.model_names_var][i].values for i in bestChiSq_labels.values])
1022-
1019+
10231020
### default behavior is that complexity is determined by number of free parameters.
10241021
### this is an issue if the number of parameters is the same between models. You bank on them having wildly different ChiSq vals
10251022
### could use a neighbor approach or some more intelligent selection methods
1026-
if self.model_complexity is None:
1027-
print('aggregating complexity')
1023+
10281024

1029-
# Determine model complexity either from server or local model_inputs
1030-
if self.server_id is not None:
1031-
# Get complexity from server
1032-
self.construct_clients()
1033-
aSAS_config = self.AutoSAS_client.get_config('all', interactive=True)['return_val']
1034-
model_inputs = aSAS_config['model_inputs']
1035-
elif self.model_inputs is not None:
1036-
# Use local model_inputs
1037-
model_inputs = self.model_inputs
1038-
else:
1039-
raise ValueError("Either server_id or model_inputs must be provided to calculate model complexity")
1040-
1041-
# Calculate complexity based on number of free parameters
1042-
order = []
1043-
for model in model_inputs:
1044-
n_params = 0
1045-
for p in model['fit_params']:
1046-
if model['fit_params'][p]['bounds'] != None:
1047-
n_params += 1
1048-
order.append(n_params)
1049-
print(order)
1050-
print(np.argsort(order))
1051-
self.model_complexity = np.argsort(order).tolist()
1052-
1053-
# As written in dev full of jank...
1054-
replacement_labels = bestChiSq_labels.copy(deep=True)
1055-
all_chisq = self.dataset[self.all_chisq_var]
1056-
sorted_chisq = all_chisq.sortby(self.model_names_var, ascending=False).values
1057-
1058-
min_diff_chisq = np.array([row[1] - row[0] for row in sorted_chisq])
1059-
next_best_idx = np.array([np.argpartition(row,1)[1] for row in all_chisq])
1060-
1061-
for idx in range(len(replacement_labels)):
1062-
chisq_set = all_chisq.min(dim=self.model_names_var).values
1063-
1064-
if (min_diff_chisq[idx] <= self.cutoff_threshold):
1065-
best_model_index = replacement_labels[idx]
1066-
next_best_index = next_best_idx[idx]
1067-
bm_rank = self.model_complexity.index(best_model_index)
1068-
nbm_rank = self.model_complexity.index(next_best_index)
1069-
1070-
if (bm_rank > nbm_rank):
1071-
replacement_labels[idx] = next_best_index
1025+
# Determine model complexity either from server or local model_inputs
1026+
if self.server_id is not None:
1027+
# Get complexity from server
1028+
self.construct_clients()
1029+
aSAS_config = self.AutoSAS_client.get_config('all', interactive=True)['return_val']
1030+
model_inputs = aSAS_config['model_inputs']
1031+
elif self.model_inputs is not None:
1032+
# Use local model_inputs
1033+
model_inputs = self.model_inputs
1034+
else:
1035+
raise ValueError("Either server_id or model_inputs must be provided to calculate model complexity")
1036+
1037+
# Calculate complexity based on number of free parameters
1038+
model_params = []
1039+
for model in model_inputs:
1040+
n_params = 0
1041+
for p in model['fit_params']:
1042+
if model['fit_params'][p]['bounds'] != None:
1043+
n_params += 1
1044+
model_params.append(n_params)
1045+
1046+
if self.model_priority is None:
1047+
self.model_priority = np.argsort(model_params).tolist()
1048+
1049+
models = self.dataset[self.model_names_var].values #extract models
1050+
1051+
# Sort models by priority
1052+
priority_order = [m for _,m in sorted(zip(self.model_priority,models))]
1053+
1054+
# Sort chi-squared and params accordingly
1055+
sorted_chisq = self.dataset[self.all_chisq_var].sortby(self.model_names_var)
1056+
1057+
# Find best model per sample based on chi-squared
1058+
best_indices = sorted_chisq.argmin(dim=self.model_names_var)
1059+
best_chisq = sorted_chisq.min(dim=self.model_names_var)
1060+
1061+
# Iterate over samples to apply parsimony rule
1062+
selected_indices = []
1063+
for i in range(self.dataset.sizes[self.sample_dim]):
1064+
chisq_values = sorted_chisq.isel(sample=i).values
1065+
min_chisq = best_chisq.isel(sample=i).item()
1066+
1067+
# Find all models within cutoff
1068+
within_cutoff = np.where(chisq_values - min_chisq <= self.cutoff)[0]
10721069

1070+
# Choose the simplest model among them
1071+
simplest_idx = within_cutoff[np.argmin([self.model_priority[i] for i in within_cutoff])]
1072+
1073+
# print(chisq_values)
1074+
# print(chisq_values - min_chisq)
1075+
# print(within_cutoff)
1076+
# print(self.model_priority)
1077+
# print(simplest_idx,'\n')
1078+
selected_indices.append(simplest_idx)
1079+
1080+
selected_indices = np.array(selected_indices)
1081+
1082+
10731083
self.output[self._prefix_output("labels")] = xr.DataArray(
1074-
data=replacement_labels,
1084+
data=selected_indices,
10751085
dims=[self.sample_dim]
10761086
)
10771087

10781088
self.output[self._prefix_output("label_names")] = xr.DataArray(
1079-
data=[self.dataset[self.model_names_var].values[i] for i in replacement_labels],
1089+
data=[priority_order[i] for i in selected_indices],
10801090
dims=[self.sample_dim]
10811091
)
10821092
return self
@@ -1153,9 +1163,6 @@ def calculate(self, dataset):
11531163

11541164
self.dataset = dataset.copy(deep=True)
11551165

1156-
bestChiSq_labels = self.dataset[self.all_chisq_var].argmin(self.model_names_var).values
1157-
bestChiSq_label_names = np.array([self.dataset[self.model_names_var][i].values for i in bestChiSq_labels])
1158-
11591166
# Determine model complexity either from server or local model_inputs
11601167
if self.server_id is not None:
11611168
# Get complexity from server
@@ -1168,18 +1175,18 @@ def calculate(self, dataset):
11681175
else:
11691176
raise ValueError("Either server_id or model_inputs must be provided to calculate model complexity")
11701177

1171-
# Calculate number of parameters for each model
1172-
n = []
1178+
# Calculate number of parameters for each model, d
1179+
d = []
11731180
for model in model_inputs:
11741181
n_params = 0
11751182
for p in model['fit_params']:
11761183
if model['fit_params'][p]['bounds'] != None:
11771184
n_params += 1
1178-
n.append(n_params)
1179-
n = np.array(n)
1185+
d.append(n_params)
1186+
d = np.array(d)
11801187

1181-
### chisq + 2*ln(d) = AIC
1182-
AIC = np.array([2*np.log(i) + 2*n for i in self.dataset[self.all_chisq_var].values])
1188+
### AIC = 2*d + chisq
1189+
AIC = np.array([2*i + 2*d for i in self.dataset[self.all_chisq_var].values])
11831190

11841191
AIC_labels = np.argmin(AIC, axis=1)
11851192
AIC_label_names = np.array([self.dataset[self.model_names_var][i].values for i in AIC_labels])
@@ -1348,13 +1355,84 @@ def calculate(self, dataset: xr.Dataset) -> Self:
13481355

13491356

13501357

1358+
class ModelSelectMostProbable(PipelineOp):
1359+
"""ModelSelectMostProbable is a pipeline operation for selecting the most probable model.
1360+
1361+
This class selects the model with the highest probability for each sample by
1362+
argmaxing over the 'probabilities' data variable.
1363+
1364+
Attributes
1365+
----------
1366+
model_names_var : str
1367+
The variable name for model names in the dataset
1368+
sample_dim : str
1369+
The dimension containing each sample
1370+
model_inputs : list[dict[str, Any]] | None
1371+
List of model configurations (optional)
1372+
server_id : str | None
1373+
Server ID in the format "host:port" for remote execution, or None for local execution
1374+
output_prefix : str
1375+
Prefix to add to output variable names
1376+
"""
1377+
1378+
def __init__(
1379+
self,
1380+
model_names_var,
1381+
sample_dim,
1382+
output_prefix='MostProbable',
1383+
name="ModelSelection_MostProbable",
1384+
**kwargs
1385+
):
1386+
output_variables = ["labels", "label_names"]
1387+
super().__init__(
1388+
name=name,
1389+
input_variable=[model_names_var],
1390+
output_variable=[
1391+
output_prefix + "_" + o for o in listify(output_variables)
1392+
],
1393+
output_prefix=output_prefix,
1394+
)
1395+
1396+
self.sample_dim = sample_dim
1397+
self.model_names_var = model_names_var
13511398

13521399

1400+
def calculate(self, dataset):
1401+
"""Method for selecting the model with the highest probability for each sample
1402+
1403+
Raises
1404+
------
1405+
ValueError
1406+
If 'probabilities' variable is not present in the dataset
1407+
"""
1408+
1409+
self.dataset = dataset.copy(deep=True)
1410+
1411+
# Check if 'probabilities' variable exists in the dataset
1412+
if 'probabilities' not in self.dataset:
1413+
raise ValueError(
1414+
f"The 'probabilities' variable is required for {self.__class__.__name__}. "
1415+
"Please ensure the dataset contains a 'probabilities' variable with model probabilities."
1416+
)
1417+
1418+
# Determine the most probable model by argmaxing over probabilities
1419+
probabilities = self.dataset['probabilities']
1420+
1421+
# Find the index of the maximum probability for each sample
1422+
most_probable_indices = probabilities.argmax(dim=self.model_names_var)
1423+
print(most_probable_indices)
1424+
# Get the corresponding model names
1425+
model_names = self.dataset[self.model_names_var]
1426+
most_probable_labels = model_names.isel(**{self.model_names_var: most_probable_indices})
1427+
1428+
self.output[self._prefix_output("labels")] = xr.DataArray(
1429+
data=most_probable_indices.values,
1430+
dims=[self.sample_dim]
1431+
)
1432+
1433+
self.output[self._prefix_output("label_names")] = xr.DataArray(
1434+
data=most_probable_labels.values,
1435+
dims=[self.sample_dim]
1436+
)
1437+
return self
13531438

1354-
1355-
1356-
1357-
1358-
1359-
1360-

0 commit comments

Comments
 (0)