Skip to content

Commit 38b63ac

Browse files
Refactor depth filtering logic (#10200)
Implement depth limiting relative to the deepest package (in other words the leaf nodes of the specified package tree). If a user specifies nested package names, a warning is now emitted to prevent confusion. Co-authored-by: Pierre Sassoulas <[email protected]>
1 parent 0a1044b commit 38b63ac

File tree

8 files changed

+346
-190
lines changed

8 files changed

+346
-190
lines changed

pylint/pyreverse/diadefslib.py

+66-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from __future__ import annotations
88

99
import argparse
10-
from collections.abc import Generator
10+
import warnings
11+
from collections.abc import Generator, Sequence
1112
from typing import Any
1213

1314
import astroid
@@ -27,10 +28,28 @@ class DiaDefGenerator:
2728
def __init__(self, linker: Linker, handler: DiadefsHandler) -> None:
2829
"""Common Diagram Handler initialization."""
2930
self.config = handler.config
31+
self.args = handler.args
3032
self.module_names: bool = False
3133
self._set_default_options()
3234
self.linker = linker
3335
self.classdiagram: ClassDiagram # defined by subclasses
36+
# Only pre-calculate depths if user has requested a max_depth
37+
if handler.config.max_depth is not None:
38+
# Detect which of the args are leaf nodes
39+
leaf_nodes = self.get_leaf_nodes()
40+
41+
# Emit a warning if any of the args are not leaf nodes
42+
diff = set(self.args).difference(set(leaf_nodes))
43+
if len(diff) > 0:
44+
warnings.warn(
45+
"Detected nested names within the specified packages. "
46+
f"The following packages: {sorted(diff)} will be ignored for "
47+
f"depth calculations, using only: {sorted(leaf_nodes)} as the base for limiting "
48+
"package depth.",
49+
stacklevel=2,
50+
)
51+
52+
self.args_depths = {module: module.count(".") for module in leaf_nodes}
3453

3554
def get_title(self, node: nodes.ClassDef) -> str:
3655
"""Get title for objects."""
@@ -39,6 +58,22 @@ def get_title(self, node: nodes.ClassDef) -> str:
3958
title = f"{node.root().name}.{title}"
4059
return title # type: ignore[no-any-return]
4160

61+
def get_leaf_nodes(self) -> list[str]:
62+
"""
63+
Get the leaf nodes from the list of args in the generator.
64+
65+
A leaf node is one that is not a prefix (with an extra dot) of any other node.
66+
"""
67+
leaf_nodes = [
68+
module
69+
for module in self.args
70+
if not any(
71+
other != module and other.startswith(module + ".")
72+
for other in self.args
73+
)
74+
]
75+
return leaf_nodes
76+
4277
def _set_option(self, option: bool | None) -> bool:
4378
"""Activate some options if not explicitly deactivated."""
4479
# if we have a class diagram, we want more information by default;
@@ -67,6 +102,30 @@ def _get_levels(self) -> tuple[int, int]:
67102
"""Help function for search levels."""
68103
return self.anc_level, self.association_level
69104

105+
def _should_include_by_depth(self, node: nodes.NodeNG) -> bool:
106+
"""Check if a node should be included based on depth.
107+
108+
A node will be included if it is at or below the max_depth relative to the
109+
specified base packages. A node is considered to be a base package if it is the
110+
deepest package in the list of specified packages. In other words the base nodes
111+
are the leaf nodes of the specified package tree.
112+
"""
113+
# If max_depth is not set, include all nodes
114+
if self.config.max_depth is None:
115+
return True
116+
117+
# Calculate the absolute depth of the node
118+
name = node.root().name
119+
absolute_depth = name.count(".")
120+
121+
# Retrieve the base depth to compare against
122+
relative_depth = next(
123+
(v for k, v in self.args_depths.items() if name.startswith(k)), None
124+
)
125+
return relative_depth is not None and bool(
126+
(absolute_depth - relative_depth) <= self.config.max_depth
127+
)
128+
70129
def show_node(self, node: nodes.ClassDef) -> bool:
71130
"""Determine if node should be shown based on config."""
72131
if node.root().name == "builtins":
@@ -75,7 +134,8 @@ def show_node(self, node: nodes.ClassDef) -> bool:
75134
if is_stdlib_module(node.root().name):
76135
return self.config.show_stdlib # type: ignore[no-any-return]
77136

78-
return True
137+
# Filter node by depth
138+
return self._should_include_by_depth(node)
79139

80140
def add_class(self, node: nodes.ClassDef) -> None:
81141
"""Visit one class and add it to diagram."""
@@ -163,7 +223,7 @@ def visit_module(self, node: nodes.Module) -> None:
163223
164224
add this class to the package diagram definition
165225
"""
166-
if self.pkgdiagram:
226+
if self.pkgdiagram and self._should_include_by_depth(node):
167227
self.linker.visit(node)
168228
self.pkgdiagram.add_object(node.name, node)
169229

@@ -177,7 +237,7 @@ def visit_classdef(self, node: nodes.ClassDef) -> None:
177237

178238
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
179239
"""Visit astroid.ImportFrom and catch modules for package diagram."""
180-
if self.pkgdiagram:
240+
if self.pkgdiagram and self._should_include_by_depth(node):
181241
self.pkgdiagram.add_from_depend(node, node.modname)
182242

183243

@@ -208,8 +268,9 @@ def class_diagram(self, project: Project, klass: nodes.ClassDef) -> ClassDiagram
208268
class DiadefsHandler:
209269
"""Get diagram definitions from user (i.e. xml files) or generate them."""
210270

211-
def __init__(self, config: argparse.Namespace) -> None:
271+
def __init__(self, config: argparse.Namespace, args: Sequence[str]) -> None:
212272
self.config = config
273+
self.args = args
213274

214275
def get_diadefs(self, project: Project, linker: Linker) -> list[ClassDiagram]:
215276
"""Get the diagram's configuration data.

pylint/pyreverse/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def run(self) -> int:
354354
verbose=self.config.verbose,
355355
)
356356
linker = Linker(project, tag=True)
357-
handler = DiadefsHandler(self.config)
357+
handler = DiadefsHandler(self.config, self.args)
358358
diadefs = handler.get_diadefs(project, linker)
359359
writer.DiagramWriter(self.config).write(diadefs)
360360
return 0

