Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions flit_core/flit_core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,7 @@ def get_docstring_and_version_via_ast(target):
with target_path.open('rb') as f:
node = ast.parse(f.read())
for child in node.body:
if sys.version_info >= (3, 8):
target_type = ast.Constant
else:
target_type = ast.Str
# Only use the version from the given module if it's a simple
# string assignment to __version__
is_version_str = (
isinstance(child, ast.Assign)
and any(
isinstance(target, ast.Name)
and target.id == "__version__"
for target in child.targets
)
and isinstance(child.value, target_type)
)
if is_version_str:
if is_version_str_node(child):
if sys.version_info >= (3, 8):
version = child.value.value
else:
Expand All @@ -172,6 +157,20 @@ def get_docstring_and_version_via_ast(target):
return ast.get_docstring(node), version


def is_version_str_node(node, /):
"""Check if *node* is a simple string assignment to __version__"""
if not isinstance(node, (ast.Assign, ast.AnnAssign)):
return False
constant_type = ast.Constant if sys.version_info >= (3, 8) else ast.Str
if not isinstance(node.value, constant_type):
return False
targets = (node.target,) if isinstance(node, ast.AnnAssign) else node.targets
for target in targets:
if isinstance(target, ast.Name) and target.id == "__version__":
return True
return False


# To ensure we're actually loading the specified file, give it a unique name to
# avoid any cached import. In normal use we'll only load one module per process,
# so it should only matter for the tests, but we'll do it anyway.
Expand Down
4 changes: 4 additions & 0 deletions flit_core/tests_core/samples/annotated_version/module1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

"""This module has a __version__ that has a type annotation"""

__version__: str = '0.1'
12 changes: 12 additions & 0 deletions flit_core/tests_core/samples/annotated_version/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[build-system]
requires = ["flit_core >=2,<4"]
build-backend = "flit_core.buildapi"

[tool.flit.metadata]
module = "module1"
author = "Sir Robin"
author-email = "robin@camelot.uk"
home-page = "http://github.com/sirrobin/module1"
requires = [
"numpy >=1.16.0",
]
5 changes: 5 additions & 0 deletions flit_core/tests_core/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def test_get_info_from_module(self):
'version': '0.1'}
)

info = get_info_from_module(Module('module1', samples_dir / 'annotated_version'))
self.assertEqual(info, {'summary': 'This module has a __version__ that has a type annotation',
'version': '0.1'}
)

info = get_info_from_module(Module('module1', samples_dir / 'constructed_version'))
self.assertEqual(info, {'summary': 'This module has a __version__ that requires runtime interpretation',
'version': '1.2.3'}
Expand Down
Loading