Skip to content

Commit

Permalink
fix(hds): override the fit_predict and fit_predict_proba methods (#128)
Browse files Browse the repository at this point in the history
* fix(hds): override the fit_predict and fit_predict_proba methods

* fix: mypy error

* fix: mypy error

* fix: pre-commit
  • Loading branch information
shenxiangzhuang authored Jan 14, 2025
1 parent a90bdca commit 35f0661
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
28 changes: 28 additions & 0 deletions crowdkit/aggregation/classification/dawid_skene.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,31 @@ def fit(self, data: pd.DataFrame) -> "OneCoinDawidSkene": # type: ignore[overri
self.labels_ = get_most_probable_labels(probas)

return self

def fit_predict_proba(self, data: pd.DataFrame) -> pd.DataFrame: # type: ignore[override]
"""Fits the model to the training data and returns probability distributions of labels for each task.
Args:
data (DataFrame): The training dataset of workers' labeling results
which is represented as the `pandas.DataFrame` data containing `task`, `worker`, and `label` columns.
Returns:
DataFrame: Probability distributions of task labels.
The `pandas.DataFrame` data is indexed by `task` so that `result.loc[task, label]` is the probability that the `task` true label is equal to `label`.
Each probability is in the range from 0 to 1, all task probabilities must sum up to 1.
"""

self.fit(data)
assert self.probas_ is not None, "no probas_"
return self.probas_

def fit_predict(self, data: pd.DataFrame) -> "pd.Series[Any]": # type: ignore[override]
"""Fits the model to the training data and returns the aggregated results.
Args:
data (DataFrame): The training dataset of workers' labeling results
which is represented as the `pandas.DataFrame` data containing `task`, `worker`, and `label` columns.
Returns:
Series: Task labels. The `pandas.Series` data is indexed by `task` so that `labels.loc[task]` is the most likely true label of tasks.
"""

self.fit(data)
assert self.labels_ is not None, "no labels_"
return self.labels_
22 changes: 22 additions & 0 deletions tests/aggregation/test_ds_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def test_aggregate_hds_on_toy_ysda(
toy_ground_truth_df.sort_index(),
)

assert_series_equal(
OneCoinDawidSkene(n_iter=n_iter, tol=tol)
.fit_predict(toy_answers_df)
.sort_index(),
toy_ground_truth_df.sort_index(),
)

probas = OneCoinDawidSkene(n_iter=n_iter, tol=tol).fit_predict_proba(toy_answers_df)
assert ((probas >= 0) & (probas <= 1)).all().all()


@pytest.mark.parametrize("n_iter, tol", [(10, 0), (100500, 1e-5)])
def test_aggregate_ds_on_simple(
Expand Down Expand Up @@ -432,6 +442,18 @@ def test_aggregate_hds_on_simple(
simple_ground_truth.sort_index(),
)

assert_series_equal(
OneCoinDawidSkene(n_iter=n_iter, tol=tol)
.fit_predict(simple_answers_df)
.sort_index(),
simple_ground_truth.sort_index(),
)

probas = OneCoinDawidSkene(n_iter=n_iter, tol=tol).fit_predict_proba(
simple_answers_df
)
assert ((probas >= 0) & (probas <= 1)).all().all()


def _make_probas(data: List[List[Any]]) -> pd.DataFrame:
# TODO: column should not be an index!
Expand Down

0 comments on commit 35f0661

Please sign in to comment.