File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
eureka_ml_insights/data_utils Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -370,20 +370,21 @@ class MajorityVoteTransform:
370
370
id_col : str = "data_point_id" # Default column name for IDs
371
371
majority_vote_col : str = "majority_vote"
372
372
373
- def transform (self , df : pd .DataFrame ) -> pd .DataFrame :
373
+ def transform (self , df : pd .DataFrame , random_state : int = 0 ) -> pd .DataFrame :
374
374
"""
375
375
Transforms the dataframe by calculating the majority vote of model_output_col per id_col.
376
376
If the 'model_output' is NaN, it will be droped before calculating the majority vote.
377
377
378
378
Args:
379
379
df (pd.DataFrame): Input dataframe containing model_output_col and id_col.
380
+ random_state (int): Input random seed
380
381
381
382
Returns:
382
383
pd.DataFrame: Transformed dataframe with majority vote for each id_col.
383
384
"""
384
385
# Step 1: Group by 'ID' and calculate the majority vote within each group
385
386
df [self .majority_vote_col ] = df .groupby (self .id_col )[self .model_output_col ].transform (
386
- lambda x : x .dropna ().mode ().sample (n = 1 ).iloc [0 ] if not x .dropna ().mode ().empty else pd .NA
387
+ lambda x : x .dropna ().mode ().sample (n = 1 , random_state = random_state ).iloc [0 ] if not x .dropna ().mode ().empty else pd .NA
387
388
)
388
389
389
390
return df
You can’t perform that action at this time.
0 commit comments