pylint/pyreverse/writer.py

-66
Original file line numberDiff line numberDiff line change
@@ -54,38 +54,6 @@ def write(self, diadefs: Iterable[ClassDiagram | PackageDiagram]) -> None:
5454
self.write_classes(diagram)
5555
self.save()
5656

57-
def should_show_node(self, qualified_name: str, is_class: bool = False) -> bool:
58-
"""Determine if a node should be shown based on depth settings.
59-
60-
Depth is calculated by counting dots in the qualified name:
61-
- depth 0: top-level packages (no dots)
62-
- depth 1: first level sub-packages (one dot)
63-
- depth 2: second level sub-packages (two dots)
64-
65-
For classes, depth is measured from their containing module, excluding
66-
the class name itself from the depth calculation.
67-
"""
68-
# If no depth limit is set ==> show all nodes
69-
if self.max_depth is None:
70-
return True
71-
72-
# For classes, we want to measure depth from their containing module
73-
if is_class:
74-
# Get the module part (everything before the last dot)
75-
last_dot = qualified_name.rfind(".")
76-
if last_dot == -1:
77-
module_path = ""
78-
else:
79-
module_path = qualified_name[:last_dot]
80-
81-
# Count module depth
82-
module_depth = module_path.count(".")
83-
return bool(module_depth <= self.max_depth)
84-
85-
# For packages/modules, count full depth
86-
node_depth = qualified_name.count(".")
87-
return bool(node_depth <= self.max_depth)
88-
8957
def write_packages(self, diagram: PackageDiagram) -> None:
9058
"""Write a package diagram."""
9159
module_info: dict[str, dict[str, int]] = {}
@@ -94,10 +62,6 @@ def write_packages(self, diagram: PackageDiagram) -> None:
9462
for module in sorted(diagram.modules(), key=lambda x: x.title):
9563
module.fig_id = module.node.qname()
9664

