11from numpy .typing import ArrayLike
22import pandas as pd
33import numpy as np
4+ from sklearn .base import BaseEstimator , TransformerMixin
5+ from sklearn .pipeline import FunctionTransformer
46from reddwarf .utils .matrix import VoteMatrix , generate_virtual_vote_matrix
57from reddwarf .sklearn .transformers import SparsityAwareCapturer , SparsityAwareScaler
68from reddwarf .sklearn .pipeline import PatchedPipeline
7- from typing import Tuple
9+ from typing import Optional , Tuple
810
911from sklearn .decomposition import PCA
1012from sklearn .impute import SimpleImputer
1113
14+ from reddwarf .utils .matrix import simple_filter_matrix
15+
16+ class ModerationFilterTransformer (BaseEstimator , TransformerMixin ):
17+ """
18+ Transformer that zero's out specific moderated columns.
19+ """
20+ def __init__ (self , columns_to_filter : list [int ] = []):
21+ self .columns_to_filter = columns_to_filter
22+
23+ def fit (self , X , y = None ):
24+ return self
25+
26+ def transform (self , X ):
27+ return simple_filter_matrix (X , self .columns_to_filter )
28+
1229
1330def run_pca (
1431 vote_matrix : VoteMatrix ,
1532 n_components : int = 2 ,
33+ mod_out_statement_ids : list [int ] = [],
1634) -> Tuple [ pd .DataFrame , pd .DataFrame , PCA ]:
1735 """
1836 Process a prepared vote matrix to be imputed and return projected participant data,
@@ -31,23 +49,35 @@ def run_pca(
3149 - explained_variance_ (List[float]): Explained variance of each principal component.
3250 - mean_ (list[float]): Means/centers of each column/statements/features.
3351 """
52+ X_raw = vote_matrix .values
53+ # moderation = ModerationFilterTransformer(columns_to_filter=mod_out_statement_ids)
54+ # X_moderated = simple_filter_matrix(X_raw)
55+
56+ # filtered_vote_matrix = simple_filter_matrix(
57+ # vote_matrix=vote_matrix,
58+ # mod_out_statement_ids=mod_out_statement_ids,
59+ # )
60+
3461 pipeline = PatchedPipeline ([
62+ ("moderate" , ModerationFilterTransformer ()),
3563 ("capture" , SparsityAwareCapturer ()),
3664 ("impute" , SimpleImputer (missing_values = np .nan , strategy = "mean" )),
3765 ("pca" , PCA (n_components = n_components )),
3866 ("scale" , SparsityAwareScaler (capture_step = "capture" )),
3967 ])
4068
41- pipeline .fit (vote_matrix .values )
4269
4370 # Generate projections of participants.
44- X_participants = pipeline .transform (vote_matrix .values )
71+ pipeline .named_steps ["moderate" ].columns_to_filter = mod_out_statement_ids
72+ pipeline .fit (X_raw )
73+ X_participants = pipeline .transform (X_raw )
4574
4675 # Generate projections of statements via virtual vote matrix.
4776 # This projects unit vectors for each feature/statement into PCA space to
4877 # understand their placement.
4978 n_statements = len (vote_matrix .columns )
5079 virtual_vote_matrix = generate_virtual_vote_matrix (n_statements )
80+ pipeline .named_steps ["moderate" ].columns_to_filter = []
5181 X_statements = pipeline .transform (virtual_vote_matrix )
5282
5383 DEFAULT_DIMENSION_LABELS = ["x" , "y" , "z" ]
0 commit comments