Skip to content

Commit 6de334f

Browse files
authored
Merge pull request #162 from cadenmyers13/init-w-recipe
feat: Initialize a fit recipe with a previous recipe
2 parents 72ed60e + 4b60d8a commit 6de334f

File tree

4 files changed

+164
-12
lines changed

4 files changed

+164
-12
lines changed

news/init-w-recipe.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* Added initialize_recipe_from_recipe to ``FitRecipe``.
4+
5+
**Changed:**
6+
7+
* <news item>
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

src/diffpy/srfit/fitbase/fitrecipe.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,50 @@ def getBounds2(self):
11401140
"""
11411141
return self.get_bounds_array()
11421142

1143+
def initialize_recipe_with_recipe(self, recipe_object):
1144+
"""Initialize a FitRecipe with another FitRecipe.
1145+
1146+
This is used to initialize a FitRecipe with the contribution(s),
1147+
parameters, constraints and restraints of another FitRecipe.
1148+
If a duplicate contribution, parameter, constraint, or restraint
1149+
is added to the FitRecipe you are initializing, the value from the
1150+
added object will be used.
1151+
1152+
Parameters
1153+
----------
1154+
recipe_object : FitRecipe
1155+
The FitRecipe to initialize with.
1156+
1157+
Raises
1158+
------
1159+
ValueError
1160+
If the object passed is not a FitRecipe.
1161+
"""
1162+
if not isinstance(recipe_object, FitRecipe):
1163+
raise ValueError(
1164+
"The input recipe_object must be a FitRecipe, "
1165+
f"but got {type(recipe_object)}."
1166+
)
1167+
1168+
for contrib_object in recipe_object._contributions.values():
1169+
if contrib_object not in self._contributions.values():
1170+
self.add_contribution(contrib_object)
1171+
1172+
for param_name, param_object in recipe_object._parameters.items():
1173+
if param_name not in self._parameters:
1174+
self._parameters.update({param_name: param_object})
1175+
1176+
for (
1177+
parameter_object,
1178+
constraint_object,
1179+
) in recipe_object._constraints.items():
1180+
if parameter_object not in self._constraints:
1181+
self._constraints.update({parameter_object: constraint_object})
1182+
1183+
for restraint in recipe_object._restraints:
1184+
if restraint not in self._restraints:
1185+
self._restraints.add(restraint)
1186+
11431187
def set_plot_defaults(self, **kwargs):
11441188
"""Set default plotting options for all future plots.
11451189

tests/conftest.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _capturestdout(f, *args, **kwargs):
146146
return _capturestdout
147147

148148

149-
@pytest.fixture(scope="session")
149+
@pytest.fixture()
150150
def build_recipe_one_contribution():
151151
"helper to build a simple recipe"
152152
profile = Profile()
@@ -164,32 +164,45 @@ def build_recipe_one_contribution():
164164
return recipe
165165

166166

167-
@pytest.fixture(scope="session")
167+
@pytest.fixture()
168168
def build_recipe_two_contributions():
169-
"helper to build a recipe with two contributions"
169+
"""Helper to build a recipe with two physically related contributions."""
170170
profile1 = Profile()
171-
x = linspace(0, pi, 10)
172-
y1 = sin(x)
171+
x = linspace(0, pi, 50)
172+
y1 = sin(x) # amplitude=1, freq=1
173173
profile1.set_observed_profile(x, y1)
174+
174175
contribution1 = FitContribution("c1")
175176
contribution1.set_profile(profile1)
176177
contribution1.set_equation("A*sin(k*x + c)")
177178

178179
profile2 = Profile()
179-
y2 = 0.5 * sin(2 * x)
180+
y2 = 0.5 * sin(2 * x) # amplitude=0.5, freq=2
180181
profile2.set_observed_profile(x, y2)
182+
181183
contribution2 = FitContribution("c2")
182184
contribution2.set_profile(profile2)
183185
contribution2.set_equation("B*sin(m*x + d)")
186+
184187
recipe = FitRecipe()
185188
recipe.add_contribution(contribution1)
186189
recipe.add_contribution(contribution2)
187-
recipe.add_variable(contribution1.A, 1)
188-
recipe.add_variable(contribution1.k, 1)
189-
recipe.add_variable(contribution1.c, 1)
190-
recipe.add_variable(contribution2.B, 0.5)
191-
recipe.add_variable(contribution2.m, 2)
192-
recipe.add_variable(contribution2.d, 0)
190+
191+
# Add variables with reasonable initial guesses
192+
recipe.add_variable(contribution1.A, 0.8)
193+
recipe.add_variable(contribution1.k, 1.0)
194+
recipe.add_variable(contribution1.c, 0.1)
195+
196+
recipe.add_variable(contribution2.B, 0.4)
197+
recipe.add_variable(contribution2.m, 2.0)
198+
recipe.add_variable(contribution2.d, 0.1)
199+
200+
# ---- Meaningful constraints ----
201+
recipe.constrain(contribution2.m, "2*k")
202+
recipe.constrain(contribution2.d, contribution1.c)
203+
recipe.constrain(contribution2.B, "0.5*A")
204+
recipe.restrain(contribution1.A, 0.5, 1.5)
205+
recipe.restrain(contribution1.k, 0.8, 1.2)
193206

