Skip to content

Commit 10744cb

Browse files
committed
WIP testing idea of CaptureMixin and CaptureWrapper for estimators.
1 parent 979cff0 commit 10744cb

4 files changed

Lines changed: 55 additions & 14 deletions

File tree

reddwarf/implementations/polis.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,13 @@ def run_clustering(
7171
"""
7272
raw_vote_matrix = generate_raw_matrix(votes=votes)
7373

74+
# Not used, just for output below, until we can get it into pipeline for inspection.
7475
filtered_vote_matrix = simple_filter_matrix(
7576
vote_matrix=raw_vote_matrix,
7677
mod_out_statement_ids=mod_out_statement_ids,
7778
)
7879

79-
projected_participants, projected_statements, pca = run_pca(vote_matrix=filtered_vote_matrix)
80+
projected_participants, projected_statements, pca = run_pca(vote_matrix=raw_vote_matrix, mod_out_statement_ids=mod_out_statement_ids)
8081

8182
participant_ids_clusterable = get_clusterable_participant_ids(raw_vote_matrix, vote_threshold=min_user_vote_threshold)
8283
if keep_participant_ids:

reddwarf/sklearn/transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def _calculate_scaling_factors(self):
102102
X_sparse = self._resolve_X_sparse()
103103
return calculate_scaling_factors(X_sparse=X_sparse)
104104

105+
# TODO: Replace to CaptureMixin and CaptureWrapper.
106+
# See: https://chatgpt.com/c/680a512a-f604-800b-8922-1992a8ddf491
105107
class SparsityAwareCapturer(BaseEstimator, TransformerMixin):
106108
"""
107109
A passthrough transformer that captures and stores the X it receives in

reddwarf/utils/matrix.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def get_unvoted_statement_ids(vote_matrix: VoteMatrix) -> List[int]:
112112
return null_column_ids
113113

114114
def simple_filter_matrix(
115-
vote_matrix: VoteMatrix,
115+
vote_matrix: VoteMatrix | np.ndarray,
116116
mod_out_statement_ids: list[int] = [],
117-
) -> VoteMatrix:
117+
) -> VoteMatrix | np.ndarray:
118118
"""
119119
The simple filter on the vote_matrix that is used by Polis prior to running PCA.
120120
@@ -125,14 +125,22 @@ def simple_filter_matrix(
125125
Returns:
126126
VoteMatrix: Copy of vote_matrix with statements zero'd out
127127
"""
128-
vote_matrix = vote_matrix.copy()
129-
for tid in mod_out_statement_ids:
130-
# Zero out column only if already exists (ie. has votes)
131-
if tid in vote_matrix.columns:
132-
# TODO: Add a flag to try np.nan instead of zero.
133-
vote_matrix.loc[:, tid] = 0
134-
135-
return vote_matrix
128+
if isinstance(vote_matrix, pd.DataFrame):
129+
vote_matrix = vote_matrix.copy()
130+
for col in mod_out_statement_ids:
131+
if col in vote_matrix.columns:
132+
vote_matrix[col] = 0
133+
return vote_matrix
134+
135+
elif isinstance(vote_matrix, np.ndarray):
136+
vote_matrix = vote_matrix.copy()
137+
for col in mod_out_statement_ids:
138+
if isinstance(col, int) and 0 <= col < vote_matrix.shape[1]:
139+
vote_matrix[:, col] = 0
140+
return vote_matrix
141+
142+
else:
143+
raise TypeError("vote_matrix must be a pandas DataFrame or a NumPy ndarray.")
136144

137145
def get_clusterable_participant_ids(vote_matrix: VoteMatrix, vote_threshold: int) -> list:
138146
"""

reddwarf/utils/pca.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,36 @@
11
from numpy.typing import ArrayLike
22
import pandas as pd
33
import numpy as np
4+
from sklearn.base import BaseEstimator, TransformerMixin
5+
from sklearn.pipeline import FunctionTransformer
46
from reddwarf.utils.matrix import VoteMatrix, generate_virtual_vote_matrix
57
from reddwarf.sklearn.transformers import SparsityAwareCapturer, SparsityAwareScaler
68
from reddwarf.sklearn.pipeline import PatchedPipeline
7-
from typing import Tuple
9+
from typing import Optional, Tuple
810

911
from sklearn.decomposition import PCA
1012
from sklearn.impute import SimpleImputer
1113

14+
from reddwarf.utils.matrix import simple_filter_matrix
15+
16+
class ModerationFilterTransformer(BaseEstimator, TransformerMixin):
17+
"""
18+
Transformer that zero's out specific moderated columns.
19+
"""
20+
def __init__(self, columns_to_filter: list[int] = []):
21+
self.columns_to_filter = columns_to_filter
22+
23+
def fit(self, X, y=None):
24+
return self
25+
26+
def transform(self, X):
27+
return simple_filter_matrix(X, self.columns_to_filter)
28+
1229

1330
def run_pca(
1431
vote_matrix: VoteMatrix,
1532
n_components: int = 2,
33+
mod_out_statement_ids: list[int] = [],
1634
) -> Tuple[ pd.DataFrame, pd.DataFrame, PCA ]:
1735
"""
1836
Process a prepared vote matrix to be imputed and return projected participant data,
@@ -31,23 +49,35 @@ def run_pca(
3149
- explained_variance_ (List[float]): Explained variance of each principal component.
3250
- mean_ (list[float]): Means/centers of each column/statements/features.
3351
"""
52+
X_raw = vote_matrix.values
53+
# moderation = ModerationFilterTransformer(columns_to_filter=mod_out_statement_ids)
54+
# X_moderated = simple_filter_matrix(X_raw)
55+
56+
# filtered_vote_matrix = simple_filter_matrix(
57+
# vote_matrix=vote_matrix,
58+
# mod_out_statement_ids=mod_out_statement_ids,
59+
# )
60+
3461
pipeline = PatchedPipeline([
62+
("moderate", ModerationFilterTransformer()),
3563
("capture", SparsityAwareCapturer()),
3664
("impute", SimpleImputer(missing_values=np.nan, strategy="mean")),
3765
("pca", PCA(n_components=n_components)),
3866
("scale", SparsityAwareScaler(capture_step="capture")),
3967
])
4068

41-
pipeline.fit(vote_matrix.values)
4269

4370
# Generate projections of participants.
44-
X_participants = pipeline.transform(vote_matrix.values)
71+
pipeline.named_steps["moderate"].columns_to_filter = mod_out_statement_ids
72+
pipeline.fit(X_raw)
73+
X_participants = pipeline.transform(X_raw)
4574

4675
# Generate projections of statements via virtual vote matrix.
4776
# This projects unit vectors for each feature/statement into PCA space to
4877
# understand their placement.
4978
n_statements = len(vote_matrix.columns)
5079
virtual_vote_matrix = generate_virtual_vote_matrix(n_statements)
80+
pipeline.named_steps["moderate"].columns_to_filter = []
5181
X_statements = pipeline.transform(virtual_vote_matrix)
5282

5383
DEFAULT_DIMENSION_LABELS = ["x", "y", "z"]

0 commit comments

Comments
 (0)