Skip to content

Commit

Permalink
add random state for majority vote (#82)
Browse files Browse the repository at this point in the history
Co-authored-by: Lingjiao Chen <[email protected]>
  • Loading branch information
lchen001 and Lingjiao Chen authored Jan 24, 2025
1 parent 8bebccf commit cfc5144
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions eureka_ml_insights/data_utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,20 +370,21 @@ class MajorityVoteTransform:
id_col: str = "data_point_id" # Default column name for IDs
majority_vote_col: str = "majority_vote"

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
def transform(self, df: pd.DataFrame, random_state:int=0) -> pd.DataFrame:
"""
Transforms the dataframe by calculating the majority vote of model_output_col per id_col.
If the 'model_output' is NaN, it will be droped before calculating the majority vote.
Args:
df (pd.DataFrame): Input dataframe containing model_output_col and id_col.
random_state (int): Input random seed
Returns:
pd.DataFrame: Transformed dataframe with majority vote for each id_col.
"""
# Step 1: Group by 'ID' and calculate the majority vote within each group
df[self.majority_vote_col] = df.groupby(self.id_col)[self.model_output_col].transform(
lambda x: x.dropna().mode().sample(n=1).iloc[0] if not x.dropna().mode().empty else pd.NA
lambda x: x.dropna().mode().sample(n=1, random_state=random_state).iloc[0] if not x.dropna().mode().empty else pd.NA
)

return df

0 comments on commit cfc5144

Please sign in to comment.