Skip to content

Commit 39ee998

Browse files
authored
[llvmpy] from_capsule API (#70)
1 parent 1170192 commit 39ee998

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

build_tools/build_llvm.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ echo "LLVM_SOURCE_DIR=${LLVM_SOURCE_DIR}"
2424
echo "LLVM_BUILD_DIR=${LLVM_BUILD_DIR}"
2525
echo "LLVM_INSTALL_DIR=${LLVM_INSTALL_DIR}"
2626

27-
python3_command=""
27+
python3_command="python"
2828
if (command -v python3 &> /dev/null); then
2929
python3_command="python3"
30-
elif (command -v python &> /dev/null); then
31-
python3_command="python"
3230
fi
3331

3432
Python3_EXECUTABLE="${Python3_EXECUTABLE:-$(which $python3_command)}"

projects/eudsl-llvmpy/eudsl-llvmpy-generate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def postprocess(code: str) -> str:
7171
code = code.replace('.value("llvm', '.value("')
7272
code = code.replace('m.def("llvm_', 'm.def("')
7373
code = code.replace('m.def("llvm', 'm.def("')
74+
pattern = r'\.def_rw\("ptr", &(\w+)::ptr, ""\)'
75+
repl = r'.def_rw("ptr", &\1::ptr, "").def_static("from_capsule", [](nb::capsule caps) -> \1 { void *ptr = PyCapsule_GetPointer(caps.ptr(), "nb_handle"); return {ptr}; })'
76+
code = re.sub(pattern, repl, code)
7477

7578
return code
7679

projects/eudsl-llvmpy/tests/test_bindings.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Copyright (c) 2024.
55
from textwrap import dedent
66

7-
from llvm import types_ as T
7+
from llvm import types_ as T, ModuleRef, print_module_to_string
88
from llvm.context import context
99
from llvm.function import function
1010
from llvm.instructions import add, ret
@@ -80,7 +80,44 @@ def sum(a: T.int32, b: T.int32, c: T.float) -> T.int32:
8080
assert correct == mod_str
8181

8282

83+
def test_from_capsule():
84+
src = dedent(
85+
"""
86+
; ModuleID = 'test_smoke'
87+
source_filename = "test_smoke"
88+
89+
declare i32 @foo()
90+
91+
declare i32 @bar()
92+
93+
define i32 @entry(i32 %argc) {
94+
entry:
95+
%and = and i32 %argc, 1
96+
%tobool = icmp eq i32 %and, 0
97+
br i1 %tobool, label %if.end, label %if.then
98+
99+
if.then: ; preds = %entry
100+
%call = tail call i32 @foo()
101+
br label %return
102+
103+
if.end: ; preds = %entry
104+
%call1 = tail call i32 @bar()
105+
br label %return
106+
107+
return: ; preds = %if.end, %if.then
108+
%retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.end ]
109+
ret i32 %retval.0
110+
}
111+
"""
112+
)
113+
with context(src=src, buffer_name="test_smoke") as ctx:
114+
copied_mod = ModuleRef.from_capsule(ctx.module.ptr)
115+
mod_str = print_module_to_string(copied_mod)
116+
assert src.strip() == mod_str.strip()
117+
118+
83119
if __name__ == "__main__":
84120
test_smoke()
85121
test_builder()
86122
test_symbol_collision()
123+
test_from_capsule()

0 commit comments

Comments
 (0)