194207
return recipe
195208

tests/test_fitrecipe.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919
import matplotlib
2020
import matplotlib.pyplot as plt
21+
import numpy as np
2122
import pytest
2223
from numpy import array_equal, dot, linspace, pi, sin
2324
from scipy.optimize import leastsq
2425

26+
from diffpy.srfit.fitbase import FitResults
2527
from diffpy.srfit.fitbase.fitcontribution import FitContribution
2628
from diffpy.srfit.fitbase.fitrecipe import FitRecipe
2729
from diffpy.srfit.fitbase.parameter import Parameter
@@ -462,6 +464,76 @@ def optimize_recipe(recipe):
462464
leastsq(residuals, values)
463465

464466

467+
def test_initialize_recipe_from_recipe(build_recipe_two_contributions):
468+
# Case: User initializes a FitRecipe from a previously optimized fit
469+
# expected: recipe is initialized with everything:
470+
# contributions, profiles (contained in contributions),
471+
# variables, restraints, and constraints
472+
recipe1 = build_recipe_two_contributions
473+
optimize_recipe(recipe1)
474+
expected_parameters_dict = recipe1._parameters
475+
expected_constraints_dict = recipe1._constraints
476+
expected_restraints_set = recipe1._restraints
477+
expected_contributions_dict = recipe1._contributions
478+
expected_profiles_list = []
479+
for con_name, contribution in expected_contributions_dict.items():
480+
expected_profile = contribution.profile
481+
expected_profiles_list.append(expected_profile)
482+
483+
recipe2 = FitRecipe()
484+
recipe2.initialize_recipe_with_recipe(recipe1)
485+
actual_parameters_dict = recipe2._parameters
486+
actual_constraints_dict = recipe2._constraints
487+
actual_restraints_set = recipe2._restraints
488+
actual_contributions_dict = recipe2._contributions
489+
actual_profiles_list = []
490+
for con_name, contribution in actual_contributions_dict.items():
491+
actual_profile = contribution.profile
492+
actual_profiles_list.append(actual_profile)
493+
494+
assert expected_parameters_dict == actual_parameters_dict
495+
assert expected_constraints_dict == actual_constraints_dict
496+
assert expected_restraints_set == actual_restraints_set
497+
assert expected_contributions_dict == actual_contributions_dict
498+
assert expected_profiles_list == actual_profiles_list
499+
500+
# Check to see if the refined values and variable names are
501+
# the same in the results objects for each recipe
502+
results1 = FitResults(recipe1)
503+
# round to account for small numerical differences
504+
expected_values = np.round(results1.varvals, 7)
505+
expected_names = results1.varnames
506+
507+
optimize_recipe(recipe2)
508+
results2 = FitResults(recipe2)
509+
# round to account for small numerical differences
510+
actual_values = np.round(results2.varvals, 7)
511+
actual_names = results2.varnames
512+
513+
assert sorted(expected_names) == sorted(actual_names)
514+
assert sorted(list(expected_values)) == sorted(list(actual_values))
515+
516+
517+
def test_initialize_recipe_from_recipe_bad(build_recipe_two_contributions):
518+
# Case: User tries to initialize a FitRecipe from a non recipe object
519+
# expected: raised ValueError with message
520+
recipe_bad = 12345 # not a FitRecipe object
521+
recipe2 = FitRecipe()
522+
msg = (
523+
"The input recipe_object must be a FitRecipe, "
524+
"but got <class 'int'>."
525+
)
526+
with pytest.raises(ValueError, match=msg):
527+
recipe2.initialize_recipe_with_recipe(recipe_bad)
528+
529+
530+
# def test_initialize_recipe_from_results(build_recipe_one_contribution):
531+
# # Case: User initializes a FitRecipe from a FitResults object or
532+
# # results file
533+
# # expected: recipe is initialized with variables from previous fit
534+
# assert False
535+
536+
465537
def get_labels_and_linecount(ax):
466538
"""Helper to get line labels and count from a matplotlib Axes."""
467539
labels = [

0 commit comments

Comments
 (0)