Skip to content

Commit 71b2a63

Browse files
committed
feat(polis): add typed reusable APIs
Add typed success and insufficient-data outcomes for Polis and Agora clustering callers, plus reusable PCA projection and forced-k KMeans helpers so consumers can evaluate multiple group-count candidates without rebuilding the projection. Expose singleton-aware silhouette scoring through the k-means helper and cover the typed, Agora, and projection-reuse paths with focused tests.
1 parent d91cd23 commit 71b2a63

7 files changed

Lines changed: 522 additions & 9 deletions

File tree

reddwarf/implementations/agora.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class AgoraClusteringResult:
4343
ranked_consensus: RankedConsensusResult
4444

4545

46+
TypedAgoraClusteringResult = base.AnalysisSuccess[AgoraClusteringResult] | base.AnalysisInsufficientData
47+
48+
4649
def compute_effective_agreement_gac(
4750
grouped_stats_df: pd.DataFrame,
4851
statement_ids,
@@ -132,3 +135,22 @@ def run_pipeline(
132135
ranked_repness=ranked_repness,
133136
ranked_consensus=ranked_consensus,
134137
)
138+
139+
140+
def run_pipeline_typed(**kwargs) -> TypedAgoraClusteringResult:
141+
reason = base.get_insufficient_data_reason(
142+
votes=kwargs["votes"],
143+
mod_out_statement_ids=kwargs.get("mod_out_statement_ids", []),
144+
min_user_vote_threshold=kwargs.get("min_user_vote_threshold", 7),
145+
keep_participant_ids=kwargs.get("keep_participant_ids", []),
146+
force_group_count=kwargs.get("force_group_count"),
147+
)
148+
if reason is not None:
149+
return base.AnalysisInsufficientData(
150+
outcome=base.AnalysisOutcome.INSUFFICIENT_DATA,
151+
reason=reason,
152+
)
153+
return base.AnalysisSuccess(
154+
outcome=base.AnalysisOutcome.SUCCESS,
155+
result=run_pipeline(**kwargs),
156+
)

reddwarf/implementations/base.py

Lines changed: 306 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Optional, Literal
1+
from enum import Enum
2+
from typing import Generic, Optional, Literal, TypeVar
23
from dataclasses import dataclass
4+
import numpy as np
35
import pandas as pd
46
from pandas import DataFrame
57
from sklearn.decomposition import PCA
68
from reddwarf.types.polis import PolisRepness
79
from reddwarf.utils.clusterer.base import run_clusterer
10+
from reddwarf.utils.clusterer.kmeans import calculate_kmeans_silhouette_score
811
from reddwarf.utils.consensus import select_consensus_statements, ConsensusResult
912
from reddwarf.utils.matrix import (
1013
generate_raw_matrix,
@@ -20,6 +23,45 @@
2023
)
2124

2225

26+
T = TypeVar("T")
27+
28+
29+
class AnalysisOutcome(str, Enum):
30+
SUCCESS = "success"
31+
INSUFFICIENT_DATA = "insufficient_data"
32+
33+
34+
class InsufficientDataReason(str, Enum):
35+
EMPTY_VOTE_MATRIX = "empty_vote_matrix"
36+
NOT_ENOUGH_CLUSTERABLE_PARTICIPANTS = "not_enough_clusterable_participants"
37+
NOT_ENOUGH_UNIQUE_POINTS = "not_enough_unique_points"
38+
NOT_ENOUGH_SAMPLES_FOR_GROUP_COUNT = "not_enough_samples_for_group_count"
39+
40+
41+
@dataclass(frozen=True)
42+
class AnalysisSuccess(Generic[T]):
43+
outcome: Literal[AnalysisOutcome.SUCCESS]
44+
result: T
45+
46+
47+
@dataclass(frozen=True)
48+
class AnalysisInsufficientData:
49+
outcome: Literal[AnalysisOutcome.INSUFFICIENT_DATA]
50+
reason: InsufficientDataReason
51+
52+
53+
@dataclass
54+
class PcaProjectionResult:
55+
raw_vote_matrix: DataFrame
56+
filtered_vote_matrix: DataFrame
57+
reducer: ReducerModel
58+
participants_df: DataFrame
59+
statements_df: DataFrame
60+
participant_ids_to_cluster: list[int]
61+
participant_projections: dict
62+
statement_projections: Optional[dict]
63+
64+
2365
@dataclass
2466
class PolisClusteringResult:
2567
"""
@@ -52,6 +94,269 @@ class PolisClusteringResult:
5294
consensus: ConsensusResult
5395
repness: PolisRepness
5496

97+
98+
TypedPolisClusteringResult = (
99+
AnalysisSuccess[PolisClusteringResult] | AnalysisInsufficientData
100+
)
101+
102+
103+
def get_insufficient_data_reason(
104+
*,
105+
votes: list[dict],
106+
mod_out_statement_ids: list[int] | None = None,
107+
min_user_vote_threshold: int = 7,
108+
keep_participant_ids: list[int] | None = None,
109+
force_group_count: Optional[int] = None,
110+
) -> InsufficientDataReason | None:
111+
"""Return a typed insufficient-data reason, or None when compute can proceed."""
112+
if len(votes) == 0:
113+
return InsufficientDataReason.EMPTY_VOTE_MATRIX
114+
115+
mod_out_statement_ids = mod_out_statement_ids or []
116+
keep_participant_ids = keep_participant_ids or []
117+
raw_vote_matrix = generate_raw_matrix(votes=votes)
118+
filtered_vote_matrix = simple_filter_matrix(
119+
vote_matrix=raw_vote_matrix,
120+
mod_out_statement_ids=mod_out_statement_ids,
121+
)
122+
participant_ids_to_cluster = get_clusterable_participant_ids(
123+
raw_vote_matrix,
124+
vote_threshold=min_user_vote_threshold,
125+
)
126+
if keep_participant_ids:
127+
keep_participant_ids_existing = filtered_vote_matrix.index.intersection(
128+
keep_participant_ids,
129+
).to_list()
130+
participant_ids_to_cluster = sorted(
131+
list(set(participant_ids_to_cluster + keep_participant_ids_existing))
132+
)
133+
134+
if len(participant_ids_to_cluster) < 2:
135+
return InsufficientDataReason.NOT_ENOUGH_CLUSTERABLE_PARTICIPANTS
136+
if force_group_count is not None and len(participant_ids_to_cluster) < force_group_count:
137+
return InsufficientDataReason.NOT_ENOUGH_SAMPLES_FOR_GROUP_COUNT
138+
139+
clusterable_matrix = filtered_vote_matrix.loc[participant_ids_to_cluster, :].fillna(0)
140+
if len(np.unique(clusterable_matrix.values, axis=0)) < 2:
141+
return InsufficientDataReason.NOT_ENOUGH_UNIQUE_POINTS
142+
143+
return None
144+
145+
146+
def prepare_pca_projection(
147+
*,
148+
votes: list[dict],
149+
reducer_kwargs: dict | None = None,
150+
mod_out_statement_ids: list[int] | None = None,
151+
meta_statement_ids: list[int] | None = None,
152+
min_user_vote_threshold: int = 7,
153+
keep_participant_ids: list[int] | None = None,
154+
random_state: Optional[int] = None,
155+
) -> PcaProjectionResult:
156+
"""Build the vote matrix, filtering, and PCA projection once for reuse across k candidates."""
157+
reducer_kwargs = reducer_kwargs or {}
158+
mod_out_statement_ids = mod_out_statement_ids or []
159+
meta_statement_ids = meta_statement_ids or []
160+
keep_participant_ids = keep_participant_ids or []
161+
162+
raw_vote_matrix = generate_raw_matrix(votes=votes)
163+
filtered_vote_matrix = simple_filter_matrix(
164+
vote_matrix=raw_vote_matrix,
165+
mod_out_statement_ids=mod_out_statement_ids,
166+
)
167+
X_participants, X_statements, reducer_model = run_reducer(
168+
vote_matrix=filtered_vote_matrix.values,
169+
reducer="pca",
170+
random_state=random_state,
171+
**reducer_kwargs,
172+
)
173+
participants_df = pd.DataFrame(
174+
X_participants,
175+
columns=pd.Index(["x", "y"]),
176+
index=filtered_vote_matrix.index,
177+
)
178+
participant_ids_to_cluster = get_clusterable_participant_ids(
179+
raw_vote_matrix,
180+
vote_threshold=min_user_vote_threshold,
181+
)
182+
if keep_participant_ids:
183+
keep_participant_ids_existing = participants_df.index.intersection(
184+
keep_participant_ids,
185+
).to_list()
186+
participant_ids_to_cluster = sorted(
187+
list(set(participant_ids_to_cluster + keep_participant_ids_existing))
188+
)
189+
participants_df["to_cluster"] = participants_df.index.isin(participant_ids_to_cluster)
190+
191+
statements_df = pd.DataFrame(
192+
X_statements,
193+
columns=pd.Index(["x", "y"]),
194+
index=filtered_vote_matrix.columns,
195+
)
196+
statements_df["to_zero"] = statements_df.index.isin(mod_out_statement_ids)
197+
statements_df["is_meta"] = statements_df.index.isin(meta_statement_ids)
198+
if isinstance(reducer_model, PCA):
199+
pca = reducer_model
200+
201+
def get_with_default(lst, idx, default=None):
202+
try:
203+
return lst[idx]
204+
except IndexError:
205+
return default
206+
207+
statements_df["mean"] = pca.mean_
208+
statements_df["pc1"] = get_with_default(pca.components_, 0)
209+
statements_df["pc2"] = get_with_default(pca.components_, 1)
210+
statements_df["pc3"] = get_with_default(pca.components_, 2)
211+
statements_df = populate_priority_calculations_into_statements_df(
212+
statements_df=statements_df,
213+
vote_matrix=raw_vote_matrix.loc[participant_ids_to_cluster, :],
214+
)
215+
216+
participant_projections = dict(zip(filtered_vote_matrix.index, X_participants))
217+
statement_projections = (
218+
dict(zip(filtered_vote_matrix.columns, X_statements))
219+
if X_statements is not None
220+
else None
221+
)
222+
223+
return PcaProjectionResult(
224+
raw_vote_matrix=raw_vote_matrix,
225+
filtered_vote_matrix=filtered_vote_matrix,
226+
reducer=reducer_model,
227+
participants_df=participants_df,
228+
statements_df=statements_df,
229+
participant_ids_to_cluster=participant_ids_to_cluster,
230+
participant_projections=participant_projections,
231+
statement_projections=statement_projections,
232+
)
233+
234+
235+
def _build_clustering_result_from_projection(
236+
*,
237+
projection: PcaProjectionResult,
238+
clusterer_model: ClustererModel | None,
239+
mod_out_statement_ids: list[int],
240+
pick_max: int,
241+
confidence: float,
242+
consensus_mode: Literal["standard", "legacy"],
243+
) -> PolisClusteringResult:
244+
cluster_labels = clusterer_model.labels_ if clusterer_model else None
245+
participants_df = projection.participants_df.copy()
246+
label_series = pd.Series(
247+
cluster_labels,
248+
index=projection.participant_ids_to_cluster,
249+
dtype="Int64",
250+
)
251+
participants_df["cluster_id"] = label_series
252+
253+
grouped_stats_df, gac_df = calculate_comment_statistics_dataframes(
254+
vote_matrix=projection.raw_vote_matrix.loc[projection.participant_ids_to_cluster, :],
255+
cluster_labels=cluster_labels,
256+
consensus_mode=consensus_mode,
257+
)
258+
statements_df = pd.concat([projection.statements_df.copy(), gac_df], axis=1)
259+
group_aware_consensus = {
260+
"agree": statements_df["group-aware-consensus-agree"].to_dict(),
261+
"disagree": statements_df["group-aware-consensus-disagree"].to_dict(),
262+
}
263+
consensus = select_consensus_statements(
264+
vote_matrix=projection.raw_vote_matrix,
265+
mod_out_statement_ids=mod_out_statement_ids,
266+
pick_max=pick_max,
267+
confidence=confidence,
268+
prob_threshold=0.5,
269+
)
270+
repness = select_representative_statements(
271+
grouped_stats_df=grouped_stats_df,
272+
mod_out_statement_ids=mod_out_statement_ids,
273+
pick_max=pick_max,
274+
confidence=confidence,
275+
)
276+
277+
return PolisClusteringResult(
278+
participant_projections=projection.participant_projections,
279+
statement_projections=projection.statement_projections,
280+
group_aware_consensus=group_aware_consensus,
281+
consensus=consensus,
282+
repness=repness,
283+
raw_vote_matrix=projection.raw_vote_matrix,
284+
filtered_vote_matrix=projection.filtered_vote_matrix,
285+
reducer=projection.reducer,
286+
clusterer=clusterer_model,
287+
group_comment_stats=grouped_stats_df,
288+
statements_df=statements_df,
289+
participants_df=participants_df,
290+
)
291+
292+
293+
def run_kmeans_on_pca_projection(
294+
*,
295+
projection: PcaProjectionResult,
296+
force_group_count: int,
297+
init_centers: Optional[list[list[float]]] = None,
298+
random_state: Optional[int] = None,
299+
mod_out_statement_ids: list[int] | None = None,
300+
pick_max: int = 5,
301+
confidence: float = 0.9,
302+
consensus_mode: Literal["standard", "legacy"] = "standard",
303+
) -> PolisClusteringResult:
304+
"""Run one forced-k k-means candidate from an already prepared PCA projection."""
305+
clusterer_model = run_clusterer(
306+
clusterer="kmeans",
307+
X_participants_clusterable=projection.participants_df.loc[
308+
projection.participant_ids_to_cluster,
309+
["x", "y"],
310+
].values,
311+
max_group_count=force_group_count,
312+
force_group_count=force_group_count,
313+
init_centers=init_centers,
314+
random_state=random_state,
315+
)
316+
return _build_clustering_result_from_projection(
317+
projection=projection,
318+
clusterer_model=clusterer_model,
319+
mod_out_statement_ids=mod_out_statement_ids or [],
320+
pick_max=pick_max,
321+
confidence=confidence,
322+
consensus_mode=consensus_mode,
323+
)
324+
325+
326+
def calculate_projection_silhouette_score(
327+
*,
328+
projection: PcaProjectionResult,
329+
clusterer_model: ClustererModel,
330+
) -> float | None:
331+
"""Calculate the candidate silhouette score from a prepared PCA projection."""
332+
return calculate_kmeans_silhouette_score(
333+
X_to_cluster=projection.participants_df.loc[
334+
projection.participant_ids_to_cluster,
335+
["x", "y"],
336+
].values,
337+
labels=clusterer_model.labels_,
338+
)
339+
340+
341+
def run_pipeline_typed(**kwargs) -> TypedPolisClusteringResult:
342+
reason = get_insufficient_data_reason(
343+
votes=kwargs["votes"],
344+
mod_out_statement_ids=kwargs.get("mod_out_statement_ids", []),
345+
min_user_vote_threshold=kwargs.get("min_user_vote_threshold", 7),
346+
keep_participant_ids=kwargs.get("keep_participant_ids", []),
347+
force_group_count=kwargs.get("force_group_count"),
348+
)
349+
if reason is not None:
350+
return AnalysisInsufficientData(
351+
outcome=AnalysisOutcome.INSUFFICIENT_DATA,
352+
reason=reason,
353+
)
354+
return AnalysisSuccess(
355+
outcome=AnalysisOutcome.SUCCESS,
356+
result=run_pipeline(**kwargs),
357+
)
358+
359+
55360
def run_pipeline(
56361
votes: list[dict],
57362
reducer: ReducerType = "pca",

reddwarf/implementations/polis.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ def run_clustering(**kwargs) -> base.PolisClusteringResult:
77
return run_pipeline(**kwargs)
88

99

10+
def run_clustering_typed(**kwargs) -> base.TypedPolisClusteringResult:
11+
return run_pipeline_typed(**kwargs)
12+
13+
1014
def run_pipeline(**kwargs) -> base.PolisClusteringResult:
1115
kwargs = {
1216
"reducer": "pca",
@@ -15,3 +19,13 @@ def run_pipeline(**kwargs) -> base.PolisClusteringResult:
1519
**kwargs,
1620
}
1721
return base.run_pipeline(**kwargs)
22+
23+
24+
def run_pipeline_typed(**kwargs) -> base.TypedPolisClusteringResult:
25+
kwargs = {
26+
"reducer": "pca",
27+
"clusterer": "kmeans",
28+
"consensus_mode": "legacy",
29+
**kwargs,
30+
}
31+
return base.run_pipeline_typed(**kwargs)

0 commit comments

Comments
 (0)