Skip to content

Commit 9bbb1ca

Browse files
committed
update plots
1 parent d8c0a3f commit 9bbb1ca

9 files changed

+99
-99
lines changed

case_study1/helper_visualize.py

Lines changed: 74 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -256,35 +256,19 @@ def plot_benchmark_results(df, show_sampler, save_path=None):
256256
# Collect all unique model-sampler combinations across all problems
257257
all_combinations = []
258258
for model_key in get_model_name():
259-
if show_sampler == 'all':
260-
# Group by model across all samplers
261-
model_data = df[df['model'] == model_key]
262-
if len(model_data) > 0:
263-
samplers = model_data['sampler'].unique()
264-
for sampler in samplers:
265-
label = get_model_name(model_key)
266-
if 'consistency' not in model_key:
267-
label += f' ({get_sampler_name(sampler)})'
268-
all_combinations.append({
269-
'model': model_key,
270-
'sampler': sampler,
271-
'label': label,
272-
'color': colors.get(model_key, 'gray')
273-
})
274-
else:
275-
# Single sampler per model
276-
model_data = df[df['model'] == model_key]
277-
if len(model_data) > 0:
278-
sampler = model_data['sampler'].iloc[0]
279-
label = get_model_name(model_key)
280-
if 'consistency' not in model_key and show_sampler != 'all':
281-
label += f' ({get_sampler_name(sampler)})'
282-
all_combinations.append({
283-
'model': model_key,
284-
'sampler': sampler,
285-
'label': label,
286-
'color': colors.get(model_key, 'gray')
287-
})
259+
# Single sampler per model
260+
model_data = df[df['model'] == model_key]
261+
if len(model_data) > 0:
262+
sampler = model_data['sampler'].iloc[0]
263+
label = get_model_name(model_key)
264+
if 'consistency' not in model_key and show_sampler != 'all':
265+
label += f' ({get_sampler_name(sampler)})'
266+
all_combinations.append({
267+
'model': model_key,
268+
'sampler': sampler,
269+
'label': label,
270+
'color': colors.get(model_key, 'gray')
271+
})
288272

