Skip to content

Commit 20679c7

Browse files
committed
Fix typing
1 parent 381530a commit 20679c7

File tree

12 files changed

+66
-61
lines changed

12 files changed

+66
-61
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v5.0.0
3+
rev: v6.0.0
44
hooks:
55
- id: trailing-whitespace
66
- id: end-of-file-fixer
77
- id: check-yaml
88
- id: check-added-large-files
99
- repo: https://github.com/psf/black
10-
rev: 24.10.0
10+
rev: 25.9.0
1111
hooks:
1212
- id: black
1313
- repo: https://github.com/pycqa/isort
14-
rev: 5.13.2
14+
rev: 7.0.0
1515
hooks:
1616
- id: isort
1717
- repo: https://github.com/pycqa/flake8
18-
rev: 7.1.1
18+
rev: 7.3.0
1919
hooks:
2020
- id: flake8
2121
- repo: https://github.com/asottile/pyupgrade
22-
rev: v3.19.1
22+
rev: v3.21.0
2323
hooks:
2424
- id: pyupgrade
2525
- repo: https://github.com/nbQA-dev/nbQA

crowdkit/aggregation/classification/dawid_skene.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["DawidSkene", "OneCoinDawidSkene"]
22

3-
from typing import Any, List, Literal, Optional, cast
3+
from typing import Any, List, Literal, Optional, Tuple, cast
44

