Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions news/init-w-results.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* Added ``initialize_recipe_with_results`` to ``FitRecipe``.

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
70 changes: 70 additions & 0 deletions src/diffpy/srfit/fitbase/fitrecipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
__all__ = ["FitRecipe"]

from collections import OrderedDict
from pathlib import Path

import matplotlib.pyplot as plt
from bg_mpl_stylesheets.styles import all_styles
from numpy import array, concatenate, dot, sqrt

import diffpy.srfit.util.inpututils as utils
from diffpy.srfit.fitbase.fithook import PrintFitHook
from diffpy.srfit.fitbase.parameter import ParameterProxy
from diffpy.srfit.fitbase.recipeorganizer import RecipeOrganizer
Expand Down Expand Up @@ -1184,6 +1186,74 @@ def initialize_recipe_with_recipe(self, recipe_object):
if restraint not in self._restraints:
self._restraints.add(restraint)

def _pretty_print_results_dict(self, params_dict):
"""Pretty print a dictionary of parameter names and values."""
sorted_params = sorted(params_dict.items())
width = max(len(name) for name, _ in sorted_params)
for name, value in sorted_params:
if isinstance(value, float):
value_str = f"{value:.6g}"
else:
value_str = str(value)
print(f" {name:<{width}} = {value_str}")

def _set_parameters_from_dict(self, params_dict):
"""Set the parameters of the FitRecipe from a dictionary of
parameter names and values."""
for param_name, param_value in params_dict.items():
if param_name in self._parameters:
self._parameters[param_name].setValue(param_value)
else:
print(
f"Warning: Parameter '{param_name}' from results "
"not found in FitRecipe and will be ignored."
)

def initialize_recipe_with_results(self, results, verbose=True):
"""Initialize a FitRecipe with a FitResults object or a results
file.

Note that at least one FitContribution must already exist in
the FitRecipe.

Parameters
----------
results : FitResults, pathlib.Path, or str
The FitResults object or path to results file to initialize with.
verbose : bool, optional
If True, print warnings for any parameters in the results that are
not in the FitRecipe. Default is True.

Raises
------
ValueError
If the input results is not a FitResults object or a path to a
results file.
"""
if hasattr(results, "print_results"):
params_dict = utils.get_dict_from_results_object(results)
elif isinstance(results, (str, Path)):
params_dict = utils.get_dict_from_results_file(results)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handles either a results object or a path to a results file

else:
raise ValueError(
"The input results must be a FitResults object or a path to a "
f"results file, but got {type(results)}."
)
self._set_parameters_from_dict(params_dict)
if verbose:
print()
print("Parameters found in Results:")
print("=" * 30)
self._pretty_print_results_dict(params_dict)
print()
print("Parameters set in FitRecipe:")
print("=" * 30)
set_parameters_dict = {
param.name: param.getValue()
for param in self._parameters.values()
}
self._pretty_print_results_dict(set_parameters_dict)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prints the parameters found in the results and the parameters set


def set_plot_defaults(self, **kwargs):
"""Set default plotting options for all future plots.

Expand Down
62 changes: 62 additions & 0 deletions src/diffpy/srfit/util/inpututils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__all__ = ["inputToString"]

import os.path
from pathlib import Path


def inputToString(input):
Expand Down Expand Up @@ -51,4 +52,65 @@ def inputToString(input):
return inptstr


def get_dict_from_results_file(
results_filepath: Path | str,
) -> dict[str, float]:
"""Get a dictionary of parameter names and values from a results
file.

The file should have lines in the format:
"parameter_name value +/- uncertainty". Lines that do not match this
format will be ignored.

Parameters
----------
results_filepath : pathlib.Path or str
The path to the results file.

Returns
-------
parsed_results_dict : dict
The dictionary where keys are parameter names and values are the
corresponding parameter values as floats.
"""
with open(results_filepath, "r") as f:
results_string = f.read()
parsed_results_dict = {}
for raw_line in results_string.splitlines():
line = raw_line.strip()
# skip blank lines and lines that are just dashes
if not line or set(line) == {"-"}:
continue
line_items = line.split()
if len(line_items) < 2:
continue
if len(line_items) >= 4 and line_items[2] == "+/-":
try:
parsed_results_dict[line_items[0]] = float(line_items[1])
except ValueError:
pass
return parsed_results_dict


def get_dict_from_results_object(results_object) -> dict[str, float]:
"""Get a dictionary of parameter names and values from a FitResults
object.

Parameters
----------
results_object : FitResults
The FitResults object containing the parameter names and values.