289273
# Create a consistent mapping from combination to x-position
290274
combination_to_x = {
@@ -485,7 +469,6 @@ def plot_by_sampler(df, col='c2st', col_std='std', save_path=None):
485469
axes[0].set_ylabel(r'Time [s]', fontsize=10)
486470
axes[n_cols].set_ylabel(r'Time [s]', fontsize=10)
487471
axes[0].set_yscale('log')
488-
#axes[0].set_ylim(0.01, 125)
489472

490473
# Hide unused subplots
491474
if n_samplers < len(axes):
@@ -612,59 +595,71 @@ def plot_by_model(df, col='c2st', col_std='std', save_path=None):
612595
plt.show()
613596

614597

615-
def plot_low_budget_results(df, save_path=None):
616-
"""Plot C2ST and time results for low-budget samplers across models."""
617-
if df['problem'].iloc[0] != 'all':
598+
def plot_low_budget_c2st_by_problem(df, save_path=None):
599+
"""Plot C2ST results for low-budget samplers, one subplot per problem."""
600+
if df['problem'].iloc[0] == 'all':
618601
return
619-
620-
# C2ST plot
621-
labels = []
622-
data_to_plot = []
623-
data_to_plot_std = []
624602
samplers = ['ode-euler-mini', 'ode-euler-small', 'ode-euler']
625603
colors = ["#0072B2", "#E69F00", "#009E73"]
626604

627-
for s in samplers:
628-
subset = df[df.sampler == s]
629-
data_to_plot.append(subset.c2st.values)
630-
data_to_plot_std.append(subset['std'].values)
631-
labels.append(get_sampler_name(s))
632-
633-
# Get model names for x-axis
634-
subset = df[df.sampler == samplers[0]]
635-
model_labels = [get_model_name(m) for m in subset.model.values]
636-
637-
fig, axs = plt.subplots(ncols=2, figsize=(12, 4), layout='constrained')
638-
ax = axs[0]
639-
for s_i, s in enumerate(samplers):
640-
ax.errorbar(x=np.arange(len(model_labels)), y=np.array(data_to_plot)[s_i],
641-
yerr=np.array(data_to_plot_std)[s_i], marker='o', markersize=5, color=colors[s_i])
642-
ax.set_ylabel(r'$\mathrm{C2ST}$', fontsize=10)
643-
ax.set_xticks(ticks=np.arange(len(model_labels)), labels=model_labels, rotation=45, ha='right')
644-
ax.grid(True, alpha=0.3)
645-
ax.spines['right'].set_visible(False)
646-
ax.spines['top'].set_visible(False)
647-
648-
# Time plot
649-
labels = []
650-
data_to_plot = []
651-
data_to_plot_std = []
652-
for s in samplers:
653-
subset = df[df.sampler == s]
654-
data_to_plot.append(subset.time.values)
655-
data_to_plot_std.append(subset.time_std.values)
656-
labels.append(get_sampler_name(s))
657-
658-
ax = axs[1]
659-
for s_i, s in enumerate(samplers):
660-
ax.errorbar(x=np.arange(len(model_labels)), y=np.array(data_to_plot)[s_i], yerr=np.array(data_to_plot_std)[s_i],
661-
marker='o', markersize=5, color=colors[s_i])
662-
ax.set_ylabel(r'Time [s]', fontsize=10)
663-
ax.set_xticks(ticks=np.arange(len(model_labels)), labels=model_labels, rotation=45, ha='right')
664-
ax.grid(True, alpha=0.3)
665-
ax.spines['right'].set_visible(False)
666-
ax.spines['top'].set_visible(False)
667-
fig.legend(labels, fontsize=10, ncols=3, loc='lower center', bbox_to_anchor=(0.5, -0.1), fancybox=False)
605+
problem_names = sbibm.get_available_tasks()
606+
problem_names_nice = np.array([sbibm.get_task(p).name_display for p in problem_names])
607+
problem_dim = [sbibm.get_task(p).dim_parameters for p in problem_names]
608+
data_dim = [sbibm.get_task(p).dim_data for p in problem_names]
609+
problem_order = np.lexsort((data_dim, problem_dim))
610+
n_problems = len(problem_names)
611+
612+
fig, axs = plt.subplots(
613+
ncols=n_problems // 2,
614+
nrows=2,
615+
figsize=(12, 4),
616+
layout='constrained',
617+
sharey=True, sharex=True
618+
)
619+
620+
# Ensure axs is iterable for n_problems == 1
621+
if n_problems == 1:
622+
axs = np.array([axs])
623+
axs = axs.flatten()
624+
625+
for p_i, problem_idx in enumerate(problem_order):
626+
ax = axs[p_i]
627+
df_p = df[df['problem'] == problem_names[problem_idx]]
628+
629+
# Model labels from the first sampler (assumed consistent)
630+
subset_models = df_p[df_p.sampler == samplers[0]]
631+
model_labels = [get_model_name(m) for m in subset_models.model.values]
632+
x = np.arange(len(model_labels))
633+
634+
for s_i, s in enumerate(samplers):
635+
subset = df_p[df_p.sampler == s]
636+
ax.errorbar(
637+
x=x,
638+
y=subset.c2st.values,
639+
yerr=subset['std'].values,
640+
marker='o',
641+
markersize=5,
642+
color=colors[s_i],
643+
label=get_sampler_name(s) if p_i == 0 else None
644+
)
645+
646+
ax.set_title(problem_names_nice[problem_idx], fontsize=10)
647+
ax.set_xticks(x, model_labels, rotation=90, ha='right')
648+
ax.grid(True, alpha=0.3)
649+
ax.spines['right'].set_visible(False)
650+
ax.spines['top'].set_visible(False)
651+
652+
if p_i == 0 or p_i == n_problems // 2:
653+
ax.set_ylabel(r'$\mathrm{C2ST}$', fontsize=10)
654+
655+
fig.legend(
656+
fontsize=10,
657+
ncols=3,
658+
loc='lower center',
659+
bbox_to_anchor=(0.5, -0.15),
660+
fancybox=False
661+
)
662+
668663
if save_path is not None:
669664
fig.savefig(save_path, bbox_inches='tight')
670665
plt.show()
0 Bytes
Binary file not shown.
-47.7 KB
Binary file not shown.
0 Bytes
Binary file not shown.
41.2 KB
Binary file not shown.
-24.1 KB
Binary file not shown.
0 Bytes
Binary file not shown.
5.18 KB
Binary file not shown.

case_study1/visualize_results.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _():
2222
sys.path.append(str(PROJECT_ROOT))
2323

2424
from case_study1.model_settings_benchmark import SAMPLER_SETTINGS
25-
from case_study1.helper_visualize import plot_benchmark_results, plot_by_sampler, plot_by_model, plot_low_budget_results, pareto_best_sampler
25+
from case_study1.helper_visualize import plot_benchmark_results, plot_by_sampler, plot_by_model, plot_low_budget_c2st_by_problem, pareto_best_sampler
2626
return (
2727
BASE,
2828
SAMPLER_SETTINGS,
@@ -32,7 +32,7 @@ def _():
3232
plot_benchmark_results,
3333
plot_by_model,
3434
plot_by_sampler,
35-
plot_low_budget_results,
35+
plot_low_budget_c2st_by_problem,
3636
plt,
3737
sbibm,
3838
)
@@ -49,7 +49,7 @@ def _(BASE, pd):
4949
@app.cell
5050
def _(SAMPLER_SETTINGS):
5151
all_samplers= ['best', 'merge_problems'] + [k for k in SAMPLER_SETTINGS.keys()]
52-
SHOW_SAMPLER = all_samplers[1]
52+
SHOW_SAMPLER = all_samplers[0]
5353
print(SHOW_SAMPLER)
5454
return (SHOW_SAMPLER,)
5555

@@ -92,10 +92,13 @@ def _(SHOW_SAMPLER, pareto_best_sampler, results):
9292
].reset_index(drop=True)
9393

