Skip to content

Commit cfc5144

Browse files
lchen001Lingjiao Chen
andauthored
add random state for majority vote (#82)
Co-authored-by: Lingjiao Chen <[email protected]>
1 parent 8bebccf commit cfc5144

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

eureka_ml_insights/data_utils/transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,20 +370,21 @@ class MajorityVoteTransform:
370370
id_col: str = "data_point_id" # Default column name for IDs
371371
majority_vote_col: str = "majority_vote"
372372

373-
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
373+
def transform(self, df: pd.DataFrame, random_state:int=0) -> pd.DataFrame:
374374
"""
375375
Transforms the dataframe by calculating the majority vote of model_output_col per id_col.
376376
If the 'model_output' is NaN, it will be droped before calculating the majority vote.
377377
378378
Args:
379379
df (pd.DataFrame): Input dataframe containing model_output_col and id_col.
380+
random_state (int): Input random seed
380381
381382
Returns:
382383
pd.DataFrame: Transformed dataframe with majority vote for each id_col.
383384
"""
384385
# Step 1: Group by 'ID' and calculate the majority vote within each group
385386
df[self.majority_vote_col] = df.groupby(self.id_col)[self.model_output_col].transform(
386-
lambda x: x.dropna().mode().sample(n=1).iloc[0] if not x.dropna().mode().empty else pd.NA
387+
lambda x: x.dropna().mode().sample(n=1, random_state=random_state).iloc[0] if not x.dropna().mode().empty else pd.NA
387388
)
388389

389390
return df

0 commit comments

Comments
 (0)