diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77f0296..499b36a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,14 +35,14 @@ repos: - id: text-unicode-replacement-char - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.8.6" + rev: "v0.11.4" hooks: - id: ruff args: - --fix - repo: https://github.com/psf/black - rev: 24.10.0 + rev: 25.1.0 hooks: - id: black additional_dependencies: [toml] @@ -54,7 +54,7 @@ repos: additional_dependencies: [black] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.14.1" + rev: "v1.15.0" hooks: - id: mypy additional_dependencies: diff --git a/pyproject.toml b/pyproject.toml index d818c36..bafdc68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,9 +143,11 @@ where = ["src"] [tool.ruff] target-version = "py310" line-length = 88 + +[tool.ruff.lint] select = ["ALL"] ignore = [ - "ANN101", "ANN102", "ANN401", + "ANN401", "ARG001", "ARG002", "D105", "D107", "D203", "D213", "ERA001", @@ -169,14 +171,14 @@ ignore = [ "TD", ] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.isort] +[tool.ruff.lint.isort] force-sort-within-sections = true known-third-party = ["numpy", "pytest", "torch"] known-first-party = ["stream_mapper.core"] known-local-folder = ["stream_mapper.pytorch"] -[tool.ruff.pylint] +[tool.ruff.lint.pylint] max-args = 6 diff --git a/src/stream_mapper/pytorch/_base.py b/src/stream_mapper/pytorch/_base.py index 9067732..b5f299d 100644 --- a/src/stream_mapper/pytorch/_base.py +++ b/src/stream_mapper/pytorch/_base.py @@ -5,10 +5,11 @@ __all__: tuple[str, ...] = () from dataclasses import KW_ONLY, dataclass -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any import torch as xp from torch import nn +from typing_extensions import Self from stream_mapper.core import ModelBase as CoreModelBase from stream_mapper.core._connect.nn_namespace import NN_NAMESPACE @@ -22,8 +23,6 @@ if TYPE_CHECKING: from stream_mapper.core import Data - Self = TypeVar("Self", bound="ModelBase") - @dataclass(unsafe_hash=True, repr=False) class ModelBase(nn.Module, CoreModelBase[Array, NNModel]): diff --git a/src/stream_mapper/pytorch/_multi.py b/src/stream_mapper/pytorch/_multi.py index 1416287..a0fcedb 100644 --- a/src/stream_mapper/pytorch/_multi.py +++ b/src/stream_mapper/pytorch/_multi.py @@ -5,9 +5,10 @@ __all__: tuple[str, ...] = () from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any from torch import nn +from typing_extensions import Self from stream_mapper.core import BACKGROUND_KEY, NNField from stream_mapper.core import IndependentModels as CoreIndependentModels @@ -23,8 +24,6 @@ if TYPE_CHECKING: from stream_mapper.core import Data - Self = TypeVar("Self", bound="MixtureModel") - @dataclass class ModelsBase(nn.Module, CoreModelsBase[Array, NNModel]): @@ -122,7 +121,7 @@ class MixtureModel(ModelsBase, CoreMixtureModel[Array, NNModel]): net: NNField[NNModel, NNModel] = NNField(default=MISSING) - def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: # noqa: ARG003 + def __new__(cls: type[Self], *_: Any, **__: Any) -> Self: """Initialize the model. This is needed for PyTorch.""" self: Self = super().__new__(cls) # PyTorch needs to be initialized before attributes are assigned.