|
| 1 | +from collections.abc import Mapping, Sequence |
| 2 | + |
1 | 3 | from reddwarf.implementations import base |
2 | 4 |
|
3 | 5 |
|
4 | 6 | # This is to not break things. |
5 | 7 | # TODO: Adde deprecation warning. |
6 | | -def run_clustering(**kwargs) -> base.TypedPolisPipelineResult: |
7 | | - return run_pipeline(**kwargs) |
| 8 | +def run_clustering( |
| 9 | + votes: Sequence[Mapping[str, object]], |
| 10 | + reducer: base.ReducerType = "pca", |
| 11 | + reducer_kwargs: dict[str, object] | None = None, |
| 12 | + clusterer: base.ClustererType = "kmeans", |
| 13 | + clusterer_kwargs: dict[str, object] | None = None, |
| 14 | + mod_out_statement_ids: list[int] | None = None, |
| 15 | + meta_statement_ids: list[int] | None = None, |
| 16 | + min_user_vote_threshold: int = 7, |
| 17 | + keep_participant_ids: list[int] | None = None, |
| 18 | + init_centers: list[list[float]] | None = None, |
| 19 | + max_group_count: int = 5, |
| 20 | + force_group_count: int | None = None, |
| 21 | + random_state: int | None = None, |
| 22 | + pick_max: int = 5, |
| 23 | + confidence: float = 0.9, |
| 24 | + consensus_mode: base.Literal["standard", "legacy"] = "legacy", |
| 25 | + candidate_group_counts: base.CandidateGroupCounts = None, |
| 26 | +) -> base.TypedPolisPipelineResult: |
| 27 | + return run_pipeline( |
| 28 | + votes=votes, |
| 29 | + reducer=reducer, |
| 30 | + reducer_kwargs=reducer_kwargs, |
| 31 | + clusterer=clusterer, |
| 32 | + clusterer_kwargs=clusterer_kwargs, |
| 33 | + mod_out_statement_ids=mod_out_statement_ids, |
| 34 | + meta_statement_ids=meta_statement_ids, |
| 35 | + min_user_vote_threshold=min_user_vote_threshold, |
| 36 | + keep_participant_ids=keep_participant_ids, |
| 37 | + init_centers=init_centers, |
| 38 | + max_group_count=max_group_count, |
| 39 | + force_group_count=force_group_count, |
| 40 | + random_state=random_state, |
| 41 | + pick_max=pick_max, |
| 42 | + confidence=confidence, |
| 43 | + consensus_mode=consensus_mode, |
| 44 | + candidate_group_counts=candidate_group_counts, |
| 45 | + ) |
8 | 46 |
|
9 | 47 |
|
10 | | -def run_pipeline(**kwargs) -> base.TypedPolisPipelineResult: |
11 | | - kwargs = { |
12 | | - "reducer": "pca", |
13 | | - "clusterer": "kmeans", |
14 | | - "consensus_mode": "legacy", |
15 | | - **kwargs, |
16 | | - } |
17 | | - return base.run_pipeline(**kwargs) |
| 48 | +def run_pipeline( |
| 49 | + votes: Sequence[Mapping[str, object]], |
| 50 | + reducer: base.ReducerType = "pca", |
| 51 | + reducer_kwargs: dict[str, object] | None = None, |
| 52 | + clusterer: base.ClustererType = "kmeans", |
| 53 | + clusterer_kwargs: dict[str, object] | None = None, |
| 54 | + mod_out_statement_ids: list[int] | None = None, |
| 55 | + meta_statement_ids: list[int] | None = None, |
| 56 | + min_user_vote_threshold: int = 7, |
| 57 | + keep_participant_ids: list[int] | None = None, |
| 58 | + init_centers: list[list[float]] | None = None, |
| 59 | + max_group_count: int = 5, |
| 60 | + force_group_count: int | None = None, |
| 61 | + random_state: int | None = None, |
| 62 | + pick_max: int = 5, |
| 63 | + confidence: float = 0.9, |
| 64 | + consensus_mode: base.Literal["standard", "legacy"] = "legacy", |
| 65 | + candidate_group_counts: base.CandidateGroupCounts = None, |
| 66 | +) -> base.TypedPolisPipelineResult: |
| 67 | + return base.run_pipeline( |
| 68 | + votes=list(votes), |
| 69 | + reducer=reducer, |
| 70 | + reducer_kwargs=reducer_kwargs or {}, |
| 71 | + clusterer=clusterer, |
| 72 | + clusterer_kwargs=clusterer_kwargs or {}, |
| 73 | + mod_out_statement_ids=mod_out_statement_ids or [], |
| 74 | + meta_statement_ids=meta_statement_ids or [], |
| 75 | + min_user_vote_threshold=min_user_vote_threshold, |
| 76 | + keep_participant_ids=keep_participant_ids or [], |
| 77 | + init_centers=init_centers, |
| 78 | + max_group_count=max_group_count, |
| 79 | + force_group_count=force_group_count, |
| 80 | + random_state=random_state, |
| 81 | + pick_max=pick_max, |
| 82 | + confidence=confidence, |
| 83 | + consensus_mode=consensus_mode, |
| 84 | + candidate_group_counts=candidate_group_counts, |
| 85 | + ) |
0 commit comments