Skip to content

Commit fae8500

Browse files
committed
Update parmest.py, DoE meeting work
1 parent 7d0c956 commit fae8500

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

pyomo/contrib/parmest/parmest.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,23 @@ def regularize_term(model, prior_FIM, theta_ref):
246246
247247
Added to SSE objective function
248248
"""
249-
expr = ((theta - theta_ref).transpose() * prior_FIM * (theta - theta_ref) for theta in model.unknown_parameters.items())
249+
# Check if prior_FIM is a square matrix
250+
if prior_FIM.shape[0] != prior_FIM.shape[1]:
251+
raise ValueError("prior_FIM must be a square matrix")
252+
253+
# Check if theta_ref is a vector of the same size as prior_FIM
254+
if len(theta_ref) != prior_FIM.shape[0]:
255+
raise ValueError("theta_ref must be a vector of the same size as prior_FIM")
256+
257+
# (theta - theta_ref).transpose() * prior_FIM * (theta - theta_ref)
258+
expr = np.zeros(len(theta_ref))
259+
260+
for i in range(len(theta_ref)):
261+
if theta_ref[i] is None:
262+
raise ValueError("theta_ref must not contain None values")
263+
expr[i] = (model.unknown_parameters[i] - theta_ref[i]).transpose() * prior_FIM[i] * (model.unknown_parameters[i] - theta_ref[i])
264+
return sum(expr)**2
265+
250266
return expr
251267

252268

@@ -449,10 +465,10 @@ def _create_parmest_model(self, experiment_number):
449465

450466
if self.prior_FIM and self.theta_ref is not None:
451467
# Regularize the objective function
452-
second_stage_rule = SSE + regularize_term(prior_FIM = self.prior_FIM, theta_ref = self.theta_ref)
468+
second_stage_rule = SSE + regularize_term(model = self.model_initialized, prior_FIM = self.prior_FIM, theta_ref = self.theta_ref)
453469
elif self.prior_FIM:
454470
theta_ref = model.unknown_parameters.values()
455-
second_stage_rule = SSE + regularize_term(prior_FIM = self.prior_FIM, theta_ref = self.theta_ref)
471+
second_stage_rule = SSE + regularize_term(prior_FIM = self.prior_FIM, theta_ref = theta_ref)
456472

457473
else:
458474
# Sum of squared errors

0 commit comments

Comments
 (0)