Skip to content

Commit ad01a7a

Browse files
committed
feat(polis): type pipeline outcomes
Make run_pipeline return explicit success or insufficient-data outcomes directly instead of a separate typed endpoint. Add candidate_group_counts for opt-in multi-k k-means results while preserving force_group_count for single forced-k runs.
1 parent 71b2a63 commit ad01a7a

9 files changed

Lines changed: 437 additions & 86 deletions

File tree

docs/api_reference.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,34 @@ use in Scikit-Learn workflows, pipelines, and APIs.
149149

150150
## Types
151151

152+
### ::: reddwarf.implementations.base.AnalysisSuccess
153+
options:
154+
show_root_heading: true
155+
156+
### ::: reddwarf.implementations.base.AnalysisInsufficientData
157+
options:
158+
show_root_heading: true
159+
160+
### ::: reddwarf.implementations.base.InsufficientDataReason
161+
options:
162+
show_root_heading: true
163+
152164
### ::: reddwarf.implementations.base.PolisClusteringResult
153165
options:
154166
show_root_heading: true
155167

168+
### ::: reddwarf.implementations.base.KMeansCandidatesResult
169+
options:
170+
show_root_heading: true
171+
172+
### ::: reddwarf.implementations.base.KMeansCandidateSuccess
173+
options:
174+
show_root_heading: true
175+
176+
### ::: reddwarf.implementations.base.KMeansCandidateInsufficientData
177+
options:
178+
show_root_heading: true
179+
156180
### ::: reddwarf.types.agora.RankedRepnessStatement
157181
options:
158182
show_root_heading: true

reddwarf/implementations/agora.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ class AgoraClusteringResult:
4444

4545

4646
TypedAgoraClusteringResult = base.AnalysisSuccess[AgoraClusteringResult] | base.AnalysisInsufficientData
47+
TypedAgoraKMeansCandidatesResult = (
48+
base.AnalysisSuccess[base.KMeansCandidatesResult[AgoraClusteringResult]]
49+
| base.AnalysisInsufficientData
50+
)
51+
TypedAgoraPipelineResult = TypedAgoraClusteringResult | TypedAgoraKMeansCandidatesResult
4752

4853

4954
def compute_effective_agreement_gac(
@@ -80,34 +85,21 @@ def compute_effective_agreement_gac(
8085
return {"agree": agree, "disagree": disagree}
8186

8287

83-
def run_pipeline(
84-
fdr_rate: float = 0.10,
85-
**kwargs,
88+
def _build_agora_result(
89+
*,
90+
base_result: base.PolisClusteringResult,
91+
mod_out_statement_ids: list[int],
92+
fdr_rate: float,
8693
) -> AgoraClusteringResult:
87-
"""
88-
Agora clustering pipeline. Runs the base pipeline and adds ranked
89-
representative/consensus statements with Benjamini-Hochberg selection.
90-
91-
Accepts all the same arguments as base.run_pipeline(), plus:
92-
93-
Args:
94-
fdr_rate (float): False discovery rate for Benjamini-Hochberg selection.
95-
**kwargs: All arguments forwarded to base.run_pipeline().
96-
97-
Returns:
98-
AgoraClusteringResult: Clustering results with ranked statement outputs.
99-
"""
100-
base_result = base.run_pipeline(**kwargs)
101-
10294
ranked_repness = rank_representative_statements(
10395
grouped_stats_df=base_result.group_comment_stats,
104-
mod_out_statement_ids=kwargs.get("mod_out_statement_ids", []),
96+
mod_out_statement_ids=mod_out_statement_ids,
10597
fdr_rate=fdr_rate,
10698
)
10799

108100
ranked_consensus = rank_consensus_statements(
109101
vote_matrix=base_result.raw_vote_matrix,
110-
mod_out_statement_ids=kwargs.get("mod_out_statement_ids", []),
102+
mod_out_statement_ids=mod_out_statement_ids,
111103
fdr_rate=fdr_rate,
112104
)
113105

@@ -137,20 +129,63 @@ def run_pipeline(
137129
)
138130

139131

140-
def run_pipeline_typed(**kwargs) -> TypedAgoraClusteringResult:
141-
reason = base.get_insufficient_data_reason(
142-
votes=kwargs["votes"],
143-
mod_out_statement_ids=kwargs.get("mod_out_statement_ids", []),
144-
min_user_vote_threshold=kwargs.get("min_user_vote_threshold", 7),
145-
keep_participant_ids=kwargs.get("keep_participant_ids", []),
146-
force_group_count=kwargs.get("force_group_count"),
147-
)
148-
if reason is not None:
149-
return base.AnalysisInsufficientData(
150-
outcome=base.AnalysisOutcome.INSUFFICIENT_DATA,
151-
reason=reason,
132+
def run_pipeline(
133+
fdr_rate: float = 0.10,
134+
**kwargs,
135+
) -> TypedAgoraPipelineResult:
136+
"""
137+
Agora clustering pipeline. Runs the base pipeline and adds ranked
138+
representative/consensus statements with Benjamini-Hochberg selection.
139+
140+
Accepts all the same arguments as base.run_pipeline(), plus:
141+
142+
Args:
143+
fdr_rate (float): False discovery rate for Benjamini-Hochberg selection.
144+
**kwargs: All arguments forwarded to base.run_pipeline().
145+
146+
Returns:
147+
AnalysisSuccess or AnalysisInsufficientData. On success, `result` contains either
148+
AgoraClusteringResult or KMeansCandidatesResult when candidate_group_counts is set.
149+
"""
150+
base_pipeline_result = base.run_pipeline(**kwargs)
151+
if base_pipeline_result.outcome == base.AnalysisOutcome.INSUFFICIENT_DATA:
152+
return base_pipeline_result
153+
154+
mod_out_statement_ids = kwargs.get("mod_out_statement_ids", [])
155+
if isinstance(base_pipeline_result.result, base.KMeansCandidatesResult):
156+
candidates: list[
157+
base.KMeansCandidateSuccess[AgoraClusteringResult]
158+
| base.KMeansCandidateInsufficientData
159+
] = []
160+
for candidate in base_pipeline_result.result.candidates:
161+
if candidate.outcome == base.AnalysisOutcome.INSUFFICIENT_DATA:
162+
candidates.append(candidate)
163+
continue
164+
candidates.append(
165+
base.KMeansCandidateSuccess(
166+
group_count=candidate.group_count,
167+
outcome=base.AnalysisOutcome.SUCCESS,
168+
silhouette_score=candidate.silhouette_score,
169+
result=_build_agora_result(
170+
base_result=candidate.result,
171+
mod_out_statement_ids=mod_out_statement_ids,
172+
fdr_rate=fdr_rate,
173+
),
174+
)
175+
)
176+
return base.AnalysisSuccess(
177+
outcome=base.AnalysisOutcome.SUCCESS,
178+
result=base.KMeansCandidatesResult(
179+
projection=base_pipeline_result.result.projection,
180+
candidates=candidates,
181+
),
152182
)
183+
153184
return base.AnalysisSuccess(
154185
outcome=base.AnalysisOutcome.SUCCESS,
155-
result=run_pipeline(**kwargs),
186+
result=_build_agora_result(
187+
base_result=base_pipeline_result.result,
188+
mod_out_statement_ids=mod_out_statement_ids,
189+
fdr_rate=fdr_rate,
190+
),
156191
)

0 commit comments

Comments
 (0)