Skip to content

Commit e62aab6

Browse files
committed
Move impl to C++
1 parent b0794c3 commit e62aab6

3 files changed

Lines changed: 157 additions & 37 deletions

File tree

python/src/specialize.cc

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <functional>
77
#include <pybind11/pybind11.h>
88
#include <string>
9+
#include <string_view>
910
#include <unordered_map>
1011
#include <utility>
1112

@@ -605,9 +606,151 @@ PyObject *specialize_impl(PyObject *self, PyObject *const *args,
605606
return PyTuple_Pack(2, type.ptr(), key.ptr());
606607
}
607608

609+
bool visit_make_tensordesc_args(PyObject *arg, PyObject *sig,
610+
PyObject *relevant_paths,
611+
PyObject *tensordesc_meta,
612+
bool has_tensordesc_meta, PyObject *base_args,
613+
PyObject *make_tensordesc_arg,
614+
Py_ssize_t *tensordesc_idx, PyObject *result) {
615+
assert(PyTuple_Check(sig));
616+
auto arg_fast =
617+
from_new_ref(PySequence_Fast(arg, "Expected iterable args node"));
618+
if (!arg_fast)
619+
return false;
620+
621+
Py_ssize_t arg_len = PySequence_Fast_GET_SIZE(arg_fast.ptr());
622+
Py_ssize_t sig_len = PyTuple_GET_SIZE(sig);
623+
assert(sig_len == arg_len || !"Invalid signature");
624+
Py_ssize_t len = arg_len;
625+
626+
for (Py_ssize_t i = 0; i < len; ++i) {
627+
PyObject *a = PySequence_Fast_GET_ITEM(arg_fast.ptr(), i);
628+
PyObject *s = PyTuple_GET_ITEM(sig, i);
629+
630+
if (PyUnicode_CheckExact(s)) {
631+
Py_ssize_t size;
632+
const char *type_cstr = PyUnicode_AsUTF8AndSize(s, &size);
633+
if (!type_cstr)
634+
return false;
635+
636+
// if not s.startswith("tensordesc")
637+
std::string_view tensordesc = "tensordesc";
638+
std::string_view type_str(type_cstr, size);
639+
if (type_str.substr(0, tensordesc.length()) != tensordesc) {
640+
if (PyList_Append(result, a) < 0)
641+
return false;
642+
continue;
643+
}
644+
645+
PyObject *meta = Py_None;
646+
if (has_tensordesc_meta) {
647+
// Borrowed reference
648+
meta = PyList_GetItem(tensordesc_meta, *tensordesc_idx);
649+
if (!meta)
650+
return false;
651+
}
652+
653+
PyObject *vector_args[] = {a, meta, base_args};
654+
auto desc_args = from_new_ref(
655+
PyObject_Vectorcall(make_tensordesc_arg, vector_args, 3, nullptr));
656+
if (!desc_args)
657+
return false;
658+
// list.extend(desc_args)
659+
if (PyList_SetSlice(result, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX,
660+
desc_args.ptr()) < 0)
661+
return false;
662+
663+
*tensordesc_idx += 1;
664+
continue;
665+
}
666+
667+
auto key = from_new_ref(PyLong_FromSsize_t(i));
668+
if (!key)
669+
return false;
670+
671+
// Borrowed ref
672+
PyObject *inner_relevant_paths =
673+
PyDict_GetItemWithError(relevant_paths, key.ptr());
674+
if (PyErr_Occurred())
675+
return false;
676+
677+
if (!inner_relevant_paths) {
678+
// Short-circuit if tuple doesn't contain any tensordesc args
679+
if (PyList_Append(result, a) < 0)
680+
return false;
681+
continue;
682+
}
683+
684+
// Recurse into tuple
685+
auto inner_res = from_new_ref(PyList_New(0));
686+
if (!inner_res)
687+
return false;
688+
if (!visit_make_tensordesc_args(
689+
a, s, inner_relevant_paths, tensordesc_meta, has_tensordesc_meta,
690+
base_args, make_tensordesc_arg, tensordesc_idx, inner_res.ptr()))
691+
return false;
692+
693+
auto inner_tuple = from_new_ref(PyList_AsTuple(inner_res.ptr()));
694+
if (!inner_tuple)
695+
return false;
696+
if (PyList_Append(result, inner_tuple.ptr()) < 0)
697+
return false;
698+
}
699+
return true;
700+
}
701+
702+
PyObject *make_tensordesc_args(PyObject *self, PyObject *const *args,
703+
Py_ssize_t nargs) {
704+
if (nargs != 6) {
705+
PyErr_SetString(PyExc_TypeError,
706+
"make_tensordesc_args expected 6 arguments");
707+
return nullptr;
708+
}
709+
710+
PyObject *kernel_args = args[0];
711+
PyObject *signature = args[1];
712+
PyObject *relevant_paths = args[2];
713+
PyObject *tensordesc_meta = args[3];
714+
PyObject *base_args = args[4];
715+
PyObject *make_tensordesc_arg = args[5];
716+
717+
bool has_tensordesc_meta = tensordesc_meta != Py_None;
718+
if (has_tensordesc_meta && !PyList_CheckExact(tensordesc_meta)) {
719+
PyErr_SetString(PyExc_TypeError, "Expected tensordesc_meta to be a list");
720+
return nullptr;
721+
}
722+
723+
auto result = from_new_ref(PyList_New(0));
724+
if (!result)
725+
return nullptr;
726+
727+
Py_ssize_t tensordesc_idx = 0;
728+
if (!visit_make_tensordesc_args(kernel_args, signature, relevant_paths,
729+
tensordesc_meta, has_tensordesc_meta,
730+
base_args, make_tensordesc_arg,
731+
&tensordesc_idx, result.ptr()))
732+
return nullptr;
733+
734+
if (has_tensordesc_meta) {
735+
Py_ssize_t meta_len = PySequence_Size(tensordesc_meta);
736+
if (meta_len < 0)
737+
return nullptr;
738+
739+
if (tensordesc_idx != meta_len) {
740+
PyErr_SetString(PyExc_ValueError,
741+
"make_tensordesc_args: tensordesc_idx != meta_len");
742+
return nullptr;
743+
}
744+
}
745+
746+
return result.release().ptr();
747+
}
748+
608749
static PyMethodDef module_methods[] = {
609750
{"native_specialize_impl", (PyCFunction)specialize_impl, METH_FASTCALL,
610751
nullptr},
752+
{"make_tensordesc_args", (PyCFunction)make_tensordesc_args, METH_FASTCALL,
753+
"Helper to translate tensordesc args"},
611754
{nullptr, nullptr, 0, nullptr} // sentinel
612755
};
613756

