Skip to content

Commit 0895607

Browse files
committed
Fix pandas deprecation warnings
1 parent d611f9b commit 0895607

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

spotify_confidence/analysis/bayesian/bayesian_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ def _categorical_multiple_difference_plot(self, level, absolute, groupby, level_
447447
@staticmethod
448448
def _validate_levels(level_df, remaining_groups, level):
449449
try:
450-
level_df.groupby(remaining_groups).get_group(level)
450+
# When grouping with a length-1 list, get_group expects a tuple
451+
group_key = (level,) if isinstance(remaining_groups, list) and len(remaining_groups) == 1 else level
452+
level_df.groupby(remaining_groups).get_group(group_key)
451453
except (KeyError, ValueError):
452454
raise ValueError(
453455
"""

spotify_confidence/analysis/bayesian/bayesian_models.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,17 @@ def _categorical_summary_plot(self, level_name, level_df, remaining_groups, grou
185185
ch.set_legend_location("outside_bottom")
186186
return ch
187187

188-
def _difference_posteriors(self, data, level_1, level_2, absolute=True):
189-
posterior_1 = self._sample_posterior(data.get_group(level_1))
190-
posterior_2 = self._sample_posterior(data.get_group(level_2))
188+
def _difference_posteriors(self, data, level_1, level_2, absolute=True, remaining_groups=None):
189+
# When grouping with a length-1 list, get_group expects a tuple
190+
if isinstance(remaining_groups, list) and len(remaining_groups) == 1:
191+
level_1_key = (level_1,)
192+
level_2_key = (level_2,)
193+
else:
194+
level_1_key = level_1
195+
level_2_key = level_2
196+
197+
posterior_1 = self._sample_posterior(data.get_group(level_1_key))
198+
posterior_2 = self._sample_posterior(data.get_group(level_2_key))
191199

192200
if absolute:
193201
difference_posterior = posterior_2 - posterior_1
@@ -256,7 +264,7 @@ def _difference_and_difference_posterior(self, level_df, remaining_groups, level
256264
self._validate_levels(level_df, remaining_groups, level_2)
257265
# difference is posterior_2 - posterior_1
258266
difference_posterior = self._difference_posteriors(
259-
level_df.groupby(remaining_groups), level_1, level_2, absolute
267+
level_df.groupby(remaining_groups), level_1, level_2, absolute, remaining_groups
260268
)
261269
difference_df = self._differences(difference_posterior, level_1, level_2, absolute)
262270
return difference_df, difference_posterior
@@ -384,7 +392,11 @@ def _multiple_difference_joint_base(self, level_name, level_df, remaining_groups
384392

385393
self._validate_levels(level_df, remaining_groups, level)
386394

387-
posteriors = [self._sample_posterior(grouped_df.get_group(level)) for level in grouped_df_keys]
395+
# When grouping with a length-1 list, get_group expects a tuple
396+
if isinstance(remaining_groups, list) and len(remaining_groups) == 1:
397+
posteriors = [self._sample_posterior(grouped_df.get_group((lvl,))) for lvl in grouped_df_keys]
398+
else:
399+
posteriors = [self._sample_posterior(grouped_df.get_group(lvl)) for lvl in grouped_df_keys]
388400

389401
var_indx = grouped_df_keys.index(level)
390402
other_indx = [i for i, value in enumerate(grouped_df_keys) if value != level]
@@ -627,7 +639,11 @@ def _categorical_multiple_difference_chart(
627639

628640
self._validate_levels(level_df, remaining_groups, level)
629641

630-
posteriors = [self._sample_posterior(grouped_df.get_group(level)) for level in grouped_df_keys]
642+
# When grouping with a length-1 list, get_group expects a tuple
643+
if isinstance(remaining_groups, list) and len(remaining_groups) == 1:
644+
posteriors = [self._sample_posterior(grouped_df.get_group((lvl,))) for lvl in grouped_df_keys]
645+
else:
646+
posteriors = [self._sample_posterior(grouped_df.get_group(lvl)) for lvl in grouped_df_keys]
631647

632648
var_indx = grouped_df_keys.index(level)
633649

spotify_confidence/analysis/confidence_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def get_all_categorical_group_columns(
9898
def validate_levels(df: DataFrame, level_columns: Union[str, Iterable], levels: Iterable):
9999
for level in levels:
100100
try:
101-
df.groupby(level_columns).get_group(level)
101+
# When grouping with a length-1 list, get_group expects a tuple
102+
group_key = (level,) if isinstance(level_columns, list) and len(level_columns) == 1 else level
103+
df.groupby(level_columns).get_group(group_key)
102104
except (KeyError, ValueError):
103105
raise ValueError(
104106
"""

0 commit comments

Comments
 (0)