9494
# only flow matching
95-
long_df_fm = long_df[
95+
long_df_reduced_fm = long_df[
9696
long_df["model"].str.contains('flow_matching')
9797
].reset_index(drop=True)
98-
return long_df, long_df_fm, long_df_reduced
98+
long_df_fm = long_df_copy[
99+
long_df_copy["model"].str.contains('flow_matching')
100+
].reset_index(drop=True)
101+
return long_df, long_df_fm, long_df_reduced, long_df_reduced_fm
99102

100103

101104
@app.cell
@@ -157,29 +160,31 @@ def _(
157160
BASE,
158161
SHOW_SAMPLER,
159162
long_df_fm,
160-
plot_benchmark_results,
161-
plot_low_budget_results,
163+
long_df_reduced_fm,
164+
plot_by_sampler,
165+
plot_low_budget_c2st_by_problem,
162166
):
163-
plot_benchmark_results(
164-
long_df_fm,
165-
SHOW_SAMPLER,
166-
BASE / 'plots' / f"c2st_benchmark_boxplot_{SHOW_SAMPLER}_fm.pdf"
167+
plot_by_sampler(
168+
long_df_reduced_fm,
169+
col='time',
170+
col_std='time_std',
171+
save_path=BASE / 'plots' / f"time_benchmark_boxplot_{SHOW_SAMPLER}_fm.pdf"
167172
)
168173

169-
#plot_by_sampler(
170-
# long_df_fm,
171-
# col='time',
172-
# col_std='time_std',
173-
# save_path=BASE / 'plots' / f"time_benchmark_boxplot_{SHOW_SAMPLER}_fm.pdf"
174-
#)
175-
176-
plot_low_budget_results(
174+
plot_low_budget_c2st_by_problem(
177175
long_df_fm,
178-
BASE / 'plots' / f"euler_benchmark_boxplot_{SHOW_SAMPLER}_fm.pdf"
176+
BASE / 'plots' / f"euler_benchmark_boxplot_fm.pdf"
179177
)
180178
return
181179

182180

181+
@app.cell
182+
def _():
183+
184+
185+
return
186+
187+
183188
@app.cell
184189
def _():
185190
import bayesflow as bf

0 commit comments

Comments
 (0)