Skip to content

Commit b051883

Browse files
committed
fix(polis): type wrapper parameters
Give the Polis wrapper explicit parameter types so downstream type checkers can use the typed package without falling back to unknown **kwargs.
1 parent 73d663c commit b051883

1 file changed

Lines changed: 78 additions & 10 deletions

File tree

reddwarf/implementations/polis.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,85 @@
1+
from collections.abc import Mapping, Sequence
2+
13
from reddwarf.implementations import base
24

35

46
# This is to not break things.
57
# TODO: Adde deprecation warning.
6-
def run_clustering(**kwargs) -> base.TypedPolisPipelineResult:
7-
return run_pipeline(**kwargs)
8+
def run_clustering(
9+
votes: Sequence[Mapping[str, object]],
10+
reducer: base.ReducerType = "pca",
11+
reducer_kwargs: dict[str, object] | None = None,
12+
clusterer: base.ClustererType = "kmeans",
13+
clusterer_kwargs: dict[str, object] | None = None,
14+
mod_out_statement_ids: list[int] | None = None,
15+
meta_statement_ids: list[int] | None = None,
16+
min_user_vote_threshold: int = 7,
17+
keep_participant_ids: list[int] | None = None,
18+
init_centers: list[list[float]] | None = None,
19+
max_group_count: int = 5,
20+
force_group_count: int | None = None,
21+
random_state: int | None = None,
22+
pick_max: int = 5,
23+
confidence: float = 0.9,
24+
consensus_mode: base.Literal["standard", "legacy"] = "legacy",
25+
candidate_group_counts: base.CandidateGroupCounts = None,
26+
) -> base.TypedPolisPipelineResult:
27+
return run_pipeline(
28+
votes=votes,
29+
reducer=reducer,
30+
reducer_kwargs=reducer_kwargs,
31+
clusterer=clusterer,
32+
clusterer_kwargs=clusterer_kwargs,
33+
mod_out_statement_ids=mod_out_statement_ids,
34+
meta_statement_ids=meta_statement_ids,
35+
min_user_vote_threshold=min_user_vote_threshold,
36+
keep_participant_ids=keep_participant_ids,
37+
init_centers=init_centers,
38+
max_group_count=max_group_count,
39+
force_group_count=force_group_count,
40+
random_state=random_state,
41+
pick_max=pick_max,
42+
confidence=confidence,
43+
consensus_mode=consensus_mode,
44+
candidate_group_counts=candidate_group_counts,
45+
)
846

947

10-
def run_pipeline(**kwargs) -> base.TypedPolisPipelineResult:
11-
kwargs = {
12-
"reducer": "pca",
13-
"clusterer": "kmeans",
14-
"consensus_mode": "legacy",
15-
**kwargs,
16-
}
17-
return base.run_pipeline(**kwargs)
48+
def run_pipeline(
49+
votes: Sequence[Mapping[str, object]],
50+
reducer: base.ReducerType = "pca",
51+
reducer_kwargs: dict[str, object] | None = None,
52+
clusterer: base.ClustererType = "kmeans",
53+
clusterer_kwargs: dict[str, object] | None = None,
54+
mod_out_statement_ids: list[int] | None = None,
55+
meta_statement_ids: list[int] | None = None,
56+
min_user_vote_threshold: int = 7,
57+
keep_participant_ids: list[int] | None = None,
58+
init_centers: list[list[float]] | None = None,
59+
max_group_count: int = 5,
60+
force_group_count: int | None = None,
61+
random_state: int | None = None,
62+
pick_max: int = 5,
63+
confidence: float = 0.9,
64+
consensus_mode: base.Literal["standard", "legacy"] = "legacy",
65+
candidate_group_counts: base.CandidateGroupCounts = None,
66+
) -> base.TypedPolisPipelineResult:
67+
return base.run_pipeline(
68+
votes=list(votes),
69+
reducer=reducer,
70+
reducer_kwargs=reducer_kwargs or {},
71+
clusterer=clusterer,
72+
clusterer_kwargs=clusterer_kwargs or {},
73+
mod_out_statement_ids=mod_out_statement_ids or [],
74+
meta_statement_ids=meta_statement_ids or [],
75+
min_user_vote_threshold=min_user_vote_threshold,
76+
keep_participant_ids=keep_participant_ids or [],
77+
init_centers=init_centers,
78+
max_group_count=max_group_count,
79+
force_group_count=force_group_count,
80+
random_state=random_state,
81+
pick_max=pick_max,
82+
confidence=confidence,
83+
consensus_mode=consensus_mode,
84+
candidate_group_counts=candidate_group_counts,
85+
)

0 commit comments

Comments
 (0)