Skip to content

Commit b0b988a

Browse files
committed
Merge branch 'experiment_performance' into master
2 parents 342638d + ca305f5 commit b0b988a

File tree

6 files changed

+65
-43
lines changed

6 files changed

+65
-43
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ __pycache__/
1010
/tutorial notebooks/experiment.dump
1111

1212
/docs/html
13+
/plots
14+
/xrdfit/test.py
15+
/xrdfit.egg-info

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ To install as a Python module, type
88
from the root directory.
99
For developers, you should install in linked .egg mode using
1010

11-
`python -m pip install . develop`
11+
`python setup.py develop`
1212

1313
If you are using a Python virtual environment, you should activate this first before using the above commands.

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
'dill',
1313
'tqdm',
1414
'scipy',
15-
'lmfit'
15+
'lmfit',
16+
'jupyter',
1617
],
1718
extras_require={"documentation_compilation": "sphinx"}
1819
)

xrdfit/plotting.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import pathlib
13
from typing import Tuple, List, Union, TYPE_CHECKING
24

35
import lmfit
@@ -83,7 +85,7 @@ def plot_peak_params(peak_params: List["PeakParams"], x_range: Tuple[float, floa
8385

8486

8587
def plot_peak_fit(data: np.ndarray, cake_numbers: List[int], fit_result: lmfit.model.ModelResult,
86-
fit_name: str):
88+
fit_name: str, timestep: str = None, file_name: str = None):
8789
"""Plot the result of a peak fit as well as the raw data."""
8890
plt.figure(figsize=(8, 6))
8991

@@ -100,9 +102,18 @@ def plot_peak_fit(data: np.ndarray, cake_numbers: List[int], fit_result: lmfit.m
100102
plt.xlabel(r'Two Theta ($^\circ$)')
101103
plt.ylabel('Intensity')
102104
plt.legend()
105+
if timestep:
106+
fit_name = f'Peak "{fit_name}" at t = {timestep}'
103107
plt.title(fit_name)
104108
plt.tight_layout()
105-
plt.show()
109+
if file_name:
110+
file_name = pathlib.Path(file_name)
111+
if not file_name.parent.exists():
112+
os.makedirs(file_name.parent)
113+
plt.savefig(file_name)
114+
else:
115+
plt.show()
116+
plt.close()
106117

107118

108119
def plot_parameter(data: np.ndarray, fit_parameter: str, peak_name: str, show_points: bool):

xrdfit/spectrum_fitting.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,32 @@ def __str__(self) -> str:
6161
return f"PeakParams('{self.name}', {self.peak_bounds})"
6262
return f"PeakParams('{self.name}', {self.peak_bounds}, {self.maxima_bounds})"
6363

64+
def adjust_peak_bounds(self, fit_params: lmfit.Parameters):
65+
"""Adjust peak bounds to re-center the peak in the peak bounds."""
66+
centers = []
67+
for name in fit_params:
68+
if "center" in name:
69+
centers.append(fit_params[name].value)
70+
center = sum(centers) / len(centers)
71+
72+
bound_width = self.peak_bounds[1] - self.peak_bounds[0]
73+
self.peak_bounds = (center - (bound_width / 2), center + (bound_width / 2))
74+
75+
def adjust_maxima_bounds(self, fit_params: lmfit.Parameters):
76+
"""Adjust maxima bounds to re-center the maximum in the maximum bounds."""
77+
peak_centers = []
78+
for param in fit_params:
79+
if "center" in param:
80+
peak_centers.append(fit_params[param].value)
81+
82+
new_maxima_bounds = []
83+
for center, maximum_bounds in zip(peak_centers, self.maxima_bounds):
84+
maximum_bound_width = maximum_bounds[1] - maximum_bounds[0]
85+
lower_bound = center - maximum_bound_width / 2
86+
upper_bound = center + maximum_bound_width / 2
87+
new_maxima_bounds.append((lower_bound, upper_bound))
88+
self.maxima_bounds = new_maxima_bounds
89+
6490

6591
class PeakFit:
6692
"""An object containing data on the fit to a peak.
@@ -74,12 +100,13 @@ def __init__(self, name: str):
74100
self.result: Union[None, lmfit.model.ModelResult] = None
75101
self.cake_numbers: List[int] = []
76102

77-
def plot(self):
103+
def plot(self, timestep: str = None, file_name: str = None):
78104
""" Plot the raw spectral data and the fit."""
79105
if self.raw_spectrum is None:
80106
print("Cannot plot fit peak as fitting has not been done yet.")
81107
else:
82-
plotting.plot_peak_fit(self.raw_spectrum, self.cake_numbers, self.result, self.name)
108+
plotting.plot_peak_fit(self.raw_spectrum, self.cake_numbers, self.result, self.name,
109+
timestep, file_name)
83110

84111

85112
class FitSpectrum:
@@ -129,11 +156,14 @@ def plot(self, cakes_to_plot: Union[int, List[int]], x_range: Tuple[float, float
129156
plotting.plot_spectrum(data, cakes_to_plot, merge_cakes, show_points, x_range)
130157
plt.show()
131158

132-
def plot_fit(self, fit_name: str):
159+
def plot_fit(self, fit_name: str, timestep: str = None, file_name: str = None):
133160
"""Plot the result of a fit.
134-
:param fit_name: The name of the fit to plot."""
161+
:param fit_name: The name of the fit to plot.
162+
:param timestep: If provided, the timestep of the fit which will be added to the title.
163+
:param file_name: If provided, the stub of the file name to write the plot to, if not
164+
provided, the plot will be displayed on screen."""
135165
fit = self.get_fit(fit_name)
136-
fit.plot()
166+
fit.plot(timestep, file_name)
137167

138168
def plot_peak_params(self, peak_params: Union[PeakParams, List[PeakParams]],
139169
cakes_to_plot: Union[int, List[int]],
@@ -319,7 +349,8 @@ def run_analysis(self, reuse_fits=False):
319349
for peak_fit, peak_params in zip(spectral_data.fitted_peaks, self.peak_params):
320350
if reuse_fits:
321351
peak_params.set_previous_fit(peak_fit.result.params)
322-
self.adjust_peak_bounds(spectral_data.fitted_peaks)
352+
peak_params.adjust_maxima_bounds(peak_fit.result.params)
353+
peak_params.adjust_peak_bounds(peak_fit.result.params)
323354

324355
print("Analysis complete.")
325356

@@ -365,22 +396,28 @@ def plot_fit_parameter(self, peak_name: str, fit_parameter: str, show_points=Fal
365396
plotting.plot_parameter(data, fit_parameter, peak_name, show_points)
366397

367398
def plot_fits(self, num_timesteps: int = 5, peak_names: Union[List[str], str] = None,
368-
timesteps: List[int] = None):
399+
timesteps: List[int] = None, file_name: str = None):
369400
"""Plot the calculated fits to the data.
370401
:param num_timesteps: The number of timesteps to plot fits for. The function will plot this
371402
many timesteps, evenly spaced over the whole dataset. This value is ignored if `timesteps`
372403
is specified.
373404
:param peak_names: The name of the peak to fit. If not specified, will plot all fitted
374405
peaks.
375-
:param timesteps: A list of timesteps to plot the fits for."""
406+
:param timesteps: A list of timesteps to plot the fits for.
407+
:param file_name: If provided, outputs the plot to an image file with filename as the image
408+
stub."""
376409

377410
if timesteps is None:
378411
timesteps = self._calculate_timesteps(num_timesteps)
379412
if peak_names is None:
380413
peak_names = self.peak_names()
381414
for timestep in timesteps:
382415
for name in peak_names:
383-
self.timesteps[timestep].plot_fit(name)
416+
if file_name:
417+
output_name = f"../plots/{file_name}_{name}_{timestep :04d}.png"
418+
else:
419+
output_name = None
420+
self.timesteps[timestep].plot_fit(name, str(timestep), output_name)
384421

385422
def _calculate_timesteps(self, num_timesteps: int) -> List[int]:
386423
"""Work out which timesteps to plot."""
@@ -397,18 +434,6 @@ def save(self, file_name: str):
397434
dill.dump(self, output_file)
398435
print("Data successfully saved to dump file.")
399436

400-
def adjust_peak_bounds(self, fitted_peaks: List[PeakFit]):
401-
"""Adjust peak bounds to re-center the peak in the peak bounds."""
402-
for peak_fit, peak_param in zip(fitted_peaks, self.peak_params):
403-
centers = []
404-
for name in peak_fit.result.params:
405-
if "center" in name:
406-
centers.append(peak_fit.result.params[name].value)
407-
center = sum(centers) / len(centers)
408-
409-
bound_width = peak_param.peak_bounds[1] - peak_param.peak_bounds[0]
410-
peak_param.peak_bounds = (center - (bound_width / 2), center + (bound_width / 2))
411-
412437

413438
def get_stacked_spectrum(spectrum: np.ndarray) -> np.ndarray:
414439
"""Take an number of observations from N different cakes and stack them vertically into a 2

xrdfit/test.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)