Skip to content

Commit 954dc05

Browse files
authored
Merge branch 'main' into main
2 parents a80220e + 9415a9a commit 954dc05

File tree

7 files changed

+140
-6
lines changed

7 files changed

+140
-6
lines changed

extensions/pyo3/private/pyo3.bzl

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,43 @@ def _py_pyo3_library_impl(ctx):
8787
is_windows = extension.basename.endswith(".dll")
8888

8989
# https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds
90-
ext = ctx.actions.declare_file("{}{}".format(
91-
ctx.label.name,
92-
".pyd" if is_windows else ".so",
93-
))
90+
#
91+
# Determine the on-disk and logical Python module layout.
92+
#
93+
# `module` is a full dotted module path (e.g. "foo.bar"). We split on the
94+
# last "." such that:
95+
# - module_prefix == "foo"
96+
# - module_name == "bar"
97+
#
98+
# `module_name` must match the `#[pymodule] fn <name>(...)` in the Rust code
99+
# and is also what we pass to the stub generator.
100+
module_path = ctx.attr.module_name if ctx.attr.module_name else ctx.label.name.replace("/", ".")
101+
102+
if module_path.startswith(".") or module_path.endswith(".") or ".." in module_path:
103+
fail("Invalid `module` value '{}': expected a dotted module path like 'foo.bar'.".format(module_path))
104+
105+
last_dot = module_path.rfind(".")
106+
if last_dot == -1:
107+
module_prefix = None
108+
module_name = module_path
109+
else:
110+
module_prefix = module_path[:last_dot]
111+
module_name = module_path[last_dot + 1:]
112+
113+
if not module_name:
114+
fail("Invalid `module` value '{}': module name may not be empty.".format(module_path))
115+
116+
# Convert module_prefix (e.g. "foo.bar") into a path ("foo/bar") and place
117+
# the extension and stubs in the corresponding directory.
118+
if module_prefix:
119+
module_prefix_path = module_prefix.replace(".", "/")
120+
module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so")
121+
stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name)
122+
else:
123+
module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so")
124+
stub_relpath = "{}.pyi".format(module_name)
125+
126+
ext = ctx.actions.declare_file(module_relpath)
94127
ctx.actions.symlink(
95128
output = ext,
96129
target_file = extension,
@@ -99,10 +132,10 @@ def _py_pyo3_library_impl(ctx):
99132

100133
stub = None
101134
if _stubs_enabled(ctx.attr.stubs, toolchain):
102-
stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name))
135+
stub = ctx.actions.declare_file(stub_relpath)
103136

104137
args = ctx.actions.args()
105-
args.add(ctx.label.name, format = "--module_name=%s")
138+
args.add(module_name, format = "--module_name=%s")
106139
args.add(ext, format = "--module_path=%s")
107140
args.add(stub, format = "--output=%s")
108141
ctx.actions.run(
@@ -180,6 +213,9 @@ py_pyo3_library = rule(
180213
"imports": attr.string_list(
181214
doc = "List of import directories to be added to the `PYTHONPATH`.",
182215
),
216+
"module_name": attr.string(
217+
doc = "A full dotted Python module path implemented by this extension (e.g. `foo.bar`).",
218+
),
183219
"stubs": attr.int(
184220
doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.",
185221
default = -1,
@@ -218,6 +254,7 @@ def pyo3_extension(
218254
stubs = None,
219255
version = None,
220256
compilation_mode = "opt",
257+
module_name = None,
221258
**kwargs):
222259
"""Define a PyO3 python extension module.
223260
@@ -259,6 +296,7 @@ def pyo3_extension(
259296
For more details see [rust_shared_library][rsl].
260297
compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode)
261298
value to build the extension for. If set to `"current"`, the current configuration will be used.
299+
module_name (str, optional): A full dotted Python module path implemented by this extension (e.g. `foo.bar`).
262300
**kwargs (dict): Additional keyword arguments.
263301
"""
264302
tags = kwargs.pop("tags", [])
@@ -318,6 +356,7 @@ def pyo3_extension(
318356
compilation_mode = compilation_mode,
319357
stubs = stubs_int,
320358
imports = imports,
359+
module_name = module_name,
321360
tags = tags,
322361
visibility = visibility,
323362
**kwargs
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@rules_python//python:defs.bzl", "py_test")
2+
load("//:defs.bzl", "pyo3_extension")
3+
4+
pyo3_extension(
5+
name = "module_prefix",
6+
srcs = ["bar.rs"],
7+
edition = "2021",
8+
module_name = "foo.bar",
9+
)
10+
11+
py_test(
12+
name = "module_prefix_import_test",
13+
srcs = ["module_prefix_import_test.py"],
14+
deps = [":module_prefix"],
15+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use pyo3::prelude::*;
2+
3+
#[pyfunction]
4+
fn thing() -> PyResult<&'static str> {
5+
Ok("hello from rust")
6+
}
7+
8+
#[pymodule]
9+
fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
10+
m.add_function(wrap_pyfunction!(thing, m)?)?;
11+
Ok(())
12+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Tests that a pyo3 extension can be imported via a module prefix."""
2+
3+
import unittest
4+
from test.module_prefix.foo import bar
5+
6+
7+
class ModulePrefixImportTest(unittest.TestCase):
8+
"""Test Class."""
9+
10+
def test_import_and_call(self) -> None:
11+
"""Test that a pyo3 extension can be imported via a module prefix."""
12+
13+
result = bar.thing()
14+
self.assertEqual("hello from rust", result)
15+
16+
17+
if __name__ == "__main__":
18+
unittest.main()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@rules_python//python:defs.bzl", "py_test")
2+
load("//:defs.bzl", "pyo3_extension")
3+
4+
# Variant that relies on the `imports` attribute to put this package root on
5+
# `PYTHONPATH` (so you can import `foo.bar` directly).
6+
pyo3_extension(
7+
name = "module_prefix_imports",
8+
srcs = ["bar.rs"],
9+
edition = "2021",
10+
imports = ["."],
11+
module_name = "foo.bar",
12+
)
13+
14+
py_test(
15+
name = "module_prefix_imports_test",
16+
srcs = ["module_prefix_imports_test.py"],
17+
deps = [":module_prefix_imports"],
18+
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use pyo3::prelude::*;
2+
3+
#[pyfunction]
4+
fn thing() -> PyResult<&'static str> {
5+
Ok("hello from rust")
6+
}
7+
8+
#[pymodule]
9+
fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
10+
m.add_function(wrap_pyfunction!(thing, m)?)?;
11+
Ok(())
12+
}
13+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Tests importing a pyo3 extension via `imports = ["."]`."""
2+
3+
import unittest
4+
5+
from foo import bar # type: ignore[import-untyped]
6+
7+
8+
class ModulePrefixImportsTest(unittest.TestCase):
9+
"""Test Class."""
10+
11+
def test_import_and_call(self) -> None:
12+
"""Test that a pyo3 extension can be imported via a module prefix."""
13+
14+
result = bar.thing()
15+
self.assertEqual("hello from rust", result)
16+
17+
18+
if __name__ == "__main__":
19+
unittest.main()

0 commit comments

Comments
 (0)