Skip to content

Commit f524d3e

Browse files
committed
Use Ruff TYPE_CHECKING blocks
1 parent 00ede45 commit f524d3e

File tree

8 files changed

+53
-48
lines changed

8 files changed

+53
-48
lines changed

.github/workflows/test.yml

+7-5
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ jobs:
1212
fail-fast: false
1313
matrix:
1414
python-version: ["3.8", "3.9", "3.10"]
15-
pytorch-version: ["1.4.0", "1.5.1", "1.6.0", "1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1"]
15+
pytorch-version: ["1.4.0", "1.5.1", "1.6.0", "1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1", "2.2"]
1616
include:
1717
- python-version: 3.11
1818
pytorch-version: 2.0
1919
- python-version: 3.11
2020
pytorch-version: 2.1
21+
- python-version: 3.11
22+
pytorch-version: 2.2
2123
exclude:
2224
- python-version: 3.9
2325
pytorch-version: 1.4.0
@@ -56,16 +58,16 @@ jobs:
5658
pip install torch==${{ matrix.pytorch-version }} torchvision transformers
5759
pip install compressai
5860
- name: mypy
59-
if: ${{ matrix.pytorch-version == '2.1' }}
61+
if: ${{ matrix.pytorch-version == '2.2' }}
6062
run: |
61-
python -m pip install mypy==1.7.1
63+
python -m pip install mypy==1.9.0
6264
mypy --install-types --non-interactive .
6365
- name: pytest
64-
if: ${{ matrix.pytorch-version == '2.1' }}
66+
if: ${{ matrix.pytorch-version == '2.2' }}
6567
run: |
6668
pytest --cov=torchinfo --cov-report= --durations=0
6769
- name: pytest
68-
if: ${{ matrix.pytorch-version != '2.1' }}
70+
if: ${{ matrix.pytorch-version != '2.2' }}
6971
run: |
7072
pytest --no-output -k "not test_eval_order_doesnt_matter and not test_google and not test_uninitialized_tensor and not test_input_size_half_precision and not test_recursive_with_missing_layers and not test_flan_t5_small"
7173
- name: codecov

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ci:
22
skip: [mypy, pytest]
33
repos:
44
- repo: https://github.com/astral-sh/ruff-pre-commit
5-
rev: v0.2.2
5+
rev: v0.3.4
66
hooks:
77
- id: ruff
88
args: [--fix]

ruff.toml

