|
| 1 | +#include <multipy/runtime/deploy.h> |
| 2 | +#include <pybind11/pybind11.h> |
| 3 | +#include <torch/csrc/jit/python/pybind_utils.h> |
| 4 | +#include <torch/csrc/lazy/core/debug_util.h> |
| 5 | + |
| 6 | +namespace py = pybind11; |
| 7 | + |
| 8 | +using namespace torch::deploy; |
| 9 | + |
| 10 | +namespace { |
| 11 | +at::IValue detachIValue(at::IValue&& iv) { |
| 12 | + if (iv.isTensor()) { |
| 13 | + // detach tensors to avoid cross interpreter autograd state |
| 14 | + return std::move(iv).toTensor().detach(); |
| 15 | + } |
| 16 | + return iv; |
| 17 | +} |
| 18 | +at::IValue toIValue(const py::handle& obj) { |
| 19 | + return detachIValue(torch::jit::toTypeInferredIValue(obj)); |
| 20 | +} |
| 21 | +} // namespace |
| 22 | + |
| 23 | +PYBIND11_MODULE(multipy_pybind, m) { |
| 24 | + m.doc() = "multipy python bindings"; |
| 25 | + |
| 26 | + py::class_<InterpreterManager>(m, "InterpreterManager") |
| 27 | + .def(py::init<size_t>()) |
| 28 | + .def("acquire_one", &InterpreterManager::acquireOne) |
| 29 | + .def( |
| 30 | + "__len__", |
| 31 | + [](InterpreterManager& self) -> int { |
| 32 | + return self.allInstances().size(); |
| 33 | + }) |
| 34 | + .def( |
| 35 | + "__getitem__", |
| 36 | + [](InterpreterManager& self, int i) -> InterpreterSession { |
| 37 | + return self.allInstances().at(i).acquireSession(); |
| 38 | + }); |
| 39 | + |
| 40 | + py::class_<Interpreter>(m, "Interpreter") |
| 41 | + .def("acquire_session", &Interpreter::acquireSession); |
| 42 | + |
| 43 | + py::class_<InterpreterSession>(m, "InterpreterSession") |
| 44 | + .def("global_", &InterpreterSession::global); |
| 45 | + |
| 46 | + py::class_<Obj>(m, "Obj") |
| 47 | + .def( |
| 48 | + "__call__", |
| 49 | + [](Obj& self, py::args args, const py::kwargs& kwargs) -> Obj { |
| 50 | + std::vector<at::IValue> iargs; |
| 51 | + std::unordered_map<std::string, at::IValue> ikwargs; |
| 52 | + |
| 53 | + for (auto& arg : args) { |
| 54 | + iargs.emplace_back(toIValue(arg)); |
| 55 | + } |
| 56 | + for (auto& arg : kwargs) { |
| 57 | + ikwargs.emplace( |
| 58 | + arg.first.cast<std::string>(), toIValue(arg.second)); |
| 59 | + } |
| 60 | + |
| 61 | + return self.callKwargs(iargs, ikwargs); |
| 62 | + }) |
| 63 | + .def( |
| 64 | + "__getattr__", |
| 65 | + [](Obj& self, std::string attr) -> Obj { |
| 66 | + return self.attr(attr.c_str()); |
| 67 | + }) |
| 68 | + .def("deref", [](Obj& self) -> py::object { |
| 69 | + return ::torch::jit::toPyObject(detachIValue(self.toIValue())); |
| 70 | + }); |
| 71 | +} |
0 commit comments