Skip to content

Commit ddeaab9

Browse files
authored
Added a .pyi stub file to type inspection with pylance (#9) (#10)
* Added a .pyi stub file to type inspection with pylance * Formatting * Update version to 0.1.1
2 parents 2abd2fb + 9993dc3 commit ddeaab9

4 files changed

Lines changed: 24 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [ "setuptools", "wheel" ]
55

66
[project]
77
name = "cunumpy"
8-
version = "0.1"
8+
version = "0.1.1"
99
description = "Simple wrapper for numpy and cupy. Replace `import numpy as np` with `import cunumpy as xp`."
1010
readme = "README.md"
1111
keywords = [ "python" ]
@@ -48,3 +48,6 @@ urls."Source" = "https://github.com/max-models/cunumpy"
4848

4949
[tool.setuptools.packages.find]
5050
where = [ "src" ]
51+
52+
[tool.setuptools.package-data]
53+
cunumpy = [ "py.typed", "*.pyi" ]

src/cunumpy/__init__.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Stub file for Pylance/mypy: exposes all numpy symbols so that
2+
# `import cunumpy as xp` followed by `xp.<Tab>` shows numpy completions.
3+
# At runtime the real __init__.py dispatches to numpy or cupy via __getattr__.
4+
from numpy import *
5+
from numpy import __config__, __version__
6+
7+
from . import xp

src/cunumpy/py.typed

Whitespace-only changes.

tests/unit/test_app.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
import cunumpy as xp
24

35

@@ -9,5 +11,16 @@ def test_xp_array():
911
print(f"{arr = } {type(arr) = }")
1012

1113

14+
def test_numpy_symbols_accessible():
15+
"""All public numpy symbols must be reachable via cunumpy.
16+
17+
This validates the runtime behaviour that the stub file (__init__.pyi)
18+
declares to Pylance so that `xp.<Tab>` shows numpy completions in VS Code.
19+
"""
20+
missing = [name for name in np.__all__ if not hasattr(xp, name)]
21+
assert missing == [], f"Symbols not accessible via cunumpy: {missing}"
22+
23+
1224
if __name__ == "__main__":
1325
test_xp_array()
26+
test_numpy_symbols_accessible()

0 commit comments

Comments
 (0)