|
20 | 20 | from autoemulate.plotting import _plot_model
|
21 | 21 | from autoemulate.printing import _print_setup
|
22 | 22 | from autoemulate.save import ModelSerialiser
|
| 23 | +from autoemulate.sensitivity_analysis import plot_sensitivity_analysis |
| 24 | +from autoemulate.sensitivity_analysis import sensitivity_analysis |
| 25 | +from autoemulate.utils import _ensure_2d |
23 | 26 | from autoemulate.utils import _get_full_model_name
|
24 | 27 | from autoemulate.utils import _redirect_warnings
|
25 | 28 | from autoemulate.utils import get_model_name
|
@@ -522,3 +525,64 @@ def plot_eval(
|
522 | 525 | )
|
523 | 526 |
|
524 | 527 | return fig
|
| 528 | + |
| 529 | + def sensitivity_analysis( |
| 530 | + self, model=None, problem=None, N=1024, conf_level=0.95, as_df=True |
| 531 | + ): |
| 532 | + """Perform Sobol sensitivity analysis on a fitted emulator. |
| 533 | +
|
| 534 | + Parameters |
| 535 | + ---------- |
| 536 | + model : object, optional |
| 537 | + Fitted model. If None, uses the best model from cross-validation. |
| 538 | + problem : dict, optional |
| 539 | + The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'. |
| 540 | + If None, the problem is generated from X using minimum and maximum values of the features as bounds. |
| 541 | +
|
| 542 | + Example: |
| 543 | + ```python |
| 544 | + problem = { |
| 545 | + "num_vars": 2, |
| 546 | + "names": ["x1", "x2"], |
| 547 | + "bounds": [[0, 1], [0, 1]], |
| 548 | + } |
| 549 | + ``` |
| 550 | + N : int, optional |
| 551 | + Number of samples to generate. Default is 1024. |
| 552 | + conf_level : float, optional |
| 553 | + Confidence level for the confidence intervals. Default is 0.95. |
| 554 | + as_df : bool, optional |
| 555 | + If True, return a long-format pandas DataFrame (default is True). |
| 556 | + """ |
| 557 | + if model is None: |
| 558 | + if not hasattr(self, "best_model"): |
| 559 | + raise RuntimeError("Must run compare() before sensitivity_analysis()") |
| 560 | + model = self.refit(self.best_model) |
| 561 | + self.logger.info( |
| 562 | + f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data." |
| 563 | + ) |
| 564 | + |
| 565 | + Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df) |
| 566 | + return Si |
| 567 | + |
| 568 | + def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None): |
| 569 | + """ |
| 570 | + Plot the sensitivity analysis results. |
| 571 | +
|
| 572 | + Parameters: |
| 573 | + ----------- |
| 574 | + results : pd.DataFrame |
| 575 | + The results from sobol_results_to_df. |
| 576 | + index : str, default "S1" |
| 577 | + The type of sensitivity index to plot. |
| 578 | + - "S1": first-order indices |
| 579 | + - "S2": second-order/interaction indices |
| 580 | + - "ST": total-order indices |
| 581 | + n_cols : int, optional |
| 582 | + The number of columns in the plot. Defaults to 3 if there are 3 or more outputs, |
| 583 | + otherwise the number of outputs. |
| 584 | + figsize : tuple, optional |
| 585 | + Figure size as (width, height) in inches.If None, automatically calculated. |
| 586 | +
|
| 587 | + """ |
| 588 | + return plot_sensitivity_analysis(results, index, n_cols, figsize) |
0 commit comments