Skip to content

Commit cdefc78

Browse files
authored
Generate stubs properly for extensions with lazy loaded modules (#1765)
1 parent ef7cb9d commit cdefc78

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

metaflow/cmd/develop/stub_generator.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from metaflow import FlowSpec, step
3131
from metaflow.debug import debug
3232
from metaflow.decorators import Decorator, FlowDecorator
33+
from metaflow.extension_support import get_aliased_modules
3334
from metaflow.graph import deindent_docstring
3435
from metaflow.metaflow_version import get_version
3536

@@ -116,6 +117,7 @@ def __init__(self, output_dir: str, include_generated_for: bool = True):
116117

117118
self._write_generated_for = include_generated_for
118119
self._pending_modules = ["metaflow"] # type: List[str]
120+
self._pending_modules.extend(get_aliased_modules())
119121
self._root_module = "metaflow."
120122
self._safe_modules = ["metaflow.", "metaflow_extensions."]
121123

metaflow/cmd/develop/stubs.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -323,17 +323,26 @@ def get_packages_for_stubs() -> Tuple[List[Tuple[str, str]], List[str]]:
323323
return [], []
324324

325325
dist_list = []
326-
for dist in _metadata_package.distributions():
326+
327+
# We check the type because if the user has multiple importlib metadata, for
328+
# some reason it shows up multiple times.
329+
interesting_dists = [
330+
d
331+
for d in _metadata_package.distributions()
327332
if any(
328333
[
329-
pkg == "metaflow-stubs"
330-
for pkg in (dist.read_text("top_level.txt") or "").split()
334+
p == "metaflow-stubs"
335+
for p in (d.read_text("top_level.txt") or "").split()
331336
]
332-
):
333-
# This is a package we care about
334-
root_path = dist.locate_file("metaflow-stubs").as_posix()
335-
dist_list.append((dist.metadata["Name"], root_path))
336-
all_paths.discard(root_path)
337+
)
338+
and isinstance(d, _metadata_package.PathDistribution)
339+
]
340+
341+
for dist in interesting_dists:
342+
# This is a package we care about
343+
root_path = dist.locate_file("metaflow-stubs").as_posix()
344+
dist_list.append((dist.metadata["Name"], root_path))
345+
all_paths.discard(root_path)
337346
return dist_list, list(all_paths)
338347

339348

0 commit comments

Comments
 (0)