python/test/unit/runtime/test_driver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import triton
66
import triton.language as tl
7-
from triton.backends.driver import expand_signature, wrap_descriptors
7+
from triton.backends.driver import expand_signature, wrap_handle_tensordesc_impl
88

99

1010
def test_is_lazy():
@@ -69,7 +69,7 @@ def test_expand_signature_with_aggregate_tensordesc():
6969
assert expanded[2:] == ["nvTmaDesc", "i32", "i32", "i32", "i32", "i64", "i64", "i64", "i64"]
7070

7171

72-
def test_wrap_descriptors_handles_aggregate_arguments():
72+
def test_wrap_tensordesc_handles_aggregate_arguments():
7373
signature = {0: ("tensordesc<fp16[16,16]>", "i32"), 1: "i64", 2: "tensordesc<fp16[16,16]>"}
7474
outer_meta = {"tag": "outer"}
7575
launcher_calls = []
@@ -81,24 +81,24 @@ def launcher(*args):
8181
def make_descriptor(arg, meta, base_args):
8282
return [("desc", arg, meta, base_args[0]), ("shape", arg)]
8383

84-
wrapped = wrap_descriptors(launcher, signature, [None, outer_meta], make_descriptor)
84+
wrapped = wrap_handle_tensordesc_impl(launcher, signature, [None, outer_meta], make_descriptor)
8585
assert wrapped("meta0", "meta1", (("A", 7), 9, "B")) == "ok"
8686

