@@ -256,35 +256,19 @@ def plot_benchmark_results(df, show_sampler, save_path=None):
256256 # Collect all unique model-sampler combinations across all problems
257257 all_combinations = []
258258 for model_key in get_model_name ():
259- if show_sampler == 'all' :
260- # Group by model across all samplers
261- model_data = df [df ['model' ] == model_key ]
262- if len (model_data ) > 0 :
263- samplers = model_data ['sampler' ].unique ()
264- for sampler in samplers :
265- label = get_model_name (model_key )
266- if 'consistency' not in model_key :
267- label += f' ({ get_sampler_name (sampler )} )'
268- all_combinations .append ({
269- 'model' : model_key ,
270- 'sampler' : sampler ,
271- 'label' : label ,
272- 'color' : colors .get (model_key , 'gray' )
273- })
274- else :
275- # Single sampler per model
276- model_data = df [df ['model' ] == model_key ]
277- if len (model_data ) > 0 :
278- sampler = model_data ['sampler' ].iloc [0 ]
279- label = get_model_name (model_key )
280- if 'consistency' not in model_key and show_sampler != 'all' :
281- label += f' ({ get_sampler_name (sampler )} )'
282- all_combinations .append ({
283- 'model' : model_key ,
284- 'sampler' : sampler ,
285- 'label' : label ,
286- 'color' : colors .get (model_key , 'gray' )
287- })
259+ # Single sampler per model
260+ model_data = df [df ['model' ] == model_key ]
261+ if len (model_data ) > 0 :
262+ sampler = model_data ['sampler' ].iloc [0 ]
263+ label = get_model_name (model_key )
264+ if 'consistency' not in model_key and show_sampler != 'all' :
265+ label += f' ({ get_sampler_name (sampler )} )'
266+ all_combinations .append ({
267+ 'model' : model_key ,
268+ 'sampler' : sampler ,
269+ 'label' : label ,
270+ 'color' : colors .get (model_key , 'gray' )
271+ })
288272
289273 # Create a consistent mapping from combination to x-position
290274 combination_to_x = {
@@ -485,7 +469,6 @@ def plot_by_sampler(df, col='c2st', col_std='std', save_path=None):
485469 axes [0 ].set_ylabel (r'Time [s]' , fontsize = 10 )
486470 axes [n_cols ].set_ylabel (r'Time [s]' , fontsize = 10 )
487471 axes [0 ].set_yscale ('log' )
488- #axes[0].set_ylim(0.01, 125)
489472
490473 # Hide unused subplots
491474 if n_samplers < len (axes ):
@@ -612,59 +595,71 @@ def plot_by_model(df, col='c2st', col_std='std', save_path=None):
612595 plt .show ()
613596
614597
615- def plot_low_budget_results (df , save_path = None ):
616- """Plot C2ST and time results for low-budget samplers across models ."""
617- if df ['problem' ].iloc [0 ] ! = 'all' :
598+ def plot_low_budget_c2st_by_problem (df , save_path = None ):
599+ """Plot C2ST results for low-budget samplers, one subplot per problem ."""
600+ if df ['problem' ].iloc [0 ] = = 'all' :
618601 return
619-
620- # C2ST plot
621- labels = []
622- data_to_plot = []
623- data_to_plot_std = []
624602 samplers = ['ode-euler-mini' , 'ode-euler-small' , 'ode-euler' ]
625603 colors = ["#0072B2" , "#E69F00" , "#009E73" ]
626604
627- for s in samplers :
628- subset = df [df .sampler == s ]
629- data_to_plot .append (subset .c2st .values )
630- data_to_plot_std .append (subset ['std' ].values )
631- labels .append (get_sampler_name (s ))
632-
633- # Get model names for x-axis
634- subset = df [df .sampler == samplers [0 ]]
635- model_labels = [get_model_name (m ) for m in subset .model .values ]
636-
637- fig , axs = plt .subplots (ncols = 2 , figsize = (12 , 4 ), layout = 'constrained' )
638- ax = axs [0 ]
639- for s_i , s in enumerate (samplers ):
640- ax .errorbar (x = np .arange (len (model_labels )), y = np .array (data_to_plot )[s_i ],
641- yerr = np .array (data_to_plot_std )[s_i ], marker = 'o' , markersize = 5 , color = colors [s_i ])
642- ax .set_ylabel (r'$\mathrm{C2ST}$' , fontsize = 10 )
643- ax .set_xticks (ticks = np .arange (len (model_labels )), labels = model_labels , rotation = 45 , ha = 'right' )
644- ax .grid (True , alpha = 0.3 )
645- ax .spines ['right' ].set_visible (False )
646- ax .spines ['top' ].set_visible (False )
647-
648- # Time plot
649- labels = []
650- data_to_plot = []
651- data_to_plot_std = []
652- for s in samplers :
653- subset = df [df .sampler == s ]
654- data_to_plot .append (subset .time .values )
655- data_to_plot_std .append (subset .time_std .values )
656- labels .append (get_sampler_name (s ))
657-
658- ax = axs [1 ]
659- for s_i , s in enumerate (samplers ):
660- ax .errorbar (x = np .arange (len (model_labels )), y = np .array (data_to_plot )[s_i ], yerr = np .array (data_to_plot_std )[s_i ],
661- marker = 'o' , markersize = 5 , color = colors [s_i ])
662- ax .set_ylabel (r'Time [s]' , fontsize = 10 )
663- ax .set_xticks (ticks = np .arange (len (model_labels )), labels = model_labels , rotation = 45 , ha = 'right' )
664- ax .grid (True , alpha = 0.3 )
665- ax .spines ['right' ].set_visible (False )
666- ax .spines ['top' ].set_visible (False )
667- fig .legend (labels , fontsize = 10 , ncols = 3 , loc = 'lower center' , bbox_to_anchor = (0.5 , - 0.1 ), fancybox = False )
605+ problem_names = sbibm .get_available_tasks ()
606+ problem_names_nice = np .array ([sbibm .get_task (p ).name_display for p in problem_names ])
607+ problem_dim = [sbibm .get_task (p ).dim_parameters for p in problem_names ]
608+ data_dim = [sbibm .get_task (p ).dim_data for p in problem_names ]
609+ problem_order = np .lexsort ((data_dim , problem_dim ))
610+ n_problems = len (problem_names )
611+
612+ fig , axs = plt .subplots (
613+ ncols = n_problems // 2 ,
614+ nrows = 2 ,
615+ figsize = (12 , 4 ),
616+ layout = 'constrained' ,
617+ sharey = True , sharex = True
618+ )
619+
620+ # Ensure axs is iterable for n_problems == 1
621+ if n_problems == 1 :
622+ axs = np .array ([axs ])
623+ axs = axs .flatten ()
624+
625+ for p_i , problem_idx in enumerate (problem_order ):
626+ ax = axs [p_i ]
627+ df_p = df [df ['problem' ] == problem_names [problem_idx ]]
628+
629+ # Model labels from the first sampler (assumed consistent)
630+ subset_models = df_p [df_p .sampler == samplers [0 ]]
631+ model_labels = [get_model_name (m ) for m in subset_models .model .values ]
632+ x = np .arange (len (model_labels ))
633+
634+ for s_i , s in enumerate (samplers ):
635+ subset = df_p [df_p .sampler == s ]
636+ ax .errorbar (
637+ x = x ,
638+ y = subset .c2st .values ,
639+ yerr = subset ['std' ].values ,
640+ marker = 'o' ,
641+ markersize = 5 ,
642+ color = colors [s_i ],
643+ label = get_sampler_name (s ) if p_i == 0 else None
644+ )
645+
646+ ax .set_title (problem_names_nice [problem_idx ], fontsize = 10 )
647+ ax .set_xticks (x , model_labels , rotation = 90 , ha = 'right' )
648+ ax .grid (True , alpha = 0.3 )
649+ ax .spines ['right' ].set_visible (False )
650+ ax .spines ['top' ].set_visible (False )
651+
652+ if p_i == 0 or p_i == n_problems // 2 :
653+ ax .set_ylabel (r'$\mathrm{C2ST}$' , fontsize = 10 )
654+
655+ fig .legend (
656+ fontsize = 10 ,
657+ ncols = 3 ,
658+ loc = 'lower center' ,
659+ bbox_to_anchor = (0.5 , - 0.15 ),
660+ fancybox = False
661+ )
662+
668663 if save_path is not None :
669664 fig .savefig (save_path , bbox_inches = 'tight' )
670665 plt .show ()
0 commit comments