Skip to content

Commit b8c7457

Browse files
authored
Allows the caller to pass in argv to hl.main() (#8937)
1 parent ec8654a commit b8c7457

File tree

3 files changed

+45
-30
lines changed

3 files changed

+45
-30
lines changed

python_bindings/src/halide/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ def install_dir():
3232
active_generator_context,
3333
alias,
3434
funcs,
35-
Generator,
3635
generator,
36+
main,
37+
Generator,
3738
GeneratorParam,
3839
InputBuffer,
3940
InputScalar,

python_bindings/src/halide/_generator_helpers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from enum import Enum
55
from functools import total_ordering
66
from .halide_ import (
7+
_,
8+
_generate_filter_main,
9+
_unique_name,
10+
_UnspecifiedType,
711
ArgInfo,
812
ArgInfoDirection,
913
ArgInfoKind,
@@ -21,9 +25,6 @@
2125
Type,
2226
UInt,
2327
Var,
24-
_,
25-
_UnspecifiedType,
26-
_unique_name,
2728
)
2829
from inspect import isclass
2930
from typing import Any, Optional
@@ -892,3 +893,14 @@ def funcs(names: str) -> tuple[Func, ...]:
892893
def vars(names: str) -> tuple[Var, ...]:
893894
"""Given a space-delimited string, create a Var for each substring and return as a tuple."""
894895
return tuple(Var(n) for n in names.split(" "))
896+
897+
898+
def main(argv: Optional[list[str]] = None):
899+
"""Entrypoint for invoking all registered generators.
900+
901+
Args:
902+
argv: A list of command-line arguments to pass to the generator. If None, uses sys.argv.
903+
"""
904+
if argv is None:
905+
argv = sys.argv
906+
_generate_filter_main(argv)

python_bindings/src/halide/halide_/PyGenerator.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "PyGenerator.h"
22

33
#include <pybind11/embed.h>
4+
#include <string>
5+
#include <vector>
46

57
namespace Halide {
68
namespace PythonBindings {
@@ -15,16 +17,6 @@ using Halide::Parameter;
1517
using Halide::Internal::ArgInfoDirection;
1618
using Halide::Internal::ArgInfoKind;
1719

18-
template<typename T>
19-
std::map<std::string, T> dict_to_map(const py::dict &dict) {
20-
_halide_user_assert(!dict.is(py::none()));
21-
std::map<std::string, T> m;
22-
for (auto it : dict) {
23-
m[it.first.cast<std::string>()] = it.second.cast<T>();
24-
}
25-
return m;
26-
}
27-
2820
class PyGeneratorBase : public AbstractGenerator {
2921
// The name declared in the Python function's decorator
3022
const std::string name_;
@@ -165,22 +157,32 @@ void define_generator(py::module &m) {
165157
return o.str();
166158
});
167159

168-
m.def("main", []() -> void {
169-
py::object argv_object = py::module_::import("sys").attr("argv");
170-
std::vector<std::string> argv_vector = args_to_vector<std::string>(argv_object);
171-
std::vector<char *> argv;
172-
argv.reserve(argv_vector.size());
173-
for (auto &s : argv_vector) {
174-
argv.push_back(const_cast<char *>(s.c_str()));
175-
}
176-
int result = Halide::Internal::generate_filter_main((int)argv.size(), argv.data(), PyGeneratorFactoryProvider());
177-
if (result != 0) {
178-
// Some paths in generate_filter_main() will fail with user_error or similar (which throws an exception
179-
// due to how libHalide is built for python), but some paths just return an error code, so
180-
// be sure to handle both.
181-
throw std::runtime_error("Generator failed: " + std::to_string(result));
182-
}
183-
});
160+
m.def("_generate_filter_main", //
161+
[](const std::vector<std::string> &arguments) -> void {
162+
if (arguments.empty()) {
163+
throw std::invalid_argument("No arguments provided to _generate_filter_main");
164+
}
165+
166+
// POSIX requires argv to be mutable and null-terminated
167+
std::vector<char *> argv;
168+
argv.reserve(arguments.size() + 1);
169+
for (const auto &s : arguments) {
170+
argv.push_back(const_cast<char *>(s.c_str()));
171+
}
172+
argv.push_back(nullptr);
173+
174+
const int result = Halide::Internal::generate_filter_main(
175+
static_cast<int>(argv.size()) - 1, argv.data(), PyGeneratorFactoryProvider());
176+
if (result != 0) {
177+
// Some paths in generate_filter_main() will fail with user_error
178+
// or similar (which throws an exception due to how libHalide is
179+
// built for Python), but other paths just return an error code.
180+
// For consistency, handle both by throwing a C++ exception, which
181+
// PyBind11 turns into a Python exception.
182+
throw std::runtime_error("Generator failed: " + std::to_string(result));
183+
} //
184+
},
185+
py::arg("argv"));
184186

185187
m.def("_unique_name", []() -> std::string {
186188
return ::Halide::Internal::unique_name('p');

0 commit comments

Comments
 (0)