Skip to content

Commit 35f0661

Browse files
fix(hds): override the fit_predict and fit_predict_proba methods (#128)
* fix(hds): override the fit_predict and fit_predict_proba methods * fix: mypy error * fix: mypy error * fix: pre-commit
1 parent a90bdca commit 35f0661

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

crowdkit/aggregation/classification/dawid_skene.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,31 @@ def fit(self, data: pd.DataFrame) -> "OneCoinDawidSkene": # type: ignore[overri
578578
self.labels_ = get_most_probable_labels(probas)
579579

580580
return self
581+
582+
def fit_predict_proba(self, data: pd.DataFrame) -> pd.DataFrame: # type: ignore[override]
583+
"""Fits the model to the training data and returns probability distributions of labels for each task.
584+
Args:
585+
data (DataFrame): The training dataset of workers' labeling results
586+
which is represented as the `pandas.DataFrame` data containing `task`, `worker`, and `label` columns.
587+
Returns:
588+
DataFrame: Probability distributions of task labels.
589+
The `pandas.DataFrame` data is indexed by `task` so that `result.loc[task, label]` is the probability that the `task` true label is equal to `label`.
590+
Each probability is in the range from 0 to 1, all task probabilities must sum up to 1.
591+
"""
592+
593+
self.fit(data)
594+
assert self.probas_ is not None, "no probas_"
595+
return self.probas_
596+
597+
def fit_predict(self, data: pd.DataFrame) -> "pd.Series[Any]": # type: ignore[override]
598+
"""Fits the model to the training data and returns the aggregated results.
599+
Args:
600+
data (DataFrame): The training dataset of workers' labeling results
601+
which is represented as the `pandas.DataFrame` data containing `task`, `worker`, and `label` columns.
602+
Returns:
603+
Series: Task labels. The `pandas.Series` data is indexed by `task` so that `labels.loc[task]` is the most likely true label of tasks.
604+
"""
605+
606+
self.fit(data)
607+
assert self.labels_ is not None, "no labels_"
608+
return self.labels_

tests/aggregation/test_ds_aggregation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,16 @@ def test_aggregate_hds_on_toy_ysda(
404404
toy_ground_truth_df.sort_index(),
405405
)
406406

407+
assert_series_equal(
408+
OneCoinDawidSkene(n_iter=n_iter, tol=tol)
409+
.fit_predict(toy_answers_df)
410+
.sort_index(),
411+
toy_ground_truth_df.sort_index(),
412+
)
413+
414+
probas = OneCoinDawidSkene(n_iter=n_iter, tol=tol).fit_predict_proba(toy_answers_df)
415+
assert ((probas >= 0) & (probas <= 1)).all().all()
416+
407417

408418
@pytest.mark.parametrize("n_iter, tol", [(10, 0), (100500, 1e-5)])
409419
def test_aggregate_ds_on_simple(
@@ -432,6 +442,18 @@ def test_aggregate_hds_on_simple(
432442
simple_ground_truth.sort_index(),
433443
)
434444

445+
assert_series_equal(
446+
OneCoinDawidSkene(n_iter=n_iter, tol=tol)
447+
.fit_predict(simple_answers_df)
448+
.sort_index(),
449+
simple_ground_truth.sort_index(),
450+
)
451+
452+
probas = OneCoinDawidSkene(n_iter=n_iter, tol=tol).fit_predict_proba(
453+
simple_answers_df
454+
)
455+
assert ((probas >= 0) & (probas <= 1)).all().all()
456+
435457

436458
def _make_probas(data: List[List[Any]]) -> pd.DataFrame:
437459
# TODO: column should not be an index!

0 commit comments

Comments
 (0)