7
7
from __future__ import annotations
8
8
9
9
import argparse
10
- from collections .abc import Generator
10
+ import warnings
11
+ from collections .abc import Generator , Sequence
11
12
from typing import Any
12
13
13
14
import astroid
@@ -27,10 +28,28 @@ class DiaDefGenerator:
27
28
def __init__ (self , linker : Linker , handler : DiadefsHandler ) -> None :
28
29
"""Common Diagram Handler initialization."""
29
30
self .config = handler .config
31
+ self .args = handler .args
30
32
self .module_names : bool = False
31
33
self ._set_default_options ()
32
34
self .linker = linker
33
35
self .classdiagram : ClassDiagram # defined by subclasses
36
+ # Only pre-calculate depths if user has requested a max_depth
37
+ if handler .config .max_depth is not None :
38
+ # Detect which of the args are leaf nodes
39
+ leaf_nodes = self .get_leaf_nodes ()
40
+
41
+ # Emit a warning if any of the args are not leaf nodes
42
+ diff = set (self .args ).difference (set (leaf_nodes ))
43
+ if len (diff ) > 0 :
44
+ warnings .warn (
45
+ "Detected nested names within the specified packages. "
46
+ f"The following packages: { sorted (diff )} will be ignored for "
47
+ f"depth calculations, using only: { sorted (leaf_nodes )} as the base for limiting "
48
+ "package depth." ,
49
+ stacklevel = 2 ,
50
+ )
51
+
52
+ self .args_depths = {module : module .count ("." ) for module in leaf_nodes }
34
53
35
54
def get_title (self , node : nodes .ClassDef ) -> str :
36
55
"""Get title for objects."""
@@ -39,6 +58,22 @@ def get_title(self, node: nodes.ClassDef) -> str:
39
58
title = f"{ node .root ().name } .{ title } "
40
59
return title # type: ignore[no-any-return]
41
60
61
+ def get_leaf_nodes (self ) -> list [str ]:
62
+ """
63
+ Get the leaf nodes from the list of args in the generator.
64
+
65
+ A leaf node is one that is not a prefix (with an extra dot) of any other node.
66
+ """
67
+ leaf_nodes = [
68
+ module
69
+ for module in self .args
70
+ if not any (
71
+ other != module and other .startswith (module + "." )
72
+ for other in self .args
73
+ )
74
+ ]
75
+ return leaf_nodes
76
+
42
77
def _set_option (self , option : bool | None ) -> bool :
43
78
"""Activate some options if not explicitly deactivated."""
44
79
# if we have a class diagram, we want more information by default;
@@ -67,6 +102,30 @@ def _get_levels(self) -> tuple[int, int]:
67
102
"""Help function for search levels."""
68
103
return self .anc_level , self .association_level
69
104
105
+ def _should_include_by_depth (self , node : nodes .NodeNG ) -> bool :
106
+ """Check if a node should be included based on depth.
107
+
108
+ A node will be included if it is at or below the max_depth relative to the
109
+ specified base packages. A node is considered to be a base package if it is the
110
+ deepest package in the list of specified packages. In other words the base nodes
111
+ are the leaf nodes of the specified package tree.
112
+ """
113
+ # If max_depth is not set, include all nodes
114
+ if self .config .max_depth is None :
115
+ return True
116
+
117
+ # Calculate the absolute depth of the node
118
+ name = node .root ().name
119
+ absolute_depth = name .count ("." )
120
+
121
+ # Retrieve the base depth to compare against
122
+ relative_depth = next (
123
+ (v for k , v in self .args_depths .items () if name .startswith (k )), None
124
+ )
125
+ return relative_depth is not None and bool (
126
+ (absolute_depth - relative_depth ) <= self .config .max_depth
127
+ )
128
+
70
129
def show_node (self , node : nodes .ClassDef ) -> bool :
71
130
"""Determine if node should be shown based on config."""
72
131
if node .root ().name == "builtins" :
@@ -75,7 +134,8 @@ def show_node(self, node: nodes.ClassDef) -> bool:
75
134
if is_stdlib_module (node .root ().name ):
76
135
return self .config .show_stdlib # type: ignore[no-any-return]
77
136
78
- return True
137
+ # Filter node by depth
138
+ return self ._should_include_by_depth (node )
79
139
80
140
def add_class (self , node : nodes .ClassDef ) -> None :
81
141
"""Visit one class and add it to diagram."""
@@ -163,7 +223,7 @@ def visit_module(self, node: nodes.Module) -> None:
163
223
164
224
add this class to the package diagram definition
165
225
"""
166
- if self .pkgdiagram :
226
+ if self .pkgdiagram and self . _should_include_by_depth ( node ) :
167
227
self .linker .visit (node )
168
228
self .pkgdiagram .add_object (node .name , node )
169
229
@@ -177,7 +237,7 @@ def visit_classdef(self, node: nodes.ClassDef) -> None:
177
237
178
238
def visit_importfrom (self , node : nodes .ImportFrom ) -> None :
179
239
"""Visit astroid.ImportFrom and catch modules for package diagram."""
180
- if self .pkgdiagram :
240
+ if self .pkgdiagram and self . _should_include_by_depth ( node ) :
181
241
self .pkgdiagram .add_from_depend (node , node .modname )
182
242
183
243
@@ -208,8 +268,9 @@ def class_diagram(self, project: Project, klass: nodes.ClassDef) -> ClassDiagram
208
268
class DiadefsHandler :
209
269
"""Get diagram definitions from user (i.e. xml files) or generate them."""
210
270
211
- def __init__ (self , config : argparse .Namespace ) -> None :
271
+ def __init__ (self , config : argparse .Namespace , args : Sequence [ str ] ) -> None :
212
272
self .config = config
273
+ self .args = args
213
274
214
275
def get_diadefs (self , project : Project , linker : Linker ) -> list [ClassDiagram ]:
215
276
"""Get the diagram's configuration data.
0 commit comments