11from typing import Optional
2- from numpy .typing import NDArray
32from pandas import DataFrame
43from sklearn .decomposition import PCA
54from reddwarf .sklearn .cluster import PolisKMeans
65from reddwarf .utils .matrix import generate_raw_matrix , simple_filter_matrix , get_clusterable_participant_ids
76from reddwarf .utils .pca import run_pca
87from reddwarf .utils .clustering import find_optimal_k
98from dataclasses import dataclass
9+ import pandas as pd
1010
11- from reddwarf .utils .stats import calculate_comment_statistics_dataframes
11+ from reddwarf .utils .stats import calculate_comment_statistics_dataframes , populate_priority_calculations_into_statements_df
1212
1313@dataclass
1414class PolisClusteringResult :
@@ -22,6 +22,8 @@ class PolisClusteringResult:
2222 kmeans (PolisKMeans): Scikit-Learn KMeans object for selected group count, including `labels_` and `cluster_centers_`. See `PolisKMeans`.
2323 group_aware_consensus (DataFrame): Group-aware consensus scores for each statement.
2424 group_comment_stats (DataFrame): A multi-index dataframes for each statement, indexed by group ID and statement.
25+ statements_df (DataFrame): A dataframe with all intermediary and final statement data/calculations/metadata.
26+ participants_df (DataFrame): A dataframe with all intermediary and final participant data/calculations/metadata.
2527 """
2628 raw_vote_matrix : DataFrame
2729 filtered_vote_matrix : DataFrame
@@ -31,10 +33,13 @@ class PolisClusteringResult:
3133 kmeans : PolisKMeans | None
3234 group_aware_consensus : DataFrame
3335 group_comment_stats : DataFrame
36+ statements_df : DataFrame
37+ participants_df : DataFrame
3438
3539def run_clustering (
3640 votes : list [dict ],
3741 mod_out_statement_ids : list [int ] = [],
42+ meta_statement_ids : list [int ] = [],
3843 min_user_vote_threshold : int = 7 ,
3944 keep_participant_ids : list [int ] = [],
4045 init_centers : Optional [list [list [float ]]] = None ,
@@ -53,6 +58,7 @@ def run_clustering(
5358 Args:
5459 votes (list[dict]): Raw list of vote dicts, with keys for "participant_id", "statement_id", "vote" and "modified"
5560 mod_out_statement_ids (list[int]): List of statement IDs to moderate/zero out
61+ meta_statement_ids (list[int]): List of meta statement IDs
5662 min_user_vote_threshold (int): Minimum number of votes a participant must make to be included in clustering
5763 keep_participant_ids (list[int]): List of participant IDs to keep in clustering algorithm, regardless of normal filters.
5864 max_group_count (): Max number of group (k-values) to test using k-means and silhouette scores
@@ -70,42 +76,69 @@ def run_clustering(
7076 mod_out_statement_ids = mod_out_statement_ids ,
7177 )
7278
73- projected_participants , projected_statements , pca = run_pca (vote_matrix = filtered_vote_matrix )
79+ # Run PCA and generate participant/statement projections.
80+ # DataFrames each have "x" and "y" columns.
81+ participants_df , statements_df , pca = run_pca (vote_matrix = filtered_vote_matrix )
7482
75- participant_ids_clusterable = get_clusterable_participant_ids (raw_vote_matrix , vote_threshold = min_user_vote_threshold )
83+ participant_ids_to_cluster = get_clusterable_participant_ids (raw_vote_matrix , vote_threshold = min_user_vote_threshold )
7684 if keep_participant_ids :
77- participant_ids_clusterable = list (set (participant_ids_clusterable + keep_participant_ids ))
85+ # TODO: Make this an intersection, in case there are members of
86+ # keep_participant_ids list that aren't represented in vote_matrix.
87+ participant_ids_to_cluster = sorted (list (set (participant_ids_to_cluster + keep_participant_ids )))
7888
7989 if force_group_count :
8090 k_bounds = [force_group_count , force_group_count ]
8191 else :
8292 k_bounds = [2 , max_group_count ]
8393
84- projected_participants_clusterable = projected_participants .loc [participant_ids_clusterable , :]
8594 _ , _ , kmeans = find_optimal_k (
86- projected_data = projected_participants_clusterable ,
95+ projected_data = participants_df . loc [ participant_ids_to_cluster , :] ,
8796 k_bounds = k_bounds ,
8897 # Force polis strategy of initiating cluster centers. See: PolisKMeans.
8998 init = "polis" ,
9099 init_centers = init_centers ,
91100 random_state = random_state ,
92101 )
93- projected_participants_clusterable = projected_participants_clusterable .assign (
94- cluster_id = kmeans .labels_ if kmeans else None ,
102+ label_series = pd .Series (
103+ kmeans .labels_ if kmeans else None ,
104+ index = participant_ids_to_cluster ,
105+ dtype = "Int64" , # Allows nullable/NaN values.
95106 )
107+ participants_df ["to_cluster" ] = participants_df .index .isin (participant_ids_to_cluster )
108+ participants_df ["cluster_id" ] = label_series
96109
97110 grouped_stats_df , gac_df = calculate_comment_statistics_dataframes (
98- vote_matrix = raw_vote_matrix .loc [participant_ids_clusterable , :],
111+ vote_matrix = raw_vote_matrix .loc [participant_ids_to_cluster , :],
99112 cluster_labels = kmeans .labels_ ,
100113 )
101114
115+ def get_with_default (lst , idx , default = None ):
116+ try :
117+ return lst [idx ]
118+ except IndexError :
119+ return default
120+
121+ statements_df ["to_zero" ] = statements_df .index .isin (mod_out_statement_ids )
122+ statements_df ["is_meta" ] = statements_df .index .isin (meta_statement_ids )
123+ statements_df ["mean" ] = pca .mean_
124+ statements_df ["pc1" ] = get_with_default (pca .components_ , 0 )
125+ statements_df ["pc2" ] = get_with_default (pca .components_ , 1 )
126+ statements_df ["pc3" ] = get_with_default (pca .components_ , 2 )
127+ statements_df = pd .concat ([statements_df , gac_df ], axis = 1 )
128+ statements_df = populate_priority_calculations_into_statements_df (
129+ statements_df = statements_df ,
130+ vote_matrix = raw_vote_matrix .loc [participant_ids_to_cluster , :],
131+ )
132+
102133 return PolisClusteringResult (
103134 raw_vote_matrix = raw_vote_matrix ,
104135 filtered_vote_matrix = filtered_vote_matrix ,
105136 pca = pca ,
106- projected_participants = projected_participants_clusterable ,
107- projected_statements = projected_statements ,
137+ projected_participants = participants_df . loc [ participant_ids_to_cluster , [ "x" , "y" , "cluster_id" ]], # deprecate?
138+ projected_statements = statements_df . loc [:, [ "x" , "y" ]], # deprecate?
108139 kmeans = kmeans ,
109- group_aware_consensus = gac_df ,
140+ group_aware_consensus = gac_df , # deprecate?
110141 group_comment_stats = grouped_stats_df ,
142+ statements_df = statements_df ,
143+ participants_df = participants_df ,
111144 )
0 commit comments