|
2 | 2 | import unittest |
3 | 3 | from util import BurnManTest |
4 | 4 | import numpy as np |
| 5 | +from numpy import random |
| 6 | +import matplotlib.pyplot as plt |
5 | 7 |
|
6 | 8 | import burnman |
| 9 | +from burnman.optimize.eos_fitting import fit_XPTp_data |
7 | 10 | from burnman.optimize.nonlinear_fitting import nonlinear_least_squares_fit |
| 11 | +from burnman.utils.misc import attribute_function, pretty_string_values |
8 | 12 |
|
9 | 13 | path = os.path.dirname(os.path.abspath(__file__)) |
10 | 14 |
|
@@ -141,8 +145,143 @@ def test_fit_bounded_PVT_data(self): |
141 | 145 | fo, params, PTV, bounds=bounds, verbose=False |
142 | 146 | ) |
143 | 147 |
|
| 148 | + cp_bands = burnman.nonlinear_fitting.confidence_prediction_bands( |
| 149 | + model=fitted_eos, |
| 150 | + x_array=PTV, |
| 151 | + confidence_interval=0.95, |
| 152 | + f=attribute_function(fo, "V"), |
| 153 | + flag="V", |
| 154 | + ) |
| 155 | + self.assertEqual(len(cp_bands[0]), len(PTV)) |
| 156 | + self.assertEqual(len(cp_bands), 4) |
| 157 | + |
144 | 158 | self.assertFloatEqual(3.0, fitted_eos.popt[2]) |
145 | 159 |
|
| 160 | + s = pretty_string_values( |
| 161 | + fitted_eos.popt, |
| 162 | + fitted_eos.pcov, |
| 163 | + extra_decimal_places=1, |
| 164 | + combine_value_and_sigma=False, |
| 165 | + ) |
| 166 | + |
| 167 | + self.assertEqual(len(s), 3) |
| 168 | + self.assertEqual(len(s[0]), 3) |
| 169 | + self.assertEqual(len(s[1]), 3) |
| 170 | + self.assertEqual(len(s[2]), 3) |
| 171 | + |
| 172 | + s = pretty_string_values( |
| 173 | + fitted_eos.popt, |
| 174 | + fitted_eos.pcov, |
| 175 | + extra_decimal_places=1, |
| 176 | + combine_value_and_sigma=True, |
| 177 | + ) |
| 178 | + |
| 179 | + self.assertEqual(len(s), 3) |
| 180 | + self.assertEqual(len(s[0]), 3) |
| 181 | + self.assertEqual(len(s[1]), 3) |
| 182 | + self.assertEqual(len(s[2]), 3) |
| 183 | + |
| 184 | + def test_bounded_solution_fitting(self): |
| 185 | + solution = burnman.minerals.SLB_2011.mg_fe_olivine() |
| 186 | + solution.set_state(1.0e5, 300.0) |
| 187 | + fit_params = [["V_0", 0], ["V_0", 1], ["V", 0, 1]] |
| 188 | + |
| 189 | + n_data = 5 |
| 190 | + data = [] |
| 191 | + data_covariances = [] |
| 192 | + flags = [] |
| 193 | + |
| 194 | + f_Verror = 1.0e-3 |
| 195 | + |
| 196 | + # Choose a specific seed so that the test is reproducible. |
| 197 | + random.seed(10) |
| 198 | + for i in range(n_data): |
| 199 | + x_fa = random.random() |
| 200 | + P = random.random() * 1.0e10 |
| 201 | + T = random.random() * 1000.0 + 300.0 |
| 202 | + X = [1.0 - x_fa, x_fa] |
| 203 | + solution.set_composition(X) |
| 204 | + solution.set_state(P, T) |
| 205 | + f = 1.0 + (random.normal() - 0.5) * f_Verror |
| 206 | + V = solution.V * f |
| 207 | + |
| 208 | + data.append([1.0 - x_fa, x_fa, P, T, V]) |
| 209 | + data_covariances.append(np.zeros((5, 5))) |
| 210 | + data_covariances[-1][4, 4] = np.power(solution.V * f_Verror, 2.0) |
| 211 | + |
| 212 | + flags = ["V"] * 5 |
| 213 | + |
| 214 | + n_data = 2 |
| 215 | + f_Vperror = 1.0e-2 |
| 216 | + |
| 217 | + for i in range(n_data): |
| 218 | + x_fa = random.random() |
| 219 | + P = random.random() * 1.0e10 |
| 220 | + T = random.random() * 1000.0 + 300.0 |
| 221 | + X = [1.0 - x_fa, x_fa] |
| 222 | + solution.set_composition(X) |
| 223 | + solution.set_state(P, T) |
| 224 | + f = 1.0 + (random.normal() - 0.5) * f_Vperror |
| 225 | + Vp = solution.p_wave_velocity * f |
| 226 | + |
| 227 | + data.append([1.0 - x_fa, x_fa, P, T, Vp]) |
| 228 | + data_covariances.append(np.zeros((5, 5))) |
| 229 | + data_covariances[-1][4, 4] = np.power( |
| 230 | + solution.p_wave_velocity * f_Vperror, 2.0 |
| 231 | + ) |
| 232 | + flags.append("p_wave_velocity") |
| 233 | + |
| 234 | + data = np.array(data) |
| 235 | + data_covariances = np.array(data_covariances) |
| 236 | + flags = np.array(flags) |
| 237 | + delta_params = np.array([1.0e-8, 1.0e-8, 1.0e-8]) |
| 238 | + bounds = np.array([[0, np.inf], [0, np.inf], [-np.inf, np.inf]]) |
| 239 | + |
| 240 | + fitted_eos = fit_XPTp_data( |
| 241 | + solution=solution, |
| 242 | + flags=flags, |
| 243 | + fit_params=fit_params, |
| 244 | + data=data, |
| 245 | + data_covariances=data_covariances, |
| 246 | + delta_params=delta_params, |
| 247 | + bounds=bounds, |
| 248 | + param_tolerance=1.0e-5, |
| 249 | + verbose=False, |
| 250 | + ) |
| 251 | + |
| 252 | + self.assertEqual(len(fitted_eos.popt), 3) |
| 253 | + |
| 254 | + cp_bands = burnman.nonlinear_fitting.confidence_prediction_bands( |
| 255 | + model=fitted_eos, |
| 256 | + x_array=data, |
| 257 | + confidence_interval=0.95, |
| 258 | + f=attribute_function(solution, "V"), |
| 259 | + flag="V", |
| 260 | + ) |
| 261 | + self.assertEqual(len(cp_bands[0]), len(data)) |
| 262 | + self.assertEqual(len(cp_bands), 4) |
| 263 | + |
| 264 | + good_data_confidence_interval = 0.9 |
| 265 | + _, indices, probabilities = burnman.nonlinear_fitting.extreme_values( |
| 266 | + fitted_eos.weighted_residuals, good_data_confidence_interval |
| 267 | + ) |
| 268 | + self.assertEqual(len(indices), 0) |
| 269 | + self.assertEqual(len(probabilities), 0) |
| 270 | + |
| 271 | + # Just check plotting doesn't return an error |
| 272 | + fig, ax = plt.subplots() |
| 273 | + burnman.nonlinear_fitting.plot_residuals( |
| 274 | + ax=ax, |
| 275 | + weighted_residuals=fitted_eos.weighted_residuals, |
| 276 | + flags=fitted_eos.flags, |
| 277 | + ) |
| 278 | + fig, ax = plt.subplots() |
| 279 | + burnman.nonlinear_fitting.weighted_residual_plot(ax, fitted_eos) |
| 280 | + |
| 281 | + fig, ax = burnman.nonlinear_fitting.corner_plot( |
| 282 | + fitted_eos.popt, fitted_eos.pcov |
| 283 | + ) |
| 284 | + |
146 | 285 |
|
147 | 286 | if __name__ == "__main__": |
148 | 287 | unittest.main() |
0 commit comments