Skip to content

Commit 50fa8f2

Browse files
authored
Merge pull request #47 from polis-community/full-stats-dataframe
Return full stats in participants and statements dataframe in PolisClusteringResult
2 parents 757ec32 + fc4e4f1 commit 50fa8f2

17 files changed

Lines changed: 287 additions & 293 deletions

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Allow `is_strict_moderation` to be inferred from not just API data, but file data.
77
- Better handle numpy divide-by-zero edge-cases in two-property test. ([#28](https://github.com/polis-community/red-dwarf/pull/28))
88
- Fix bug where `vote_matrix` was modified directly, leading to subtle side-effects.
9+
- Fix bug in `select_representative_statements()` where mod-out statements weren't ignored.
910

1011
### Changes
1112
- Fixed participant projections to map more closely to Polis with `utils.pca.sparsity_aware_project_ptpt()`.
@@ -38,6 +39,11 @@
3839
- Add group statement stats to MultiIndex DataFrame.
3940
- Add `reddwarf.data_presenter.print_repress()` for printing representative statements.
4041
- Add support for `Loader()` importing data from alternative Polis instances via `polis_instance_url` arg.
42+
- Patch sklearn with a simple `PatchedPipeline`, to allow pipeline steps to access other steps.
43+
- Modify `SparsityAwareScaler` to be able to use captured output from SparsityAware Capture.
44+
- Remove ported Polis PCA functions that are no longer used.
45+
- Remove old `impute_missing_votes()` function that's no longer used.
46+
- In `PolisClusteringResult`, created new `statements_df` and `participants_df` with all raw calculation values.
4147

4248
### Chores
4349
- Moved agora implementation from `reddwarf.agora` to `reddwarf.implementations.agora` (deprecation warning).

docs/api_reference.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ use in Scikit-Learn workflows, pipelines, and APIs.
7777
options:
7878
show_root_heading: true
7979

80-
### ::: reddwarf.utils.impute_missing_votes
81-
options:
82-
show_root_heading: true
83-
8480
### ::: reddwarf.utils.get_unvoted_statement_ids
8581
options:
8682
show_root_heading: true

docs/notebooks/polis-implementation-demo.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@
269269
"from reddwarf.utils.stats import select_representative_statements\n",
270270
"from reddwarf.data_presenter import print_repness\n",
271271
"\n",
272-
"repness = select_representative_statements(grouped_stats_df=result.group_comment_stats)\n",
272+
"repness = select_representative_statements(grouped_stats_df=result.group_comment_stats, mod_out_statement_ids=mod_out_statement_ids)\n",
273273
"print_repness(repness=repness, statements_data=statements)\n"
274274
],
275275
"metadata": {},

reddwarf/implementations/polis.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Optional
2-
from numpy.typing import NDArray
32
from pandas import DataFrame
43
from sklearn.decomposition import PCA
54
from reddwarf.sklearn.cluster import PolisKMeans
65
from reddwarf.utils.matrix import generate_raw_matrix, simple_filter_matrix, get_clusterable_participant_ids
76
from reddwarf.utils.pca import run_pca
87
from reddwarf.utils.clustering import find_optimal_k
98
from dataclasses import dataclass
9+
import pandas as pd
1010

11-
from reddwarf.utils.stats import calculate_comment_statistics_dataframes
11+
from reddwarf.utils.stats import calculate_comment_statistics_dataframes, populate_priority_calculations_into_statements_df
1212

1313
@dataclass
1414
class PolisClusteringResult:
@@ -22,6 +22,8 @@ class PolisClusteringResult:
2222
kmeans (PolisKMeans): Scikit-Learn KMeans object for selected group count, including `labels_` and `cluster_centers_`. See `PolisKMeans`.
2323
group_aware_consensus (DataFrame): Group-aware consensus scores for each statement.
2424
group_comment_stats (DataFrame): A multi-index dataframes for each statement, indexed by group ID and statement.
25+
statements_df (DataFrame): A dataframe with all intermediary and final statement data/calculations/metadata.
26+
participants_df (DataFrame): A dataframe with all intermediary and final participant data/calculations/metadata.
2527
"""
2628
raw_vote_matrix: DataFrame
2729
filtered_vote_matrix: DataFrame
@@ -31,10 +33,13 @@ class PolisClusteringResult:
3133
kmeans: PolisKMeans | None
3234
group_aware_consensus: DataFrame
3335
group_comment_stats: DataFrame
36+
statements_df: DataFrame
37+
participants_df: DataFrame
3438

3539
def run_clustering(
3640
votes: list[dict],
3741
mod_out_statement_ids: list[int] = [],
42+
meta_statement_ids: list[int] = [],
3843
min_user_vote_threshold: int = 7,
3944
keep_participant_ids: list[int] = [],
4045
init_centers: Optional[list[list[float]]] = None,
@@ -53,6 +58,7 @@ def run_clustering(
5358
Args:
5459
votes (list[dict]): Raw list of vote dicts, with keys for "participant_id", "statement_id", "vote" and "modified"
5560
mod_out_statement_ids (list[int]): List of statement IDs to moderate/zero out
61+
meta_statement_ids (list[int]): List of meta statement IDs
5662
min_user_vote_threshold (int): Minimum number of votes a participant must make to be included in clustering
5763
keep_participant_ids (list[int]): List of participant IDs to keep in clustering algorithm, regardless of normal filters.
5864
max_group_count (): Max number of group (k-values) to test using k-means and silhouette scores
@@ -70,42 +76,69 @@ def run_clustering(
7076
mod_out_statement_ids=mod_out_statement_ids,
7177
)
7278

73-
projected_participants, projected_statements, pca = run_pca(vote_matrix=filtered_vote_matrix)
79+
# Run PCA and generate participant/statement projections.
80+
# DataFrames each have "x" and "y" columns.
81+
participants_df, statements_df, pca = run_pca(vote_matrix=filtered_vote_matrix)
7482

75-
participant_ids_clusterable = get_clusterable_participant_ids(raw_vote_matrix, vote_threshold=min_user_vote_threshold)
83+
participant_ids_to_cluster = get_clusterable_participant_ids(raw_vote_matrix, vote_threshold=min_user_vote_threshold)
7684
if keep_participant_ids:
77-
participant_ids_clusterable = list(set(participant_ids_clusterable + keep_participant_ids))
85+
# TODO: Make this an intersection, in case there are members of
86+
# keep_participant_ids list that aren't represented in vote_matrix.
87+
participant_ids_to_cluster = sorted(list(set(participant_ids_to_cluster + keep_participant_ids)))
7888

7989
if force_group_count:
8090
k_bounds = [force_group_count, force_group_count]
8191
else:
8292
k_bounds = [2, max_group_count]
8393

84-
projected_participants_clusterable = projected_participants.loc[participant_ids_clusterable, :]
8594
_, _, kmeans = find_optimal_k(
86-
projected_data=projected_participants_clusterable,
95+
projected_data=participants_df.loc[participant_ids_to_cluster, :],
8796
k_bounds=k_bounds,
8897
# Force polis strategy of initiating cluster centers. See: PolisKMeans.
8998
init="polis",
9099
init_centers=init_centers,
91100
random_state=random_state,
92101
)
93-
projected_participants_clusterable = projected_participants_clusterable.assign(
94-
cluster_id=kmeans.labels_ if kmeans else None,
102+
label_series = pd.Series(
103+
kmeans.labels_ if kmeans else None,
104+
index=participant_ids_to_cluster,
105+
dtype="Int64", # Allows nullable/NaN values.
95106
)
107+
participants_df["to_cluster"] = participants_df.index.isin(participant_ids_to_cluster)
108+
participants_df["cluster_id"] = label_series
96109

97110
grouped_stats_df, gac_df = calculate_comment_statistics_dataframes(
98-
vote_matrix=raw_vote_matrix.loc[participant_ids_clusterable, :],
111+
vote_matrix=raw_vote_matrix.loc[participant_ids_to_cluster, :],
99112
cluster_labels=kmeans.labels_,
100113
)
101114

115+
def get_with_default(lst, idx, default=None):
116+
try:
117+
return lst[idx]
118+
except IndexError:
119+
return default
120+
121+
statements_df["to_zero"] = statements_df.index.isin(mod_out_statement_ids)
122+
statements_df["is_meta"] = statements_df.index.isin(meta_statement_ids)
123+
statements_df["mean"] = pca.mean_
124+
statements_df["pc1"] = get_with_default(pca.components_, 0)
125+
statements_df["pc2"] = get_with_default(pca.components_, 1)
126+
statements_df["pc3"] = get_with_default(pca.components_, 2)
127+
statements_df = pd.concat([statements_df, gac_df], axis=1)
128+
statements_df = populate_priority_calculations_into_statements_df(
129+
statements_df=statements_df,
130+
vote_matrix=raw_vote_matrix.loc[participant_ids_to_cluster, :],
131+
)
132+
102133
return PolisClusteringResult(
103134
raw_vote_matrix=raw_vote_matrix,
104135
filtered_vote_matrix=filtered_vote_matrix,
105136
pca=pca,
106-
projected_participants=projected_participants_clusterable,
107-
projected_statements=projected_statements,
137+
projected_participants=participants_df.loc[participant_ids_to_cluster, ["x", "y", "cluster_id"]], # deprecate?
138+
projected_statements=statements_df.loc[:, ["x", "y"]], # deprecate?
108139
kmeans=kmeans,
109-
group_aware_consensus=gac_df,
140+
group_aware_consensus=gac_df, # deprecate?
110141
group_comment_stats=grouped_stats_df,
142+
statements_df=statements_df,
143+
participants_df=participants_df,
111144
)

reddwarf/sklearn/pipeline.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from sklearn.pipeline import Pipeline
2+
3+
4+
class PatchedPipeline(Pipeline):
5+
"""
6+
A subclass of sklearn's Pipeline that injects a `_parent_pipeline` attribute into each step.
7+
8+
This allows individual transformers in the pipeline to access their parent pipeline and,
9+
by extension, other steps within it. Useful for custom transformers that depend on
10+
intermediate results from earlier steps (e.g., SparsityAwareScaler using SparsityAwareCapturer output).
11+
12+
Example:
13+
```
14+
pipeline = PatchedPipeline([
15+
("capture", SparsityAwareCapturer()),
16+
("scale", SparsityAwareScaler(capture_step="capture")),
17+
])
18+
19+
# Inside SparsityAwareScaler.transform():
20+
# capture_step = self._parent_pipeline.named_steps["capture"]
21+
# X_sparse = capture_step.X_transformed_
22+
```
23+
24+
Note:
25+
- Steps must support attribute assignment (`__dict__`) to receive the reference.
26+
- `_parent_pipeline` is injected once during initialization.
27+
"""
28+
def __init__(self, steps, **kwargs):
29+
super().__init__(steps, **kwargs)
30+
self._patch_steps()
31+
32+
def _patch_steps(self):
33+
for _, step in self.steps:
34+
if hasattr(step, '__dict__'):
35+
step._parent_pipeline = self

reddwarf/sklearn/transformers.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
Array2D = NDArray[np.float64]
88

99
def calculate_scaling_factors(X_sparse: Array1D | Array2D) -> Array1D:
10+
"""
11+
Calculate row-based scaling factors from the sparse vote matrix.
12+
13+
(Outside estimator so available for re-use.)
14+
"""
1015
# This allows function to work for 2D (full vote_matrix) and 1D (participant_votes).
1116
# It essentially nest an 1D matrix in a 2D one.
1217
X_sparse = np.atleast_2d(X_sparse)
@@ -29,14 +34,16 @@ def calculate_scaling_factors(X_sparse: Array1D | Array2D) -> Array1D:
2934

3035
class SparsityAwareScaler(BaseEstimator, TransformerMixin):
3136
"""
32-
Scale projected points (participant/statements) based on sparsity of vote
37+
Scale projected points (participant or statements) based on sparsity of vote
3338
matrix, to account for any small number of votes by a participant and
3439
prevent those participants from bunching up in the center.
3540
3641
Attributes:
42+
capture_step (str | int | None): Name or index of the capture step in the pipeline.
3743
X_sparse (np.ndarray | None): A sparse array with shape (n_features,)
3844
"""
39-
def __init__(self, X_sparse: Optional[Array1D | Array2D] = None):
45+
def __init__(self, capture_step: Optional[str | int] = None, X_sparse: Optional[Array1D | Array2D] = None):
46+
self.capture_step = capture_step
4047
self.X_sparse = X_sparse
4148

4249
# See: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Tags.html#sklearn.utils.Tags
@@ -57,10 +64,53 @@ def inverse_transform(self, X):
5764
scaling_factors = self._calculate_scaling_factors()
5865
return X / scaling_factors[:, np.newaxis]
5966

60-
def _calculate_scaling_factors(self):
61-
if self.X_sparse is None:
67+
68+
def _get_pipeline_step(self, step):
69+
"""
70+
Fetch the parent pipeline when available via PatchedPipeline usage.
71+
"""
72+
parent = getattr(self, "_parent_pipeline", None)
73+
if parent is None:
74+
raise RuntimeError(
75+
f"{self.__class__.__name__} cannot resolve `capture_step={step}` "
76+
"because it is not being used inside a `PatchedPipeline`. "
77+
"Either use a `PatchedPipeline` or pass `X_sparse` directly."
78+
)
79+
if isinstance(step, str):
80+
return parent.named_steps[step]
81+
elif isinstance(step, int):
82+
return parent.steps[step][1]
83+
else:
84+
raise ValueError("`capture_step` must be a string (name) or int (index).")
85+
86+
def _resolve_X_sparse(self):
87+
"""
88+
Resolve X_sparse (a sparse vote matrix) from argument or prior capture step.
89+
"""
90+
if self.X_sparse is not None:
91+
return self.X_sparse
92+
93+
capture = self._get_pipeline_step(self.capture_step)
94+
if not hasattr(capture, "X_captured_"):
6295
raise AttributeError(
63-
"Missing `X_sparse`. Pass `X_sparse` when initializing SparsityAwareScaler."
96+
f"Step '{self.capture_step}' does not contain `.X_captured_`. "
97+
f"Did you run `fit/transform` on the pipeline?"
6498
)
99+
return capture.X_captured_
65100

66-
return calculate_scaling_factors(X_sparse=self.X_sparse)
101+
def _calculate_scaling_factors(self):
102+
X_sparse = self._resolve_X_sparse()
103+
return calculate_scaling_factors(X_sparse=X_sparse)
104+
105+
class SparsityAwareCapturer(BaseEstimator, TransformerMixin):
106+
"""
107+
A passthrough transformer that captures and stores the X it receives in
108+
`self.X_captured_`. Useful in pipelines where a later step needs access to
109+
this intermediate result.
110+
"""
111+
def fit(self, X, y=None):
112+
return self
113+
114+
def transform(self, X):
115+
self.X_captured_ = X # Store the actual input value
116+
return X

reddwarf/utils/matrix.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,6 @@
77

88
VoteMatrix: TypeAlias = pd.DataFrame
99

10-
def impute_missing_votes(vote_matrix: VoteMatrix) -> VoteMatrix:
11-
"""
12-
Imputes missing votes in a voting matrix using column-wise mean. All columns must have at least one vote.
13-
14-
Reference:
15-
Small, C. (2021). "Polis: Scaling Deliberation by Mapping High Dimensional Opinion Spaces."
16-
Specific highlight: <https://hyp.is/8zUyWM5fEe-uIO-J34vbkg/gwern.net/doc/sociology/2021-small.pdf>
17-
18-
Args:
19-
vote_matrix (pd.DataFrame): A vote matrix DataFrame with `NaN`/`None` values where: \
20-
1. rows are voters, \
21-
2. columns are statements, and \
22-
3. values are votes.
23-
24-
Returns:
25-
imputed_matrix (pd.DataFrame): The same vote matrix DataFrame imputing missing values with column mean.
26-
"""
27-
if vote_matrix.isna().all(axis="rows").any():
28-
raise RedDwarfError("impute_missing_votes does not support vote matrices containing statement columns with no votes.")
29-
30-
mean_imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
31-
imputed_matrix = pd.DataFrame(
32-
mean_imputer.fit_transform(vote_matrix),
33-
columns=vote_matrix.columns,
34-
index=vote_matrix.index,
35-
)
36-
return imputed_matrix
37-
3810
def filter_votes(
3911
votes: List[Dict],
4012
cutoff: Optional[int] = None,
@@ -232,4 +204,21 @@ def filter_matrix(
232204
elif unvoted_filter_type == 'zero':
233205
vote_matrix[unvoted_statement_ids] = 0
234206

235-
return vote_matrix
207+
return vote_matrix
208+
209+
def generate_virtual_vote_matrix(n_statements: int):
210+
"""
211+
Creates a matrix of virtual participants, each of whom vote agree on a
212+
single statement, with no other votes. (This is a variation of an "identity
213+
matrix", with votes going across the diagonal of a full NaN matrix.)
214+
"""
215+
# Build an basic identity matrix
216+
virtual_vote_matrix = np.eye(n_statements)
217+
218+
# Replace 1s with +1 and 0s with NaN
219+
# TODO: Why does Polis use -1 (disagree) here? is it the same? BUG?
220+
AGREE_VAL = 1
221+
MISSING_VAL = np.nan
222+
virtual_vote_matrix = np.where(virtual_vote_matrix == 1, AGREE_VAL, MISSING_VAL)
223+
224+
return virtual_vote_matrix

0 commit comments

Comments
 (0)