Returns
-------
params_dict : dict
The dictionary where keys are parameter names and values are the
corresponding parameter values as floats.
"""
param_names = results_object.varnames
param_vals = results_object.varvals
params_dict = dict(zip(param_names, param_vals))
return params_dict


# End of file
30 changes: 17 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,23 @@ def _capturestdout(f, *args, **kwargs):
@pytest.fixture()
def build_recipe_one_contribution():
"helper to build a simple recipe"
profile = Profile()
x = linspace(0, pi, 10)
y = sin(x)
profile.set_observed_profile(x, y)
contribution = FitContribution("c1")
contribution.set_profile(profile)
contribution.set_equation("amplitude*sin(wave_number*x + phase_shift)")
recipe = FitRecipe()
recipe.add_contribution(contribution)
recipe.add_variable(contribution.amplitude, 1)
recipe.add_variable(contribution.wave_number, 1)
recipe.add_variable(contribution.phase_shift, 1)
return recipe

def _build_recipe():
profile = Profile()
x = linspace(0, pi, 10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good habit is to use odd numbers in linspace

y = sin(x)
profile.set_observed_profile(x, y)
contribution = FitContribution("c1")
contribution.set_profile(profile)
contribution.set_equation("amplitude*sin(wave_number*x + phase_shift)")
recipe = FitRecipe()
recipe.add_contribution(contribution)
recipe.add_variable(contribution.amplitude, 4)
recipe.add_variable(contribution.wave_number, 3)
recipe.add_variable(contribution.phase_shift, 2)
return recipe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to wrap this in a second function because I realized calling the same fixture twice after the first one was refined led to initial values in the second call being the values of the previously refined recipe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can simply change the scope of the fixture to remove this behavior

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbillinge Unfortunately that doesn't work here. It was already set to scope="function" which is the lowest level so to speak.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but I am confused. in this case it would reset every time through a pytest.mark.parametrize Are we not initializing something correctly? I am only banging on about this because it maybe suggests something may be wrong with our tests which would not be good.

Copy link
Contributor Author

@cadenmyers13 cadenmyers13 Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbillinge In the current version, the test fixture is not being called twice when you do,

recipe1 = build_recipe_one_contribution
recipe2 = build_recipe_one_contribution

What is happening here is that it is assigning the same fixture value to two variable names so recipe1 == recipe2 would return True even if one was refined and the other wasnt.

When we wrap it in another function like in the incoming version, each recipe object can be created

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be indicating a weakness with our code, either the test or the code itself. Is it fixed if you instantiate recipe1 and recipe2 both at the top of the testing function?

Copy link
Contributor Author

@cadenmyers13 cadenmyers13 Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbillinge I tested this without the conftest fixture and built the recipes manually like so,

def test_initialize_recipe_from_results_object():
    # Case: User initializes a FitRecipe from a FitResults object
    # expected: recipe is initialized with variables from previous fit
    profile1 = Profile()
    x = linspace(0, pi, 10)
    y = sin(x)
    profile1.set_observed_profile(x, y)
    contribution1 = FitContribution("c1")
    contribution1.set_profile(profile1)
    contribution1.set_equation("amplitude*sin(wave_number*x + phase_shift)")
    recipe1 = FitRecipe()
    recipe1.add_contribution(contribution1)
    recipe1.add_variable(contribution1.amplitude, 4)
    recipe1.add_variable(contribution1.wave_number, 3)
    recipe1.add_variable(contribution1.phase_shift, 2)
    optimize_recipe(recipe1)
    results1 = FitResults(recipe1)
    expected_values = np.round(results1.varvals, 5)
    expected_names = results1.varnames

    profile2 = Profile()
    x = linspace(0, pi, 10)
    y = sin(x)
    profile2.set_observed_profile(x, y)
    contribution2 = FitContribution("c2")
    contribution2.set_profile(profile2)
    contribution2.set_equation("amplitude*sin(wave_number*x + phase_shift)")
    recipe2 = FitRecipe()
    recipe2.add_contribution(contribution2)
    recipe2.add_variable(contribution2.amplitude, 4)
    recipe2.add_variable(contribution2.wave_number, 3)
    recipe2.add_variable(contribution2.phase_shift, 2)
    recipe2.create_new_variable(
        "extra_var", 5
    )  # should be included in the initialized recipe
    actual_values_before_init = [val for val in recipe2.get_values()]
    actual_names_before_init = recipe2.get_names()
    expected_names_before_init = [
        "amplitude",
        "extra_var",
        "phase_shift",
        "wave_number",
    ]
    expected_values_before_init = [
        4,
        3,
        2,
        5,
    ]  # the three variables + the extra_var

    assert actual_values_before_init == expected_values_before_init
    assert sorted(actual_names_before_init) == sorted(
        expected_names_before_init
    )

    recipe2.initialize_recipe_with_results(results1)
    optimize_recipe(recipe2)
    results2 = FitResults(recipe2)
    actual_values = np.round(results2.varvals, 5)
    actual_names = results2.varnames

    expected_names = expected_names + [
        "extra_var"
    ]  # add the new variable name to expected names
    expected_values = list(expected_values) + [
        5
    ]  # add the value of the new variable to expected values
    assert sorted(expected_names) == sorted(actual_names)
    assert sorted(expected_values) == sorted(list(actual_values))

Doing this passes the test meaning its a fixture related thing and not a code related thing, what we could do is have it like this for this specific test (and i think one other) and revert the conftest fixture back to original.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it fixed if you instantiate recipe1 and recipe2 both at the top of the testing function?

@sbillinge and no, this doesnt fix it

return _build_recipe


@pytest.fixture()
Expand Down
Loading
Loading