Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions crowdkit/aggregation/classification/dawid_skene.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,31 @@ def fit(self, data: pd.DataFrame) -> "OneCoinDawidSkene": # type: ignore[overri
self.labels_ = get_most_probable_labels(probas)

return self

def fit_predict_proba(self, data: pd.DataFrame) -> pd.DataFrame: # type: ignore[override]
"""Fits the model to the training data and returns probability distributions of labels for each task.
Args:
data (DataFrame): The training dataset of workers' labeling results
which is represented as the `pandas.DataFrame` data containing `task`, `worker`, and `label` columns.
Returns:
DataFrame: Probability distributions of task labels.
The `pandas.DataFrame` data is indexed by `task` so that `result.loc[task, label]` is the probability that the `task` true label is equal to `label`.
Each probability is in the range from 0 to 1, all task probabilities must sum up to 1.
"""

self.fit(data)
assert self.probas_ is not None, "no probas_"
return self.probas_

def fit_predict(self, data: pd.DataFrame) -> "pd.Series[Any]": # type: ignore[override]
"""Fits the model to the training data and returns the aggregated results.
Args:
data (DataFrame): The training dataset of workers' labeling results
which is represented as the `pandas.DataFrame` data containing `task`, `worker`, and `label` columns.
Returns:
Series: Task labels. The `pandas.Series` data is indexed by `task` so that `labels.loc[task]` is the most likely true label of tasks.
"""

self.fit(data)
assert self.labels_ is not None, "no labels_"
return self.labels_
22 changes: 22 additions & 0 deletions tests/aggregation/test_ds_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def test_aggregate_hds_on_toy_ysda(
toy_ground_truth_df.sort_index(),
)

assert_series_equal(
OneCoinDawidSkene(n_iter=n_iter, tol=tol)
.fit_predict(toy_answers_df)
.sort_index(),
toy_ground_truth_df.sort_index(),
)

probas = OneCoinDawidSkene(n_iter=n_iter, tol=tol).fit_predict_proba(toy_answers_df)
assert ((probas >= 0) & (probas <= 1)).all().all()


@pytest.mark.parametrize("n_iter, tol", [(10, 0), (100500, 1e-5)])
def test_aggregate_ds_on_simple(
Expand Down Expand Up @@ -432,6 +442,18 @@ def test_aggregate_hds_on_simple(
simple_ground_truth.sort_index(),
)

assert_series_equal(
OneCoinDawidSkene(n_iter=n_iter, tol=tol)
.fit_predict(simple_answers_df)
.sort_index(),
simple_ground_truth.sort_index(),
)

probas = OneCoinDawidSkene(n_iter=n_iter, tol=tol).fit_predict_proba(
simple_answers_df
)
assert ((probas >= 0) & (probas <= 1)).all().all()


def _make_probas(data: List[List[Any]]) -> pd.DataFrame:
# TODO: column should not be an index!
Expand Down