@@ -185,9 +185,17 @@ def _categorical_summary_plot(self, level_name, level_df, remaining_groups, grou
185185 ch .set_legend_location ("outside_bottom" )
186186 return ch
187187
188- def _difference_posteriors (self , data , level_1 , level_2 , absolute = True ):
189- posterior_1 = self ._sample_posterior (data .get_group (level_1 ))
190- posterior_2 = self ._sample_posterior (data .get_group (level_2 ))
188+ def _difference_posteriors (self , data , level_1 , level_2 , absolute = True , remaining_groups = None ):
189+ # When grouping with a length-1 list, get_group expects a tuple
190+ if isinstance (remaining_groups , list ) and len (remaining_groups ) == 1 :
191+ level_1_key = (level_1 ,)
192+ level_2_key = (level_2 ,)
193+ else :
194+ level_1_key = level_1
195+ level_2_key = level_2
196+
197+ posterior_1 = self ._sample_posterior (data .get_group (level_1_key ))
198+ posterior_2 = self ._sample_posterior (data .get_group (level_2_key ))
191199
192200 if absolute :
193201 difference_posterior = posterior_2 - posterior_1
@@ -256,7 +264,7 @@ def _difference_and_difference_posterior(self, level_df, remaining_groups, level
256264 self ._validate_levels (level_df , remaining_groups , level_2 )
257265 # difference is posterior_2 - posterior_1
258266 difference_posterior = self ._difference_posteriors (
259- level_df .groupby (remaining_groups ), level_1 , level_2 , absolute
267+ level_df .groupby (remaining_groups ), level_1 , level_2 , absolute , remaining_groups
260268 )
261269 difference_df = self ._differences (difference_posterior , level_1 , level_2 , absolute )
262270 return difference_df , difference_posterior
@@ -384,7 +392,11 @@ def _multiple_difference_joint_base(self, level_name, level_df, remaining_groups
384392
385393 self ._validate_levels (level_df , remaining_groups , level )
386394
387- posteriors = [self ._sample_posterior (grouped_df .get_group (level )) for level in grouped_df_keys ]
395+ # When grouping with a length-1 list, get_group expects a tuple
396+ if isinstance (remaining_groups , list ) and len (remaining_groups ) == 1 :
397+ posteriors = [self ._sample_posterior (grouped_df .get_group ((lvl ,))) for lvl in grouped_df_keys ]
398+ else :
399+ posteriors = [self ._sample_posterior (grouped_df .get_group (lvl )) for lvl in grouped_df_keys ]
388400
389401 var_indx = grouped_df_keys .index (level )
390402 other_indx = [i for i , value in enumerate (grouped_df_keys ) if value != level ]
@@ -627,7 +639,11 @@ def _categorical_multiple_difference_chart(
627639
628640 self ._validate_levels (level_df , remaining_groups , level )
629641
630- posteriors = [self ._sample_posterior (grouped_df .get_group (level )) for level in grouped_df_keys ]
642+ # When grouping with a length-1 list, get_group expects a tuple
643+ if isinstance (remaining_groups , list ) and len (remaining_groups ) == 1 :
644+ posteriors = [self ._sample_posterior (grouped_df .get_group ((lvl ,))) for lvl in grouped_df_keys ]
645+ else :
646+ posteriors = [self ._sample_posterior (grouped_df .get_group (lvl )) for lvl in grouped_df_keys ]
631647
632648 var_indx = grouped_df_keys .index (level )
633649
0 commit comments