+22-24
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
target-version = "py38"
2-
select = ["ALL"]
3-
ignore = [
4-
"ANN101", # Missing type annotation for `self` in method
5-
"ANN102", # Missing type annotation for `cls` in classmethod
2+
lint.select = ["ALL"]
3+
lint.ignore = [
4+
"ANN101", # Missing type annotation for `self` in method
5+
"ANN102", # Missing type annotation for `cls` in classmethod
66
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
77
"C901", # function is too complex (12 > 10)
88
"COM812", # Trailing comma missing
@@ -15,36 +15,34 @@ ignore = [
1515
"FBT003", # Boolean positional value in function call
1616
"FIX002", # Line contains TODO
1717
"ISC001", # Isort
18-
"PLR0911", # Too many return statements (11 > 6)
18+
"PLR0911", # Too many return statements (11 > 6)
1919
"PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable
20-
"PLR0912", # Too many branches
20+
"PLR0912", # Too many branches
2121
"PLR0913", # Too many arguments to function call
2222
"PLR0915", # Too many statements
2323
"S101", # Use of `assert` detected
2424
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
2525
"T201", # print() found
2626
"T203", # pprint() found
27-
"TCH001", # Move application import into a type-checking block
28-
"TCH003", # Move standard library import into a type-checking block
29-
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): ...`
30-
"TD003", # Missing issue link on the line following this TODO
31-
"TD005", # Missing issue description after `TODO`
32-
"TRY003", # Avoid specifying long messages outside the exception class
27+
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): ...`
28+
"TD003", # Missing issue link on the line following this TODO
29+
"TD005", # Missing issue description after `TODO`
30+
"TRY003", # Avoid specifying long messages outside the exception class
3331

3432
# torchinfo-specific ignores
35-
"N803", # Argument name `A_i` should be lowercase
36-
"N806", # Variable `G` in function should be lowercase
33+
"N803", # Argument name `A_i` should be lowercase
34+
"N806", # Variable `G` in function should be lowercase
3735
"BLE001", # Do not catch blind exception: `Exception`
38-
"PLW0602", # Using global for `_cached_forward_pass` but no assignment is done
39-
"PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged
40-
"PLW2901", # `for` loop variable `name` overwritten by assignment target
41-
"SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block
42-
"SLF001", # Private member accessed: `_modules`
43-
"TCH002", # Move third-party import into a type-checking block
44-
"TRY004", # Prefer `TypeError` exception for invalid type
45-
"TRY301", # Abstract `raise` to an inner function
36+
"PLW0602", # Using global for `_cached_forward_pass` but no assignment is done
37+
"PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged
38+
"PLW2901", # `for` loop variable `name` overwritten by assignment target
39+
"SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block
40+
"SLF001", # Private member accessed: `_modules`
41+
"TCH002", # Move third-party import into a type-checking block
42+
"TRY004", # Prefer `TypeError` exception for invalid type
43+
"TRY301", # Abstract `raise` to an inner function
4644
]
47-
exclude = ["tests"] # TODO: check tests too
45+
exclude = ["tests"] # TODO: check tests too
4846

49-
[flake8-pytest-style]
47+
[lint.flake8-pytest-style]
5048
fixture-parentheses = false

tests/fixtures/models.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(self) -> None:
302302
self.constant = 5
303303

304304
def forward(self, x: dict[int, torch.Tensor], scale_factor: int) -> torch.Tensor:
305-
return cast(torch.Tensor, scale_factor * (x[256] + x[512][0]) * self.constant)
305+
return scale_factor * (x[256] + x[512][0]) * self.constant
306306

307307

308308
class ModuleDictModel(nn.Module):
@@ -358,7 +358,7 @@ def __int__(self) -> IntWithGetitem:
358358
return self
359359

360360
def __getitem__(self, val: int) -> torch.Tensor:
361-
return cast(torch.Tensor, self.tensor * val)
361+
return self.tensor * val
362362

363363

364364
class EdgecaseInputOutputModel(nn.Module):
@@ -575,7 +575,7 @@ def __init__(self) -> None:
575575
self.b = nn.Parameter(torch.empty(10), requires_grad=False)
576576

577577
def forward(self, x: torch.Tensor) -> torch.Tensor:
578-
return cast(torch.Tensor, self.w * x + self.b)
578+
return self.w * x + self.b
579579

580580

581581
class MixedTrainable(nn.Module):
@@ -717,7 +717,7 @@ def __init__(
717717
def forward(self, x: torch.Tensor) -> torch.Tensor:
718718
h = torch.mm(x, self.a) + self.b
719719
if self.output_dim is None:
720-
return cast(torch.Tensor, h)
720+
return h
721721
return cast(torch.Tensor, self.fc2(h))
722722

723723

torchinfo/formatting.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
import math
4-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
55

66
from .enums import ColumnSettings, RowSettings, Units, Verbosity
7-
from .layer_info import LayerInfo
7+
8+
if TYPE_CHECKING:
9+
from .layer_info import LayerInfo
810

911
HEADER_TITLES = {
1012
ColumnSettings.KERNEL_SIZE: "Kernel Shape",

torchinfo/layer_info.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def calculate_size(
119119
size = list(inputs.size())
120120
elem_bytes = inputs.element_size()
121121

122-
elif isinstance(inputs, np.ndarray):
123-
inputs_ = torch.from_numpy(inputs)
122+
elif isinstance(inputs, np.ndarray): # type: ignore[unreachable]
123+
inputs_ = torch.from_numpy(inputs) # type: ignore[unreachable]
124124
size, elem_bytes = list(inputs_.size()), inputs_.element_size()
125125

126126
elif isinstance(inputs, (list, tuple)):
@@ -217,9 +217,9 @@ def calculate_num_params(self) -> None:
217217
final_name = name
218218
# Fix the final row to display more nicely
219219
if self.inner_layers:
220-
self.inner_layers[final_name][
221-
ColumnSettings.NUM_PARAMS
222-
] = f"└─{self.inner_layers[final_name][ColumnSettings.NUM_PARAMS][2:]}"
220+
self.inner_layers[final_name][ColumnSettings.NUM_PARAMS] = (
221+
f"└─{self.inner_layers[final_name][ColumnSettings.NUM_PARAMS][2:]}"
222+
)
223223

224224
def calculate_macs(self) -> None:
225225
"""
@@ -322,8 +322,9 @@ def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], i
322322
size, elem_bytes = nested_list_size(inputs.tensors)
323323
elif isinstance(inputs, torch.Tensor):
324324
size, elem_bytes = list(inputs.size()), inputs.element_size()
325-
elif isinstance(inputs, np.ndarray):
326-
inputs_torch = torch.from_numpy(inputs) # preserves dtype
325+
elif isinstance(inputs, np.ndarray): # type: ignore[unreachable]
326+
# preserves dtype
327+
inputs_torch = torch.from_numpy(inputs) # type: ignore[unreachable]
327328
size, elem_bytes = list(inputs_torch.size()), inputs_torch.element_size()
328329
elif not hasattr(inputs, "__getitem__") or not inputs:
329330
size, elem_bytes = [], 0
@@ -358,8 +359,8 @@ def rgetattr(module: nn.Module, attr: str) -> torch.Tensor | None:
358359
if not hasattr(module, attr_i):
359360
return None
360361
module = getattr(module, attr_i)
361-
assert isinstance(module, torch.Tensor)
362-
return module
362+
assert isinstance(module, torch.Tensor) # type: ignore[unreachable]
363+
return module # type: ignore[unreachable]
363364

364365

365366
def get_children_layers(summary_list: list[LayerInfo], index: int) -> list[LayerInfo]:

torchinfo/model_statistics.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
44

55
from .enums import Units
66
from .formatting import CONVERSION_FACTORS, FormattingOptions
7-
from .layer_info import LayerInfo
7+
8+
if TYPE_CHECKING:
9+
from .layer_info import LayerInfo
810

911

1012
class ModelStatistics:

torchinfo/torchinfo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def get_device(
484484
model_parameter = None
485485

486486
if model_parameter is not None and model_parameter.is_cuda:
487-
return model_parameter.device # type: ignore[no-any-return]
487+
return model_parameter.device
488488
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
489489
return None
490490

0 commit comments

Comments
 (0)