Skip to content

Commit 4386fb6

Browse files
working on sensitivity analysis
1 parent 1cec46e commit 4386fb6

File tree

3 files changed

+194
-10
lines changed

3 files changed

+194
-10
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 191 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
TODO implementation of alternative methods:
44
- [ ] FAST
55
- [ ] Morris
6-
- [ ] Sampling based methods (distribution)
76
"""
87
import time
98
import multiprocessing
10-
from typing import Optional
9+
from typing import Optional, Any
1110
from pathlib import Path
12-
from dataclasses import dataclass
1311
from rich.progress import track
1412
from pymetadata.console import console
1513

1614
import numpy as np
1715
import pandas as pd
1816
import xarray as xr
17+
from scipy.stats import qmc
1918

2019
import roadrunner
2120

@@ -175,7 +174,6 @@ def simulate_samples(self) -> None:
175174
selections=self.sensitivity_simulation.selections,
176175
)
177176

178-
# FIXME: here the parallelization must take place
179177
for k in track(range(self.num_samples), description="Simulating samples"):
180178
changes = dict(zip(self.parameter_ids, self.samples[k, :].values))
181179
outputs = self.sensitivity_simulation.simulate(
@@ -389,10 +387,6 @@ def calculate_sensitivity(self):
389387
sensitivity_normalized[kp, ko] = sensitivity_raw[kp, ko] * p_ref/q_ref
390388

391389

392-
393-
394-
395-
@dataclass
396390
class SobolSensitivityAnalysis(SensitivityAnalysis):
397391
"""Global sensitivity analysis based on Sobol method.
398392
@@ -520,3 +514,192 @@ def plot_sobol_indices(
520514
ymax=np.max([1.05, ymax]),
521515
ymin=np.min([-0.05, ymin]),
522516
)
517+
518+
class SamplingSensitivityAnalysis(SensitivityAnalysis):
519+
"""Sensitivity/uncertainty analysis based on sampling."""
520+
521+
sensitivity_keys = [
522+
"mean",
523+
"median",
524+
"std",
525+
"cv",
526+
"min",
527+
"q005",
528+
"q095",
529+
"max"
530+
]
531+
532+
def __init__(self,
533+
sensitivity_simulation: SensitivitySimulation,
534+
parameters: list[SensitivityParameter],
535+
N: int,
536+
results_path: Path,
537+
):
538+
539+
super().__init__(sensitivity_simulation, parameters, results_path)
540+
self.N: int = N
541+
542+
543+
def create_samples(self) -> None:
544+
"""Create LHS samples.
545+
546+
Latin hypercube sampling (LHS) is a stratified sampling method used to
547+
generate near‑random samples from a multidimensional distribution for Monte
548+
Carlo simulations and computer experiments.
549+
550+
Use LHS sampling of parameters.
551+
"""
552+
# LHS sampling (uniform distributed in bounds)
553+
sampler = qmc.LatinHypercube(d=self.num_parameters) # number of dimensions
554+
u = sampler.random(n=self.N) # shape (n, d), in [0, 1], number of samples
555+
556+
# Scale to parameter bounds
557+
lower = np.array([p.lower_bound for p in self.parameters])
558+
upper = np.array([p.upper_bound for p in self.parameters])
559+
x = qmc.scale(u, lower, upper)
560+
561+
self.samples = xr.DataArray(
562+
x,
563+
dims=["sample", "parameter"],
564+
coords={"sample": range(self.N),
565+
"parameter": self.parameter_ids},
566+
name="samples"
567+
)
568+
569+
def calculate_sensitivity(self) -> None:
570+
"""Calculate the sensitivity matrices."""
571+
572+
# calculate readouts
573+
for key in self.sensitivity_keys:
574+
self.sensitivity[key] = xr.DataArray(
575+
np.full(self.num_outputs, np.nan),
576+
dims=["output"],
577+
coords={
578+
"output": self.output_ids},
579+
name=key
580+
)
581+
582+
for ko, oid in enumerate(self.outputs):
583+
# num_samples x num_outputs
584+
data = self.results.values[:, ko]
585+
for key in self.sensitivity_keys:
586+
if key == "mean":
587+
value = np.mean(data)
588+
elif key == "median":
589+
value = np.median(data)
590+
elif key == "std":
591+
value = np.std(data)
592+
elif key == "cv":
593+
value = np.std(data)/np.mean(data)
594+
elif key == "min":
595+
value = np.min(data)
596+
elif key == "q005":
597+
value = np.quantile(data, q=0.05)
598+
elif key == "q095":
599+
value = np.quantile(data, q=0.95)
600+
elif key == "max":
601+
value = np.max(data)
602+
else:
603+
raise KeyError(key)
604+
605+
self.sensitivity[key][ko] = value
606+
607+
def df_sampling_sensitivity(
608+
self,
609+
df_path: Path,
610+
):
611+
# dataframe with the values
612+
items = []
613+
for ko, output in enumerate(self.outputs):
614+
item: dict[str, Any] = {
615+
"uid": output.uid,
616+
"name": output.name,
617+
"N": self.N,
618+
}
619+
for key in self.sensitivity_keys:
620+
item[key] = self.sensitivity[key].values[ko]
621+
item["unit"] = output.unit
622+
623+
items.append(item)
624+
625+
df = pd.DataFrame(items)
626+
console.print(df)
627+
if df_path:
628+
df.to_csv(df_path, index=False, sep="\t")
629+
630+
# latex table
631+
latex_path = df_path.parent / f"{df_path.stem}.tex"
632+
df_latex: pd.DataFrame = df.copy()
633+
df_latex.drop('uid', axis=1, inplace=True)
634+
df_latex.to_latex(latex_path, index=False, float_format="{:.3g}".format)
635+
636+
return df
637+
638+
639+
640+
def plot_sampling_sensitivity(
641+
self,
642+
fig_path: Path,
643+
**kwargs
644+
):
645+
"""Boxplots for the Sampling sensitivity."""
646+
647+
# width
648+
figsize = (15, 15)
649+
label_fontsize = 15
650+
from matplotlib import pyplot as plt
651+
ncols = np.ceil(np.sqrt(self.num_outputs))
652+
n_empty = ncols*ncols - self.num_outputs
653+
n_empty_rows = np.floor(n_empty/ncols)
654+
655+
nrows = ncols-n_empty_rows
656+
657+
658+
f, axes = plt.subplots(figsize=figsize, nrows=int(nrows), ncols=int(ncols), layout="constrained")
659+
for ko, ax in enumerate(axes.flat):
660+
if ko > self.num_outputs-1:
661+
ax.axis('off')
662+
else:
663+
664+
output = self.outputs[ko]
665+
data = self.results.values[:, ko]
666+
667+
# outliers for scatter
668+
# Q1 = np.percentile(data, 25)
669+
# Q3 = np.percentile(data, 75)
670+
# IQR = Q3 - Q1
671+
# lower_fence = Q1 - 1.5 * IQR
672+
# upper_fence = Q3 + 1.5 * IQR
673+
# data_no_outliers = data[(data > lower_fence) & (data < upper_fence)]
674+
data_no_outliers = data
675+
676+
ax.boxplot(data, positions=[0.2], # labels=[output.name],
677+
patch_artist=True, showfliers=False,
678+
boxprops=dict(
679+
facecolor='lightblue',
680+
alpha=0.7
681+
)
682+
)
683+
# ax.violinplot(data, positions=[0.8], showmeans=True,
684+
# showmedians=True,
685+
# showextrema = False
686+
# )
687+
# jitter_width = 0.05 # Adjust for spacing
688+
# x_jitter = np.random.normal(0.8, jitter_width, len(data_no_outliers))
689+
# ax.scatter(x_jitter, data_no_outliers, alpha=0.7, s=30, color='darkgrey',
690+
# edgecolors='black'
691+
# )
692+
693+
# ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
694+
# ax.set_ylim(bottom=0)
695+
# ax.set_title(output.name, fontsize=15, fontweight="bold")
696+
ax.set_ylabel(f"{output.name} [{output.unit}]", fontsize=label_fontsize, fontweight="bold")
697+
ax.tick_params(axis='x', which='both', labelbottom=False)
698+
# ax.grid(True, axis="y")
699+
# ax.tick_params(axis='x', labelrotation=90)
700+
701+
# if title:
702+
# plt.suptitle(title, fontsize=20, fontweight="bold")
703+
if fig_path:
704+
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
705+
plt.show()

src/sbmlsim/sensitivity/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ class SensitivityOutput:
77
"""Output measurement for SensitivityAnalysis."""
88
uid: str
99
name: str
10-
# unit: Optional[str]
10+
unit: Optional[str]

src/sbmlsim/sensitivity/plots.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def sobol_barplot(
143143
edgecolor="black", yerr=S1_conf, capsize=5)
144144

145145

146-
ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
146+
# ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
147147
ax.set_ylabel('Sobol Index', fontsize=label_fontsize, fontweight="bold")
148148
ax.set_ylim(bottom=ymin, top=ymax)
149149
ax.grid(True, axis="y")
@@ -158,3 +158,4 @@ def sobol_barplot(
158158
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
159159
plt.show()
160160

161+

0 commit comments

Comments
 (0)