|
6 | 6 | #include <functional> |
7 | 7 | #include <pybind11/pybind11.h> |
8 | 8 | #include <string> |
| 9 | +#include <string_view> |
9 | 10 | #include <unordered_map> |
10 | 11 | #include <utility> |
11 | 12 |
|
@@ -605,9 +606,151 @@ PyObject *specialize_impl(PyObject *self, PyObject *const *args, |
605 | 606 | return PyTuple_Pack(2, type.ptr(), key.ptr()); |
606 | 607 | } |
607 | 608 |
|
| 609 | +bool visit_make_tensordesc_args(PyObject *arg, PyObject *sig, |
| 610 | + PyObject *relevant_paths, |
| 611 | + PyObject *tensordesc_meta, |
| 612 | + bool has_tensordesc_meta, PyObject *base_args, |
| 613 | + PyObject *make_tensordesc_arg, |
| 614 | + Py_ssize_t *tensordesc_idx, PyObject *result) { |
| 615 | + assert(PyTuple_Check(sig)); |
| 616 | + auto arg_fast = |
| 617 | + from_new_ref(PySequence_Fast(arg, "Expected iterable args node")); |
| 618 | + if (!arg_fast) |
| 619 | + return false; |
| 620 | + |
| 621 | + Py_ssize_t arg_len = PySequence_Fast_GET_SIZE(arg_fast.ptr()); |
| 622 | + Py_ssize_t sig_len = PyTuple_GET_SIZE(sig); |
| 623 | + assert(sig_len == arg_len || !"Invalid signature"); |
| 624 | + Py_ssize_t len = arg_len; |
| 625 | + |
| 626 | + for (Py_ssize_t i = 0; i < len; ++i) { |
| 627 | + PyObject *a = PySequence_Fast_GET_ITEM(arg_fast.ptr(), i); |
| 628 | + PyObject *s = PyTuple_GET_ITEM(sig, i); |
| 629 | + |
| 630 | + if (PyUnicode_CheckExact(s)) { |
| 631 | + Py_ssize_t size; |
| 632 | + const char *type_cstr = PyUnicode_AsUTF8AndSize(s, &size); |
| 633 | + if (!type_cstr) |
| 634 | + return false; |
| 635 | + |
| 636 | + // if not s.startswith("tensordesc") |
| 637 | + std::string_view tensordesc = "tensordesc"; |
| 638 | + std::string_view type_str(type_cstr, size); |
| 639 | + if (type_str.substr(0, tensordesc.length()) != tensordesc) { |
| 640 | + if (PyList_Append(result, a) < 0) |
| 641 | + return false; |
| 642 | + continue; |
| 643 | + } |
| 644 | + |
| 645 | + PyObject *meta = Py_None; |
| 646 | + if (has_tensordesc_meta) { |
| 647 | + // Borrowed reference |
| 648 | + meta = PyList_GetItem(tensordesc_meta, *tensordesc_idx); |
| 649 | + if (!meta) |
| 650 | + return false; |
| 651 | + } |
| 652 | + |
| 653 | + PyObject *vector_args[] = {a, meta, base_args}; |
| 654 | + auto desc_args = from_new_ref( |
| 655 | + PyObject_Vectorcall(make_tensordesc_arg, vector_args, 3, nullptr)); |
| 656 | + if (!desc_args) |
| 657 | + return false; |
| 658 | + // list.extend(desc_args) |
| 659 | + if (PyList_SetSlice(result, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, |
| 660 | + desc_args.ptr()) < 0) |
| 661 | + return false; |
| 662 | + |
| 663 | + *tensordesc_idx += 1; |
| 664 | + continue; |
| 665 | + } |
| 666 | + |
| 667 | + auto key = from_new_ref(PyLong_FromSsize_t(i)); |
| 668 | + if (!key) |
| 669 | + return false; |
| 670 | + |
| 671 | + // Borrowed ref |
| 672 | + PyObject *inner_relevant_paths = |
| 673 | + PyDict_GetItemWithError(relevant_paths, key.ptr()); |
| 674 | + if (PyErr_Occurred()) |
| 675 | + return false; |
| 676 | + |
| 677 | + if (!inner_relevant_paths) { |
| 678 | + // Short-circuit if tuple doesn't contain any tensordesc args |
| 679 | + if (PyList_Append(result, a) < 0) |
| 680 | + return false; |
| 681 | + continue; |
| 682 | + } |
| 683 | + |
| 684 | + // Recurse into tuple |
| 685 | + auto inner_res = from_new_ref(PyList_New(0)); |
| 686 | + if (!inner_res) |
| 687 | + return false; |
| 688 | + if (!visit_make_tensordesc_args( |
| 689 | + a, s, inner_relevant_paths, tensordesc_meta, has_tensordesc_meta, |
| 690 | + base_args, make_tensordesc_arg, tensordesc_idx, inner_res.ptr())) |
| 691 | + return false; |
| 692 | + |
| 693 | + auto inner_tuple = from_new_ref(PyList_AsTuple(inner_res.ptr())); |
| 694 | + if (!inner_tuple) |
| 695 | + return false; |
| 696 | + if (PyList_Append(result, inner_tuple.ptr()) < 0) |
| 697 | + return false; |
| 698 | + } |
| 699 | + return true; |
| 700 | +} |
| 701 | + |
| 702 | +PyObject *make_tensordesc_args(PyObject *self, PyObject *const *args, |
| 703 | + Py_ssize_t nargs) { |
| 704 | + if (nargs != 6) { |
| 705 | + PyErr_SetString(PyExc_TypeError, |
| 706 | + "make_tensordesc_args expected 6 arguments"); |
| 707 | + return nullptr; |
| 708 | + } |
| 709 | + |
| 710 | + PyObject *kernel_args = args[0]; |
| 711 | + PyObject *signature = args[1]; |
| 712 | + PyObject *relevant_paths = args[2]; |
| 713 | + PyObject *tensordesc_meta = args[3]; |
| 714 | + PyObject *base_args = args[4]; |
| 715 | + PyObject *make_tensordesc_arg = args[5]; |
| 716 | + |
| 717 | + bool has_tensordesc_meta = tensordesc_meta != Py_None; |
| 718 | + if (has_tensordesc_meta && !PyList_CheckExact(tensordesc_meta)) { |
| 719 | + PyErr_SetString(PyExc_TypeError, "Expected tensordesc_meta to be a list"); |
| 720 | + return nullptr; |
| 721 | + } |
| 722 | + |
| 723 | + auto result = from_new_ref(PyList_New(0)); |
| 724 | + if (!result) |
| 725 | + return nullptr; |
| 726 | + |
| 727 | + Py_ssize_t tensordesc_idx = 0; |
| 728 | + if (!visit_make_tensordesc_args(kernel_args, signature, relevant_paths, |
| 729 | + tensordesc_meta, has_tensordesc_meta, |
| 730 | + base_args, make_tensordesc_arg, |
| 731 | + &tensordesc_idx, result.ptr())) |
| 732 | + return nullptr; |
| 733 | + |
| 734 | + if (has_tensordesc_meta) { |
| 735 | + Py_ssize_t meta_len = PySequence_Size(tensordesc_meta); |
| 736 | + if (meta_len < 0) |
| 737 | + return nullptr; |
| 738 | + |
| 739 | + if (tensordesc_idx != meta_len) { |
| 740 | + PyErr_SetString(PyExc_ValueError, |
| 741 | + "make_tensordesc_args: tensordesc_idx != meta_len"); |
| 742 | + return nullptr; |
| 743 | + } |
| 744 | + } |
| 745 | + |
| 746 | + return result.release().ptr(); |
| 747 | +} |
| 748 | + |
608 | 749 | static PyMethodDef module_methods[] = { |
609 | 750 | {"native_specialize_impl", (PyCFunction)specialize_impl, METH_FASTCALL, |
610 | 751 | nullptr}, |
| 752 | + {"make_tensordesc_args", (PyCFunction)make_tensordesc_args, METH_FASTCALL, |
| 753 | + "Helper to translate tensordesc args"}, |
611 | 754 | {nullptr, nullptr, 0, nullptr} // sentinel |
612 | 755 | }; |
613 | 756 |
|
|
0 commit comments