97-
# Filter nodes based on depth
98-
if not self.should_show_node(module.fig_id):
99-
continue
100-
10165
if self.config.no_standalone and not any(
10266
module in (rel.from_object, rel.to_object)
10367
for rel in diagram.get_relationships("depends")
@@ -120,10 +84,6 @@ def write_packages(self, diagram: PackageDiagram) -> None:
12084
from_id = rel.from_object.fig_id
12185
to_id = rel.to_object.fig_id
12286

123-
# Filter nodes based on depth ==> skip if either source or target nodes is beyond the max depth
124-
if not self.should_show_node(from_id) or not self.should_show_node(to_id):
125-
continue
126-
12787
self.printer.emit_edge(
12888
from_id,
12989
to_id,
@@ -137,10 +97,6 @@ def write_packages(self, diagram: PackageDiagram) -> None:
13797
from_id = rel.from_object.fig_id
13898
to_id = rel.to_object.fig_id
13999

140-
# Filter nodes based on depth ==> skip if either source or target nodes is beyond the max depth
141-
if not self.should_show_node(from_id) or not self.should_show_node(to_id):
142-
continue
143-
144100
self.printer.emit_edge(
145101
from_id,
146102
to_id,
@@ -161,10 +117,6 @@ def write_classes(self, diagram: ClassDiagram) -> None:
161117
for obj in sorted(diagram.objects, key=lambda x: x.title):
162118
obj.fig_id = obj.node.qname()
163119

164-
# Filter class based on depth setting
165-
if not self.should_show_node(obj.fig_id, is_class=True):
166-
continue
167-
168120
if self.config.no_standalone and not any(
169121
obj in (rel.from_object, rel.to_object)
170122
for rel_type in ("specialization", "association", "aggregation")
@@ -179,12 +131,6 @@ def write_classes(self, diagram: ClassDiagram) -> None:
179131
)
180132
# inheritance links
181133
for rel in diagram.get_relationships("specialization"):
182-
# Filter nodes based on depth setting
183-
if not self.should_show_node(
184-
rel.from_object.fig_id, is_class=True
185-
) or not self.should_show_node(rel.to_object.fig_id, is_class=True):
186-
continue
187-
188134
self.printer.emit_edge(
189135
rel.from_object.fig_id,
190136
rel.to_object.fig_id,
@@ -193,12 +139,6 @@ def write_classes(self, diagram: ClassDiagram) -> None:
193139
associations: dict[str, set[str]] = defaultdict(set)
194140
# generate associations
195141
for rel in diagram.get_relationships("association"):
196-
# Filter nodes based on depth setting
197-
if not self.should_show_node(
198-
rel.from_object.fig_id, is_class=True
199-
) or not self.should_show_node(rel.to_object.fig_id, is_class=True):
200-
continue
201-
202142
associations[rel.from_object.fig_id].add(rel.to_object.fig_id)
203143
self.printer.emit_edge(
204144
rel.from_object.fig_id,
@@ -208,12 +148,6 @@ def write_classes(self, diagram: ClassDiagram) -> None:
208148
)
209149
# generate aggregations
210150
for rel in diagram.get_relationships("aggregation"):
211-
# Filter nodes based on depth setting
212-
if not self.should_show_node(
213-
rel.from_object.fig_id, is_class=True
214-
) or not self.should_show_node(rel.to_object.fig_id, is_class=True):
215-
continue
216-
217151
if rel.to_object.fig_id in associations[rel.from_object.fig_id]:
218152
continue
219153
self.printer.emit_edge(

pylint/typing.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import argparse
10-
from collections.abc import Iterable
10+
from collections.abc import Iterable, Sequence
1111
from pathlib import Path
1212
from re import Pattern
1313
from typing import (
@@ -24,8 +24,10 @@
2424

2525
if TYPE_CHECKING:
2626
from pylint.config.callback_actions import _CallbackAction
27+
from pylint.pyreverse.diadefslib import DiaDefGenerator
2728
from pylint.pyreverse.inspector import Project
2829
from pylint.reporters.ureports.nodes import Section
30+
from pylint.testutils.pyreverse import PyreverseConfig
2931
from pylint.utils import LinterStats
3032

3133

@@ -134,3 +136,9 @@ class GetProjectCallable(Protocol):
134136
def __call__(
135137
self, module: str, name: str | None = "No Name"
136138
) -> Project: ... # pragma: no cover
139+
140+
141+
class GeneratorFactory(Protocol):
142+
def __call__(
143+
self, config: PyreverseConfig | None = None, args: Sequence[str] | None = None
144+
) -> DiaDefGenerator: ...

tests/pyreverse/conftest.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from collections.abc import Callable
7+
from collections.abc import Callable, Sequence
88

99
import pytest
1010
from astroid.nodes.scoped_nodes import Module
@@ -15,8 +15,15 @@
1515
from pylint.typing import GetProjectCallable
1616

1717

18+
@pytest.fixture()
19+
def default_args() -> Sequence[str]:
20+
"""Provides default command-line arguments for tests."""
21+
return ["data"]
22+
23+
1824
@pytest.fixture()
1925
def default_config() -> PyreverseConfig:
26+
"""Provides default configuration for tests."""
2027
return PyreverseConfig()
2128

2229

0 commit comments

Comments
 (0)