55
import attr
66
import numpy as np
@@ -235,7 +235,7 @@ def _evidence_lower_bound(
235235
priors = priors.rename(index={True: "True", False: "False"}, copy=False)
236236
priors.clip(lower=_EPS, inplace=True)
237237

238-
joined.loc[:, priors.index] = joined.loc[:, priors.index].add(np.log(priors)) # type: ignore
238+
joined.loc[:, priors.index] = joined.loc[:, priors.index].add(np.log(priors))
239239

240240
joined.set_index(["task", "worker"], inplace=True)
241241
joint_expectation = (
@@ -487,25 +487,29 @@ class OneCoinDawidSkene(DawidSkene):
487487
"""
488488

489489
@staticmethod
490-
def _assign_skills(row: "pd.Series[Any]", skills: pd.DataFrame) -> pd.DataFrame:
490+
def _assign_skills(
491+
row: "pd.Series[Any]", skills: "pd.Series[Any]"
492+
) -> "pd.Series[Any]":
491493
"""
492494
Assigns user skills to error matrix row by row.
493495
"""
494496
num_categories = len(row)
495497
for column_name, _ in row.items():
496-
if column_name == row.name[1]: # type: ignore
497-
row[column_name] = skills[row.name[0]] # type: ignore
498+
if column_name == cast(Tuple[Any, Any], row.name)[1]:
499+
row[column_name] = skills.loc[cast(Tuple[Any, Any], row.name)[0]]
498500
else:
499-
row[column_name] = (1 - skills[row.name[0]]) / (num_categories - 1) # type: ignore
500-
return row # type: ignore
501+
row[column_name] = (
502+
1 - skills.loc[cast(Tuple[Any, Any], row.name)[0]]
503+
) / (num_categories - 1)
504+
return row
501505

502506
@staticmethod
503507
def _process_skills_to_errors(
504508
data: pd.DataFrame, probas: pd.DataFrame, skills: "pd.Series[Any]"
505509
) -> pd.DataFrame:
506510
errors = DawidSkene._m_step(data, probas)
507511

508-
errors = errors.apply(OneCoinDawidSkene._assign_skills, args=(skills,), axis=1) # type: ignore
512+
errors = errors.apply(OneCoinDawidSkene._assign_skills, args=(skills,), axis=1)
509513
errors.clip(lower=_EPS, upper=1 - _EPS, inplace=True)
510514

511515
return errors

crowdkit/aggregation/classification/glad.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def _optimize_df(self, x: npt.NDArray[Any]) -> npt.NDArray[Any]:
226226
dQalpha, dQbeta = self._gradient_Q(self._current_data)
227227

228228
minus_grad = np.zeros_like(x)
229-
minus_grad[: len(self.workers_)] = -dQalpha[self.workers_].values # type: ignore
230-
minus_grad[len(self.workers_) :] = -dQbeta[self.tasks_].values # type: ignore
229+
minus_grad[: len(self.workers_)] = -dQalpha[self.workers_].values # type: ignore[operator,index]
230+
minus_grad[len(self.workers_) :] = -dQbeta[self.tasks_].values # type: ignore[operator,index]
231231
return minus_grad
232232

233233
def _update_alphas_betas(
@@ -245,9 +245,9 @@ def _update_alphas_betas(
245245
def _get_alphas_betas_by_point(
246246
self, x: npt.NDArray[Any]
247247
) -> Tuple["pd.Series[Any]", "pd.Series[Any]"]:
248-
alphas = pd.Series(x[: len(self.workers_)], index=self.workers_, name="alpha") # type: ignore
248+
alphas = pd.Series(x[: len(self.workers_)], index=self.workers_, name="alpha")
249249
alphas.index.name = "worker"
250-
betas = pd.Series(x[len(self.workers_) :], index=self.tasks_, name="beta") # type: ignore
250+
betas = pd.Series(x[len(self.workers_) :], index=self.tasks_, name="beta")
251251
betas.index.name = "task"
252252
return alphas, betas
253253

@@ -268,15 +268,15 @@ def _m_step(self, data: pd.DataFrame) -> pd.DataFrame:
268268
return self._current_data
269269

270270
def _init(self, data: pd.DataFrame) -> None:
271-
self.alphas_ = pd.Series(1.0, index=pd.unique(data.worker)) # type: ignore
272-
self.betas_ = pd.Series(1.0, index=pd.unique(data.task)) # type: ignore
271+
self.alphas_ = pd.Series(1.0, index=pd.unique(data.worker))
272+
self.betas_ = pd.Series(1.0, index=pd.unique(data.task))
273273
self.tasks_ = pd.unique(data["task"])
274274
self.workers_ = pd.unique(data["worker"])
275275
self.priors_ = self.labels_priors
276276
if self.priors_ is None:
277277
self.prior_labels_ = pd.unique(data["label"])
278278
self.priors_ = pd.Series(
279-
1.0 / len(self.prior_labels_), index=self.prior_labels_ # type: ignore
279+
1.0 / len(self.prior_labels_), index=self.prior_labels_
280280
)
281281
else:
282282
self.prior_labels_ = self.priors_.index # type: ignore

crowdkit/aggregation/classification/mace.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["MACE"]
22

3-
from typing import Any, Iterator, List, Optional, Tuple, Union
3+
from typing import Any, Iterator, List, Optional, Tuple, Union, cast
44

55
import attr
66
import numpy as np
@@ -24,11 +24,14 @@ def normalize(x: NDArray[np.float64], smoothing: float) -> NDArray[np.float64]:
2424
np.ndarray: Normalized array
2525
"""
2626
norm = (x + smoothing).sum(axis=1)
27-
return np.divide(
28-
x + smoothing,
29-
norm[:, np.newaxis],
30-
out=np.zeros_like(x),
31-
where=~np.isclose(norm[:, np.newaxis], np.zeros_like(norm[:, np.newaxis])),
27+
return cast(
28+
NDArray[np.float64],
29+
np.divide(
30+
x + smoothing,
31+
norm[:, np.newaxis],
32+
out=np.zeros_like(x),
33+
where=~np.isclose(norm[:, np.newaxis], np.zeros_like(norm[:, np.newaxis])),
34+
),
3235
)
3336

3437

@@ -46,11 +49,14 @@ def variational_normalize(
4649
"""
4750
norm = (x + hparams).sum(axis=1)
4851
norm = np.exp(digamma(norm))
49-
return np.divide(
50-
np.exp(digamma(x + hparams)),
51-
norm[:, np.newaxis],
52-
out=np.zeros_like(x),
53-
where=~np.isclose(norm[:, np.newaxis], np.zeros_like(norm[:, np.newaxis])),
52+
return cast(
53+
NDArray[np.float64],
54+
np.divide(
55+
np.exp(digamma(x + hparams)),
56+
norm[:, np.newaxis],
57+
out=np.zeros_like(x),
58+
where=~np.isclose(norm[:, np.newaxis], np.zeros_like(norm[:, np.newaxis])),
59+
),
5460
)
5561

5662

crowdkit/aggregation/texts/text_hrrasa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,16 @@ def fit_predict( # type: ignore
112112
self._encode_data(data), self._encode_true_objects(true_objects)
113113
)
114114
self.texts_ = (
115-
hrrasa_results.reset_index()[["task", "output"]] # type: ignore
115+
hrrasa_results.reset_index()[["task", "output"]]
116116
.rename(columns={"output": "text"})
117-
.set_index("task")
117+
.set_index("task")["text"]
118118
)
119119
return self.texts_
120120

121121
def _encode_data(self, data: pd.DataFrame) -> pd.DataFrame:
122122
data = data[["task", "worker", "text"]].rename(columns={"text": "output"})
123-
data["embedding"] = data.output.apply(self.encoder) # type: ignore
123+
data["embedding"] = data.output.apply(self.encoder) # type: ignore[arg-type]
124124
return data
125125

126126
def _encode_true_objects(self, true_objects: "pd.Series[Any]") -> "pd.Series[Any]":
127-
return true_objects and true_objects.apply(self.encoder) # type: ignore
127+
return true_objects and true_objects.apply(self.encoder)

crowdkit/aggregation/texts/text_rasa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def fit_predict( # type: ignore
114114
self._encode_data(data), self._encode_true_objects(true_objects)
115115
)
116116
self.texts_ = (
117-
rasa_results.reset_index()[["task", "output"]] # type: ignore
117+
rasa_results.reset_index()[["task", "output"]]
118118
.rename(columns={"output": "text"})
119-
.set_index("task")
119+
.set_index("task")["text"]
120120
)
121121
return self.texts_
122122

@@ -126,4 +126,4 @@ def _encode_data(self, data: pd.DataFrame) -> pd.DataFrame:
126126
return data
127127

128128
def _encode_true_objects(self, true_objects: "pd.Series[Any]") -> "pd.Series[Any]":
129-
return true_objects and true_objects.apply(self.encoder) # type: ignore
129+
return true_objects and true_objects.apply(self.encoder)

crowdkit/datasets/_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from os import environ, listdir, makedirs, rename
55
from os.path import basename, exists, expanduser, join, splitext
66
from shutil import unpack_archive
7-
from typing import AnyStr, Optional, cast
7+
from typing import Optional
88
from urllib.request import urlretrieve
99

1010

11-
def get_data_dir(data_dir: Optional[AnyStr] = None) -> AnyStr:
11+
def get_data_dir(data_dir: Optional[str] = None) -> str:
1212
"""Return the path of the crowd-kit data dir.
1313
1414
This folder is used by some large dataset loaders to avoid downloading the
@@ -26,9 +26,7 @@ def get_data_dir(data_dir: Optional[AnyStr] = None) -> AnyStr:
2626
is `~/crowdkit_data`.
2727
"""
2828
if data_dir is None:
29-
data_dir = cast(
30-
AnyStr, environ.get("CROWDKIT_DATA", join("~", "crowdkit_data"))
31-
)
29+
data_dir = environ.get("CROWDKIT_DATA", join("~", "crowdkit_data"))
3230
data_dir = expanduser(data_dir)
3331

3432
if not exists(data_dir):

crowdkit/learning/text_summarization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def fit_predict(self, data: pd.DataFrame) -> "pd.Series[Any]":
8282

8383
data = data[["task", "worker", "text"]]
8484

85-
self.model = self.model.to(self.device)
85+
self.model = self.model.to(self.device) # type: ignore[arg-type]
8686
self.texts_ = data.groupby("task")["text"].apply(self._aggregate_one)
8787
return self.texts_
8888

@@ -117,5 +117,5 @@ def _generate_output(
117117
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(
118118
self.device
119119
)
120-
outputs = self.model.generate(input_ids, num_beams=self.num_beams)
121-
return cast(str, self.tokenizer.decode(outputs[0], skip_special_tokens=True))
120+
outputs = self.model.generate(input_ids, num_beams=self.num_beams) # type: ignore[operator]
121+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ version = {attr = "crowdkit.__version__"}
9696

9797
[tool.mypy]
9898
ignore_missing_imports = true
99-
plugins = ["numpy.typing.mypy_plugin"]
10099
strict = true
101100

102101
[tool.isort]

tests/aggregation/test_ds_aggregation.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Testing all boundary conditions and asserts
44
"""
55

6-
from typing import Any, List, Literal, Optional, cast
6+
from typing import Any, List, Literal, Optional
77

88
import numpy as np
99
import pandas as pd
@@ -461,14 +461,12 @@ def _make_probas(data: List[List[Any]]) -> pd.DataFrame:
461461
return pd.DataFrame(data, columns=columns).set_index("task")
462462

463463

464-
def _make_tasks_labels(data: List[List[Any]]) -> pd.DataFrame:
464+
def _make_tasks_labels(data: List[List[Any]]) -> "pd.Series[Any]":
465465
# TODO: should task be indexed?
466-
return cast(
467-
pd.DataFrame,
466+
return (
468467
pd.DataFrame(data, columns=["task", "label"])
469-
.set_index("task")
470-
.squeeze()
471-
.rename("agg_label"),
468+
.set_index("task")["label"]
469+
.rename("agg_label")
472470
)
473471

474472

@@ -532,7 +530,7 @@ def priors_iter_0() -> "pd.Series[Any]":
532530

533531

534532
@pytest.fixture
535-
def tasks_labels_iter_0() -> pd.DataFrame:
533+
def tasks_labels_iter_0() -> "pd.Series[Any]":
536534
return _make_tasks_labels(
537535
[
538536
["t1", "no"],
@@ -581,7 +579,7 @@ def priors_iter_1() -> "pd.Series[Any]":
581579

582580

583581
@pytest.fixture
584-
def tasks_labels_iter_1() -> pd.DataFrame:
582+
def tasks_labels_iter_1() -> "pd.Series[Any]":
585583
return _make_tasks_labels(
586584
[
587585
["t1", "yes"],
@@ -670,7 +668,7 @@ def test_dawid_skene_overlap(overlap: int) -> None:
670668
assert ds.priors_ is not None, "no priors_"
671669
assert ds.labels_ is not None, "no labels_"
672670
assert_frame_equal(expected_probas, ds.probas_, check_like=True, atol=0.005)
673-
assert_series_equal(expected_labels, ds.labels_, atol=0.005) # type: ignore
671+
assert_series_equal(expected_labels, ds.labels_, atol=0.005)
674672
assert_series_equal(
675673
pd.Series([1 / 3, 2 / 3], pd.Index(["no", "yes"], name="label"), name="prior"),
676674
ds.priors_,

0 commit comments

Comments
 (0)