Skip to content

Commit f8b2937

Browse files
committed
Added previous_result arg to polis.run_clustering() to seed args from prior run.
1 parent a2ab643 commit f8b2937

2 files changed

Lines changed: 12 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
- Modify `SparsityAwareScaler` to be able to use captured output from SparsityAware Capture.
4444
- Remove ported Polis PCA functions that are no longer used.
4545
- Remove old `impute_missing_votes()` function that's no longer used.
46+
- Add arg to polis implementation, to seed args with previous result object, locking group number.
4647

4748
### Chores
4849
- Moved agora implementation from `reddwarf.agora` to `reddwarf.implementations.agora` (deprecation warning).

reddwarf/implementations/polis.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def run_clustering(
4646
max_group_count: int = 5,
4747
force_group_count: Optional[int] = None,
4848
random_state: Optional[int] = None,
49+
previous_result: Optional[PolisClusteringResult] = None,
4950
) -> PolisClusteringResult:
5051
"""
5152
An essentially feature-complete implementation of the Polis clustering algorithm.
@@ -65,10 +66,20 @@ def run_clustering(
6566
init_centers (list[list[float]]): Initial guesses of [x,y] coordinates for k-means (Length of list must match max_group_count)
6667
force_group_count (int): Instead of using silhouette scores, force a specific number of groups (k value)
6768
random_state (int): If set, will force determinism during k-means clustering
69+
previous_result (PolisClusteringResult): The result of a previous run of this function, to seed args.
6870
6971
Returns:
7072
PolisClusteringResult: A dataclass containing clustering results, including intermediate calculations.
7173
"""
74+
if previous_result:
75+
prev_kmeans = previous_result.kmeans
76+
if prev_kmeans:
77+
init_centers = prev_kmeans.cluster_centers_
78+
# TODO: Implement some variant of k-smoothing to stabilize this instead of locking it.
79+
prev_k = len(prev_kmeans.cluster_centers_)
80+
force_group_count = prev_k
81+
keep_participant_ids = list(previous_result.participants_df.query('to_cluster').index)
82+
7283
raw_vote_matrix = generate_raw_matrix(votes=votes)
7384

7485
filtered_vote_matrix = simple_filter_matrix(

0 commit comments

Comments
 (0)