|
3 | 3 | TODO implementation of alternative methods: |
4 | 4 | - [ ] FAST |
5 | 5 | - [ ] Morris |
6 | | - - [ ] Sampling based methods (distribution) |
7 | 6 | """ |
8 | 7 | import time |
9 | 8 | import multiprocessing |
10 | | -from typing import Optional |
| 9 | +from typing import Optional, Any |
11 | 10 | from pathlib import Path |
12 | | -from dataclasses import dataclass |
13 | 11 | from rich.progress import track |
14 | 12 | from pymetadata.console import console |
15 | 13 |
|
16 | 14 | import numpy as np |
17 | 15 | import pandas as pd |
18 | 16 | import xarray as xr |
| 17 | +from scipy.stats import qmc |
19 | 18 |
|
20 | 19 | import roadrunner |
21 | 20 |
|
@@ -175,7 +174,6 @@ def simulate_samples(self) -> None: |
175 | 174 | selections=self.sensitivity_simulation.selections, |
176 | 175 | ) |
177 | 176 |
|
178 | | - # FIXME: here the parallelization must take place |
179 | 177 | for k in track(range(self.num_samples), description="Simulating samples"): |
180 | 178 | changes = dict(zip(self.parameter_ids, self.samples[k, :].values)) |
181 | 179 | outputs = self.sensitivity_simulation.simulate( |
@@ -389,10 +387,6 @@ def calculate_sensitivity(self): |
389 | 387 | sensitivity_normalized[kp, ko] = sensitivity_raw[kp, ko] * p_ref/q_ref |
390 | 388 |
|
391 | 389 |
|
392 | | - |
393 | | - |
394 | | - |
395 | | -@dataclass |
396 | 390 | class SobolSensitivityAnalysis(SensitivityAnalysis): |
397 | 391 | """Global sensitivity analysis based on Sobol method. |
398 | 392 |
|
@@ -520,3 +514,192 @@ def plot_sobol_indices( |
520 | 514 | ymax=np.max([1.05, ymax]), |
521 | 515 | ymin=np.min([-0.05, ymin]), |
522 | 516 | ) |
| 517 | + |
| 518 | +class SamplingSensitivityAnalysis(SensitivityAnalysis): |
| 519 | + """Sensitivity/uncertainty analysis based on sampling.""" |
| 520 | + |
| 521 | + sensitivity_keys = [ |
| 522 | + "mean", |
| 523 | + "median", |
| 524 | + "std", |
| 525 | + "cv", |
| 526 | + "min", |
| 527 | + "q005", |
| 528 | + "q095", |
| 529 | + "max" |
| 530 | + ] |
| 531 | + |
| 532 | + def __init__(self, |
| 533 | + sensitivity_simulation: SensitivitySimulation, |
| 534 | + parameters: list[SensitivityParameter], |
| 535 | + N: int, |
| 536 | + results_path: Path, |
| 537 | + ): |
| 538 | + |
| 539 | + super().__init__(sensitivity_simulation, parameters, results_path) |
| 540 | + self.N: int = N |
| 541 | + |
| 542 | + |
| 543 | + def create_samples(self) -> None: |
| 544 | + """Create LHS samples. |
| 545 | +
|
| 546 | + Latin hypercube sampling (LHS) is a stratified sampling method used to |
| 547 | + generate near‑random samples from a multidimensional distribution for Monte |
| 548 | + Carlo simulations and computer experiments. |
| 549 | +
|
| 550 | + Use LHS sampling of parameters. |
| 551 | + """ |
| 552 | + # LHS sampling (uniform distributed in bounds) |
| 553 | + sampler = qmc.LatinHypercube(d=self.num_parameters) # number of dimensions |
| 554 | + u = sampler.random(n=self.N) # shape (n, d), in [0, 1], number of samples |
| 555 | + |
| 556 | + # Scale to parameter bounds |
| 557 | + lower = np.array([p.lower_bound for p in self.parameters]) |
| 558 | + upper = np.array([p.upper_bound for p in self.parameters]) |
| 559 | + x = qmc.scale(u, lower, upper) |
| 560 | + |
| 561 | + self.samples = xr.DataArray( |
| 562 | + x, |
| 563 | + dims=["sample", "parameter"], |
| 564 | + coords={"sample": range(self.N), |
| 565 | + "parameter": self.parameter_ids}, |
| 566 | + name="samples" |
| 567 | + ) |
| 568 | + |
| 569 | + def calculate_sensitivity(self) -> None: |
| 570 | + """Calculate the sensitivity matrices.""" |
| 571 | + |
| 572 | + # calculate readouts |
| 573 | + for key in self.sensitivity_keys: |
| 574 | + self.sensitivity[key] = xr.DataArray( |
| 575 | + np.full(self.num_outputs, np.nan), |
| 576 | + dims=["output"], |
| 577 | + coords={ |
| 578 | + "output": self.output_ids}, |
| 579 | + name=key |
| 580 | + ) |
| 581 | + |
| 582 | + for ko, oid in enumerate(self.outputs): |
| 583 | + # num_samples x num_outputs |
| 584 | + data = self.results.values[:, ko] |
| 585 | + for key in self.sensitivity_keys: |
| 586 | + if key == "mean": |
| 587 | + value = np.mean(data) |
| 588 | + elif key == "median": |
| 589 | + value = np.median(data) |
| 590 | + elif key == "std": |
| 591 | + value = np.std(data) |
| 592 | + elif key == "cv": |
| 593 | + value = np.std(data)/np.mean(data) |
| 594 | + elif key == "min": |
| 595 | + value = np.min(data) |
| 596 | + elif key == "q005": |
| 597 | + value = np.quantile(data, q=0.05) |
| 598 | + elif key == "q095": |
| 599 | + value = np.quantile(data, q=0.95) |
| 600 | + elif key == "max": |
| 601 | + value = np.max(data) |
| 602 | + else: |
| 603 | + raise KeyError(key) |
| 604 | + |
| 605 | + self.sensitivity[key][ko] = value |
| 606 | + |
| 607 | + def df_sampling_sensitivity( |
| 608 | + self, |
| 609 | + df_path: Path, |
| 610 | + ): |
| 611 | + # dataframe with the values |
| 612 | + items = [] |
| 613 | + for ko, output in enumerate(self.outputs): |
| 614 | + item: dict[str, Any] = { |
| 615 | + "uid": output.uid, |
| 616 | + "name": output.name, |
| 617 | + "N": self.N, |
| 618 | + } |
| 619 | + for key in self.sensitivity_keys: |
| 620 | + item[key] = self.sensitivity[key].values[ko] |
| 621 | + item["unit"] = output.unit |
| 622 | + |
| 623 | + items.append(item) |
| 624 | + |
| 625 | + df = pd.DataFrame(items) |
| 626 | + console.print(df) |
| 627 | + if df_path: |
| 628 | + df.to_csv(df_path, index=False, sep="\t") |
| 629 | + |
| 630 | + # latex table |
| 631 | + latex_path = df_path.parent / f"{df_path.stem}.tex" |
| 632 | + df_latex: pd.DataFrame = df.copy() |
| 633 | + df_latex.drop('uid', axis=1, inplace=True) |
| 634 | + df_latex.to_latex(latex_path, index=False, float_format="{:.3g}".format) |
| 635 | + |
| 636 | + return df |
| 637 | + |
| 638 | + |
| 639 | + |
| 640 | + def plot_sampling_sensitivity( |
| 641 | + self, |
| 642 | + fig_path: Path, |
| 643 | + **kwargs |
| 644 | + ): |
| 645 | + """Boxplots for the Sampling sensitivity.""" |
| 646 | + |
| 647 | + # width |
| 648 | + figsize = (15, 15) |
| 649 | + label_fontsize = 15 |
| 650 | + from matplotlib import pyplot as plt |
| 651 | + ncols = np.ceil(np.sqrt(self.num_outputs)) |
| 652 | + n_empty = ncols*ncols - self.num_outputs |
| 653 | + n_empty_rows = np.floor(n_empty/ncols) |
| 654 | + |
| 655 | + nrows = ncols-n_empty_rows |
| 656 | + |
| 657 | + |
| 658 | + f, axes = plt.subplots(figsize=figsize, nrows=int(nrows), ncols=int(ncols), layout="constrained") |
| 659 | + for ko, ax in enumerate(axes.flat): |
| 660 | + if ko > self.num_outputs-1: |
| 661 | + ax.axis('off') |
| 662 | + else: |
| 663 | + |
| 664 | + output = self.outputs[ko] |
| 665 | + data = self.results.values[:, ko] |
| 666 | + |
| 667 | + # outliers for scatter |
| 668 | + # Q1 = np.percentile(data, 25) |
| 669 | + # Q3 = np.percentile(data, 75) |
| 670 | + # IQR = Q3 - Q1 |
| 671 | + # lower_fence = Q1 - 1.5 * IQR |
| 672 | + # upper_fence = Q3 + 1.5 * IQR |
| 673 | + # data_no_outliers = data[(data > lower_fence) & (data < upper_fence)] |
| 674 | + data_no_outliers = data |
| 675 | + |
| 676 | + ax.boxplot(data, positions=[0.2], # labels=[output.name], |
| 677 | + patch_artist=True, showfliers=False, |
| 678 | + boxprops=dict( |
| 679 | + facecolor='lightblue', |
| 680 | + alpha=0.7 |
| 681 | + ) |
| 682 | + ) |
| 683 | + # ax.violinplot(data, positions=[0.8], showmeans=True, |
| 684 | + # showmedians=True, |
| 685 | + # showextrema = False |
| 686 | + # ) |
| 687 | + # jitter_width = 0.05 # Adjust for spacing |
| 688 | + # x_jitter = np.random.normal(0.8, jitter_width, len(data_no_outliers)) |
| 689 | + # ax.scatter(x_jitter, data_no_outliers, alpha=0.7, s=30, color='darkgrey', |
| 690 | + # edgecolors='black' |
| 691 | + # ) |
| 692 | + |
| 693 | + # ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold") |
| 694 | + # ax.set_ylim(bottom=0) |
| 695 | + # ax.set_title(output.name, fontsize=15, fontweight="bold") |
| 696 | + ax.set_ylabel(f"{output.name} [{output.unit}]", fontsize=label_fontsize, fontweight="bold") |
| 697 | + ax.tick_params(axis='x', which='both', labelbottom=False) |
| 698 | + # ax.grid(True, axis="y") |
| 699 | + # ax.tick_params(axis='x', labelrotation=90) |
| 700 | + |
| 701 | + # if title: |
| 702 | + # plt.suptitle(title, fontsize=20, fontweight="bold") |
| 703 | + if fig_path: |
| 704 | + plt.savefig(fig_path, dpi=300, bbox_inches="tight") |
| 705 | + plt.show() |
0 commit comments