@@ -44,6 +44,11 @@ class AgoraClusteringResult:
4444
4545
4646TypedAgoraClusteringResult = base .AnalysisSuccess [AgoraClusteringResult ] | base .AnalysisInsufficientData
47+ TypedAgoraKMeansCandidatesResult = (
48+ base .AnalysisSuccess [base .KMeansCandidatesResult [AgoraClusteringResult ]]
49+ | base .AnalysisInsufficientData
50+ )
51+ TypedAgoraPipelineResult = TypedAgoraClusteringResult | TypedAgoraKMeansCandidatesResult
4752
4853
4954def compute_effective_agreement_gac (
@@ -80,34 +85,21 @@ def compute_effective_agreement_gac(
8085 return {"agree" : agree , "disagree" : disagree }
8186
8287
83- def run_pipeline (
84- fdr_rate : float = 0.10 ,
85- ** kwargs ,
88+ def _build_agora_result (
89+ * ,
90+ base_result : base .PolisClusteringResult ,
91+ mod_out_statement_ids : list [int ],
92+ fdr_rate : float ,
8693) -> AgoraClusteringResult :
87- """
88- Agora clustering pipeline. Runs the base pipeline and adds ranked
89- representative/consensus statements with Benjamini-Hochberg selection.
90-
91- Accepts all the same arguments as base.run_pipeline(), plus:
92-
93- Args:
94- fdr_rate (float): False discovery rate for Benjamini-Hochberg selection.
95- **kwargs: All arguments forwarded to base.run_pipeline().
96-
97- Returns:
98- AgoraClusteringResult: Clustering results with ranked statement outputs.
99- """
100- base_result = base .run_pipeline (** kwargs )
101-
10294 ranked_repness = rank_representative_statements (
10395 grouped_stats_df = base_result .group_comment_stats ,
104- mod_out_statement_ids = kwargs . get ( " mod_out_statement_ids" , []) ,
96+ mod_out_statement_ids = mod_out_statement_ids ,
10597 fdr_rate = fdr_rate ,
10698 )
10799
108100 ranked_consensus = rank_consensus_statements (
109101 vote_matrix = base_result .raw_vote_matrix ,
110- mod_out_statement_ids = kwargs . get ( " mod_out_statement_ids" , []) ,
102+ mod_out_statement_ids = mod_out_statement_ids ,
111103 fdr_rate = fdr_rate ,
112104 )
113105
@@ -137,20 +129,63 @@ def run_pipeline(
137129 )
138130
139131
140- def run_pipeline_typed (** kwargs ) -> TypedAgoraClusteringResult :
141- reason = base .get_insufficient_data_reason (
142- votes = kwargs ["votes" ],
143- mod_out_statement_ids = kwargs .get ("mod_out_statement_ids" , []),
144- min_user_vote_threshold = kwargs .get ("min_user_vote_threshold" , 7 ),
145- keep_participant_ids = kwargs .get ("keep_participant_ids" , []),
146- force_group_count = kwargs .get ("force_group_count" ),
147- )
148- if reason is not None :
149- return base .AnalysisInsufficientData (
150- outcome = base .AnalysisOutcome .INSUFFICIENT_DATA ,
151- reason = reason ,
132+ def run_pipeline (
133+ fdr_rate : float = 0.10 ,
134+ ** kwargs ,
135+ ) -> TypedAgoraPipelineResult :
136+ """
137+ Agora clustering pipeline. Runs the base pipeline and adds ranked
138+ representative/consensus statements with Benjamini-Hochberg selection.
139+
140+ Accepts all the same arguments as base.run_pipeline(), plus:
141+
142+ Args:
143+ fdr_rate (float): False discovery rate for Benjamini-Hochberg selection.
144+ **kwargs: All arguments forwarded to base.run_pipeline().
145+
146+ Returns:
147+ AnalysisSuccess or AnalysisInsufficientData. On success, `result` contains either
148+ AgoraClusteringResult or KMeansCandidatesResult when candidate_group_counts is set.
149+ """
150+ base_pipeline_result = base .run_pipeline (** kwargs )
151+ if base_pipeline_result .outcome == base .AnalysisOutcome .INSUFFICIENT_DATA :
152+ return base_pipeline_result
153+
154+ mod_out_statement_ids = kwargs .get ("mod_out_statement_ids" , [])
155+ if isinstance (base_pipeline_result .result , base .KMeansCandidatesResult ):
156+ candidates : list [
157+ base .KMeansCandidateSuccess [AgoraClusteringResult ]
158+ | base .KMeansCandidateInsufficientData
159+ ] = []
160+ for candidate in base_pipeline_result .result .candidates :
161+ if candidate .outcome == base .AnalysisOutcome .INSUFFICIENT_DATA :
162+ candidates .append (candidate )
163+ continue
164+ candidates .append (
165+ base .KMeansCandidateSuccess (
166+ group_count = candidate .group_count ,
167+ outcome = base .AnalysisOutcome .SUCCESS ,
168+ silhouette_score = candidate .silhouette_score ,
169+ result = _build_agora_result (
170+ base_result = candidate .result ,
171+ mod_out_statement_ids = mod_out_statement_ids ,
172+ fdr_rate = fdr_rate ,
173+ ),
174+ )
175+ )
176+ return base .AnalysisSuccess (
177+ outcome = base .AnalysisOutcome .SUCCESS ,
178+ result = base .KMeansCandidatesResult (
179+ projection = base_pipeline_result .result .projection ,
180+ candidates = candidates ,
181+ ),
152182 )
183+
153184 return base .AnalysisSuccess (
154185 outcome = base .AnalysisOutcome .SUCCESS ,
155- result = run_pipeline (** kwargs ),
186+ result = _build_agora_result (
187+ base_result = base_pipeline_result .result ,
188+ mod_out_statement_ids = mod_out_statement_ids ,
189+ fdr_rate = fdr_rate ,
190+ ),
156191 )
0 commit comments