8787
assert len(launcher_calls) == 1
8888
assert launcher_calls[0][0] == "meta0"
8989
assert launcher_calls[0][1] == "meta1"
90-
assert launcher_calls[0][2] == (
90+
assert launcher_calls[0][2] == [
9191
(("desc", "A", None, "meta0"), ("shape", "A"), 7),
9292
9,
9393
("desc", "B", outer_meta, "meta0"),
9494
("shape", "B"),
95-
)
95+
]
9696

9797

98-
def test_wrap_descriptors_is_noop_without_tensordesc():
98+
def test_wrap_tensordesc_is_noop_without_tensordesc():
9999

100100
def launcher(*args):
101101
return args
102102

103-
wrapped = wrap_descriptors(launcher, {0: "i32", 1: ("i64", "constexpr")}, None, lambda *_: [])
103+
wrapped = wrap_handle_tensordesc_impl(launcher, {0: "i32", 1: ("i64", "constexpr")}, None, lambda *_: [])
104104
assert wrapped is launcher

python/triton/backends/driver.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from abc import ABCMeta, abstractmethod
22
import re
33
from typing import Callable, List, Protocol, Sequence
4-
from collections import defaultdict
54

65
from triton._utils import find_paths_if
6+
from triton._C.libtriton import make_tensordesc_args
77

88

99
def decompose_descriptor(arg):
@@ -16,48 +16,25 @@ def _is_descriptor(arg):
1616
return isinstance(arg, str) and arg.startswith("tensordesc")
1717

1818

19-
def _make_tensordesc_args(args, signature, relevant_paths, tensordesc_meta, base_args, make_tensordesc_arg):
20-
21-
def visit(arg, sig, relevant_paths, tensordesc_idx, result):
22-
for i, (a, s) in enumerate(zip(arg, sig)):
23-
rel_paths = relevant_paths.get(i, None)
24-
if rel_paths is None:
25-
result.append(a)
26-
elif len(rel_paths) == 0:
27-
meta = tensordesc_meta[tensordesc_idx] if tensordesc_meta else None
28-
result.extend(make_tensordesc_arg(a, meta, base_args))
29-
tensordesc_idx += 1
30-
else:
31-
inner_res = []
32-
tensordesc_idx = visit(a, s, rel_paths, tensordesc_idx, inner_res)
33-
result.append(tuple(inner_res))
34-
return tensordesc_idx
35-
36-
result = []
37-
tensordesc_idx = visit(args, signature, relevant_paths, 0, result)
38-
assert not tensordesc_meta or tensordesc_idx == len(tensordesc_meta)
39-
return result
40-
41-
4219
def wrap_handle_tensordesc_impl(launcher, signature, tensordesc_meta, make_tensordesc_arg):
43-
signature = tuple(signature.values()) if hasattr(signature, "values") else tuple(signature)
20+
signature = tuple(signature.values())
4421
tensordesc_paths = find_paths_if(signature, lambda _, x: _is_descriptor(x))
4522
if len(tensordesc_paths) == 0:
4623
return launcher
4724

48-
# Build a tree to speed up checking, it will look like:
49-
# signature = ['tensordesc', 'i32', ('i32', 'tensordesc')]
25+
# Build a tree to speed up tensordesc type checking, e.g.
26+
# signature = ('tensordesc', 'i32', ('i32', 'tensordesc'))
5027
# relevant_paths = {0: {}, 2: {1: {}}}
51-
relevant_paths = defaultdict(defaultdict)
28+
relevant_paths = {}
5229
for path in tensordesc_paths:
5330
cur = relevant_paths
5431
for step in path:
55-
cur = cur[step]
32+
cur = cur.setdefault(step, {})
5633

5734
def inner(*args):
5835
base_args = args[:-1]
5936
kernel_args = args[-1]
60-
wrapped = _make_tensordesc_args(
37+
wrapped = make_tensordesc_args(
6138
kernel_args,
6239
signature,
6340
relevant_paths,

0 commit comments

Comments
 (0)