Skip to content

Commit e8d17ee

Browse files
committed
update tutorial
1 parent b38dd89 commit e8d17ee

File tree

6 files changed

+453
-156
lines changed

6 files changed

+453
-156
lines changed

case_study1/helper_visualize.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
import plotly.graph_objects as go
1010
from plotly.subplots import make_subplots
1111

12+
import warnings
13+
warnings.filterwarnings(
14+
"ignore",
15+
message="JULIA_SYSIMAGE_DIFFEQTORCH not set"
16+
)
17+
1218
BASE = Path(__file__).resolve().parent
1319

1420

@@ -415,7 +421,6 @@ def plot_benchmark_results_plotly(df):
415421

416422
problem_names = sbibm.get_available_tasks()
417423
problem_names_nice = np.array([sbibm.get_task(p).name_display for p in problem_names])
418-
task_name_nice = problem_names_nice.copy()
419424
problem_dim = [sbibm.get_task(p).dim_parameters for p in problem_names]
420425
data_dim = [sbibm.get_task(p).dim_data for p in problem_names]
421426

@@ -457,8 +462,8 @@ def plot_benchmark_results_plotly(df):
457462
shown_labels = set()
458463

459464
for plot_idx, problem_idx in enumerate(problem_order):
460-
col = plot_idx // (n_problems // 2) + 1
461-
row = plot_idx % (n_problems // 2) + 1
465+
row = plot_idx // 2 + 1
466+
col = plot_idx % 2 + 1
462467

463468
task_name = problem_names[problem_idx]
464469
subset = df[df['problem'] == task_name]
@@ -505,7 +510,7 @@ def plot_benchmark_results_plotly(df):
505510
'<extra></extra>'
506511
),
507512
customdata=[[
508-
task_name_nice[problem_idx],
513+
problem_names_nice[problem_idx],
509514
get_model_name_plotly(model_key),
510515
get_sampler_name(sampler),
511516
model_data['std'].iloc[0]

case_study1/load_results_benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111

1212
from case_study1.model_settings_benchmark import MODELS, SAMPLER_SETTINGS, is_compatible
1313

14+
import warnings
15+
warnings.filterwarnings(
16+
"ignore",
17+
message="JULIA_SYSIMAGE_DIFFEQTORCH not set"
18+
)
1419

1520
def model_sampler_key(_model: str, _sampler: str) -> str:
1621
return f"{_model}-{_sampler}"

0 commit comments

Comments
 (0)