101101
102102IN_SCOPE_MARKER = "# router-trust: in-scope"
103103
104- # An import like ``from foo import bar`` produces a dotted target with at
105- # least one dot ("foo.bar"). Names with fewer dots cannot be resolved to a
106- # (module, name) pair.
107- _MIN_DOTTED_PARTS = 2
108-
109104
110105# ---------------------------------------------------------------------------
111106# Per-file collected info
112107# ---------------------------------------------------------------------------
113108
114109
110+ @dataclass (frozen = True )
111+ class ImportTarget :
112+ """What an alias in a file's namespace refers to.
113+
114+ Two kinds:
115+ * ``kind="from"`` -- ``from <module> import <name> [as alias]``;
116+ the alias names a *value* (router, function, etc.) inside
117+ ``<module>``.
118+ * ``kind="module"`` -- ``import <module> [as alias]``; the alias
119+ names a *module*. Accessing ``.x`` either descends into a
120+ submodule or pulls a value out of the module.
121+
122+ Kept as its own dataclass so the resolver can branch on the import
123+ shape without re-deriving it from the raw string.
124+ """
125+
126+ kind : str # "from" or "module"
127+ module : str # the module path (for "from": the source module; for "module": the imported module path)
128+ name : str | None = None # for "from": the imported name; for "module": None
129+
130+
115131@dataclass
116132class IncludeCall :
117- parent_var : str
118- child_var : str
133+ parent_chain : tuple [ str , ...]
134+ child_chain : tuple [ str , ...]
119135 child_prefix : str | None
120136 lineno : int
121137
122138
123139@dataclass
124140class DecoratorRef :
125- router_var : str
141+ router_chain : tuple [ str , ...]
126142 method : str
127143 path : str
128144 lineno : int
@@ -134,8 +150,8 @@ class FileInfo:
134150 path : Path
135151 # var -> prefix string (None if no prefix kwarg)
136152 local_routers : dict [str , str | None ] = field (default_factory = dict )
137- # alias -> "<module>.<name>" (resolved import target)
138- imports : dict [str , str ] = field (default_factory = dict )
153+ # alias -> ImportTarget describing what the alias refers to
154+ imports : dict [str , ImportTarget ] = field (default_factory = dict )
139155 include_calls : list [IncludeCall ] = field (default_factory = list )
140156 decorators : list [DecoratorRef ] = field (default_factory = list )
141157 has_marker : bool = False
@@ -216,6 +232,26 @@ def _module_to_file(module: str) -> Path | None:
216232 return None
217233
218234
235+ def _attribute_chain (node : ast .AST ) -> tuple [str , ...] | None :
236+ """Flatten an ``ast.Name`` / ``ast.Attribute`` chain into a tuple of segments.
237+
238+ ``foo`` -> ``("foo",)``
239+ ``foo.bar`` -> ``("foo", "bar")``
240+ ``foo.bar.baz`` -> ``("foo", "bar", "baz")``
241+ Anything else (subscript, call, etc.) -> ``None``.
242+ """
243+ parts : list [str ] = []
244+ current : ast .AST = node
245+ while isinstance (current , ast .Attribute ):
246+ parts .append (current .attr )
247+ current = current .value
248+ if not isinstance (current , ast .Name ):
249+ return None
250+ parts .append (current .id )
251+ parts .reverse ()
252+ return tuple (parts )
253+
254+
219255def parse_file (path : Path ) -> FileInfo | None :
220256 """Parse one file with AST; return None on syntax error."""
221257 try :
@@ -238,55 +274,62 @@ def parse_file(path: Path) -> FileInfo | None:
238274 prefix = _kwarg_string (node .value , "prefix" )
239275 info .local_routers [target_name ] = prefix
240276
241- # Imports
277+ # Imports. We record the *kind* of import so the resolver can
278+ # distinguish ``from X import Y`` (Y is a value) from ``import X.Y``
279+ # (where X.Y is itself a module and ``X.Y.router`` is the value).
242280 if isinstance (node , ast .ImportFrom ):
243281 module = _resolve_relative_module (file_module , node .level , node .module )
244282 if module :
245283 for alias in node .names :
246284 local_name = alias .asname or alias .name
247- info .imports [local_name ] = f" { module } . { alias .name } "
285+ info .imports [local_name ] = ImportTarget ( kind = "from" , module = module , name = alias .name )
248286 elif isinstance (node , ast .Import ):
249287 for alias in node .names :
250288 local_name = alias .asname or alias .name .split ("." )[0 ]
251- info .imports [local_name ] = alias .name
289+ info .imports [local_name ] = ImportTarget ( kind = "module" , module = alias .name , name = None )
252290
253- # parent.include_router(child, prefix=...)
291+ # parent.include_router(child, prefix=...). Both parent and child
292+ # may be ``ast.Attribute`` chains (``app.api.include_router(...)``
293+ # or ``include_router(child.api.router, ...)``), so we flatten both
294+ # sides into tuples and let the resolver handle the dotted form.
254295 if (
255296 isinstance (node , ast .Call )
256297 and isinstance (node .func , ast .Attribute )
257298 and node .func .attr == "include_router"
258- and isinstance (node .func .value , ast .Name )
259299 and node .args
260- and isinstance (node .args [0 ], ast .Name )
261300 ):
262- parent_var = node .func .value .id
263- child_var = node .args [0 ].id
264- prefix = _kwarg_string (node , "prefix" )
265- info .include_calls .append (
266- IncludeCall (
267- parent_var = parent_var ,
268- child_var = child_var ,
269- child_prefix = prefix ,
270- lineno = node .lineno ,
301+ parent_chain = _attribute_chain (node .func .value )
302+ child_chain = _attribute_chain (node .args [0 ])
303+ if parent_chain is not None and child_chain is not None :
304+ prefix = _kwarg_string (node , "prefix" )
305+ info .include_calls .append (
306+ IncludeCall (
307+ parent_chain = parent_chain ,
308+ child_chain = child_chain ,
309+ child_prefix = prefix ,
310+ lineno = node .lineno ,
311+ )
271312 )
272- )
273313
274- # @<router>.<method>(...) on a (Async)FunctionDef
314+ # @<router>.<method>(...) on a (Async)FunctionDef. The router
315+ # reference can also be a dotted attribute chain
316+ # (``@child.api.router.post(...)``), so flatten it too.
275317 if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
276318 for deco in node .decorator_list :
277319 if (
278320 isinstance (deco , ast .Call )
279321 and isinstance (deco .func , ast .Attribute )
280322 and deco .func .attr in HTTP_METHODS
281- and isinstance (deco .func .value , ast .Name )
282323 ):
283- router_var = deco .func .value .id
324+ router_chain = _attribute_chain (deco .func .value )
325+ if router_chain is None :
326+ continue
284327 path_str = ""
285328 if deco .args and isinstance (deco .args [0 ], ast .Constant ) and isinstance (deco .args [0 ].value , str ):
286329 path_str = deco .args [0 ].value
287330 info .decorators .append (
288331 DecoratorRef (
289- router_var = router_var ,
332+ router_chain = router_chain ,
290333 method = deco .func .attr ,
291334 path = path_str ,
292335 lineno = deco .lineno ,
@@ -302,28 +345,108 @@ def parse_file(path: Path) -> FileInfo | None:
302345# ---------------------------------------------------------------------------
303346
304347
305- def _resolve_var (file_info : FileInfo , var : str , file_info_map : dict [Path , FileInfo ]) -> RouterId | None :
306- """Map ``var`` (a name in ``file_info``) to the (file, var) where it's defined."""
307- if var in file_info .local_routers :
308- return (file_info .path , var )
309- if var in file_info .imports :
310- full = file_info .imports [var ]
311- parts = full .split ("." )
312- if len (parts ) < _MIN_DOTTED_PARTS :
313- return None
314- module = "." .join (parts [:- 1 ])
315- target_name = parts [- 1 ]
316- target_file = _module_to_file (module )
317- if target_file is None :
318- return None
319- target_info = file_info_map .get (target_file )
320- if target_info is None :
348+ def _resolve_chain (
349+ file_info : FileInfo ,
350+ chain : tuple [str , ...],
351+ file_info_map : dict [Path , FileInfo ],
352+ * ,
353+ seen : frozenset [tuple [Path , tuple [str , ...]]] = frozenset (),
354+ ) -> RouterId | None :
355+ """Map a dotted attribute chain to the (file, var_name) defining the router.
356+
357+ Handles four name shapes (the ``A``/``B``/``C``/``D`` correspond to
358+ cases enumerated in this script's docstring):
359+
360+ A. ``from X.Y import Z`` then ``Z`` -> chain=("Z",)
361+ B. ``from X.Y import Z as alias`` then ``alias`` -> chain=("alias",)
362+ C. ``import X.Y`` then ``X.Y.Z`` -> chain=("X","Y","Z")
363+ D. ``import X.Y as alias`` then ``alias.Z`` -> chain=("alias","Z")
364+
365+ Cycles in re-export chains (``a.py`` re-exports from ``b.py`` which
366+ re-exports from ``a.py``) are bounded via the ``seen`` set; the
367+ resolver returns ``None`` rather than recursing forever.
368+ """
369+ if not chain :
370+ return None
371+ head = chain [0 ]
372+ rest = chain [1 :]
373+
374+ # Bound recursion against re-export cycles.
375+ fingerprint = (file_info .path , chain )
376+ if fingerprint in seen :
377+ return None
378+ seen = seen | {fingerprint }
379+
380+ # Local definition (only meaningful for a single Name, not an
381+ # attribute chain -- ``foo.bar`` cannot resolve to a local variable
382+ # ``foo`` because that would be a method call, not a router lookup).
383+ if not rest and head in file_info .local_routers :
384+ return (file_info .path , head )
385+
386+ if head not in file_info .imports :
387+ return None
388+ imp = file_info .imports [head ]
389+
390+ if imp .kind == "from" :
391+ if not rest :
392+ # `from M import N [as alias]; alias` -> module=M, var=N
393+ module = imp .module
394+ var = imp .name
395+ else :
396+ # `from M import N [as alias]; alias.x.y...`
397+ # Treat alias as a value living at M.N; the chain after head
398+ # walks deeper (rare for routers but handle it cleanly).
399+ module = "." .join ([imp .module , * ([imp .name ] if imp .name else []), * rest [:- 1 ]])
400+ var = rest [- 1 ]
401+ if not var :
321402 return None
322- if target_name in target_info .local_routers :
323- return (target_file , target_name )
324- # Re-export chain: the target file may itself import this name.
325- if target_name in target_info .imports :
326- return _resolve_var (target_info , target_name , file_info_map )
403+ elif imp .kind == "module" :
404+ # ``import M [as alias]``; ``imp.module`` is M.
405+ if head == imp .module :
406+ # Case where head is the literal module path (``import x``;
407+ # alias matches "x" exactly). ``head.x.y`` -> module=M+x, var=y.
408+ if not rest :
409+ return None
410+ module = "." .join ([imp .module , * rest [:- 1 ]])
411+ var = rest [- 1 ]
412+ elif imp .module .startswith (head + "." ):
413+ # Case C: ``import x.y.z`` (no asname). Python's binding rule:
414+ # the local name is the *first* segment ("x"); the rest of the
415+ # dotted path lives under it. Code references must spell out
416+ # the full module path before the var: ``x.y.z.router``.
417+ module_parts = tuple (imp .module .split ("." ))
418+ # Verify chain prefix == module_parts.
419+ if len (chain ) <= len (module_parts ):
420+ return None
421+ for idx , mp in enumerate (module_parts ):
422+ if chain [idx ] != mp :
423+ return None
424+ sub = chain [len (module_parts ) :]
425+ if not sub :
426+ return None
427+ module = "." .join ([imp .module , * sub [:- 1 ]])
428+ var = sub [- 1 ]
429+ else :
430+ # Case D: ``import x.y as alias``; head=alias, rest=(...,var)
431+ if not rest :
432+ return None
433+ module = "." .join ([imp .module , * rest [:- 1 ]])
434+ var = rest [- 1 ]
435+ else :
436+ return None
437+
438+ target_file = _module_to_file (module )
439+ if target_file is None :
440+ return None
441+ target_info = file_info_map .get (target_file )
442+ if target_info is None :
443+ return None
444+ if var in target_info .local_routers :
445+ return (target_file , var )
446+ # Re-export chain: the target file may itself import this name and
447+ # forward it on.
448+ if var in target_info .imports :
449+ return _resolve_chain (target_info , (var ,), file_info_map , seen = seen )
327450 return None
328451
329452
@@ -348,14 +471,14 @@ def compute_in_scope(file_info_map: dict[Path, FileInfo]) -> set[RouterId]:
348471 changed = False
349472 for info in file_info_map .values ():
350473 for call in info .include_calls :
351- child_id = _resolve_var (info , call .child_var , file_info_map )
474+ child_id = _resolve_chain (info , call .child_chain , file_info_map )
352475 if child_id is None or child_id in in_scope :
353476 continue
354477 child_in_scope = False
355478 if call .child_prefix and "/extensions" in call .child_prefix :
356479 child_in_scope = True
357480 else :
358- parent_id = _resolve_var (info , call .parent_var , file_info_map )
481+ parent_id = _resolve_chain (info , call .parent_chain , file_info_map )
359482 if parent_id is not None and parent_id in in_scope :
360483 child_in_scope = True
361484 if child_in_scope :
@@ -384,27 +507,30 @@ def scan_in_scope(
384507 file_info_map : dict [Path , FileInfo ],
385508 in_scope : set [RouterId ],
386509) -> list [str ]:
510+ """Walk every decorator and flag forbidden handlers on in-scope routers.
511+
512+ A decorator's router reference is a dotted chain (e.g.
513+ ``("router",)`` for ``@router.post(...)`` or ``("child", "api", "router")``
514+ for ``@child.api.router.post(...)``). We resolve the chain back to a
515+ ``RouterId`` and check whether that router is in scope.
516+ """
387517 violations : list [str ] = []
388- in_scope_files : dict [Path , set [str ]] = {}
389- for path , var in in_scope :
390- in_scope_files .setdefault (path , set ()).add (var )
391518
392519 for path , info in file_info_map .items ():
393- target_vars = in_scope_files .get (path )
394- # If the file has the explicit marker, every decorator counts -- we
395- # cannot trust ``router_var`` to be a literal name.
396520 check_all = info .has_marker
397- if not target_vars and not check_all :
398- continue
399-
400521 for deco in info .decorators :
401- if not check_all and deco .router_var not in (target_vars or set ()):
522+ in_scope_router = check_all
523+ if not in_scope_router :
524+ deco_router = _resolve_chain (info , deco .router_chain , file_info_map )
525+ in_scope_router = deco_router is not None and deco_router in in_scope
526+ if not in_scope_router :
402527 continue
403528 token = _violation_token (deco .path ) or _violation_token (deco .func_name )
404529 if token is not None :
530+ router_repr = "." .join (deco .router_chain )
405531 violations .append (
406532 f"{ path } :{ deco .lineno } : forbidden token { token !r} in handler "
407- f"@{ deco . router_var } .{ deco .method } ({ deco .path !r} ) def { deco .func_name } (...)"
533+ f"@{ router_repr } .{ deco .method } ({ deco .path !r} ) def { deco .func_name } (...)"
408534 )
409535 return violations
410536
0 commit comments