Skip to content

Commit a2d8896

Browse files
authored
[Nvidia] Enable TMA im2col mode - driver support (#9305)
# Summary This is the fourth PR in a series that enables TMA im2col mode (in addition to the existing tiled mode) for NVIDIA GPUs. The goal of the series is to support TMA im2col mode in Gluon DSL. - First PR: #9202 - Second PR: #9225 - Third PR: #9303 - -> Fourth PR: #9305 PTX ISA documentation for TMA im2col mode: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode TMA tensor descriptor documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html # Summary of changes This PR adds the driver function for create the tensor descriptor for TMA im2col mode. The driver function can be found in the doc: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent d156ef5 commit a2d8896

2 files changed

Lines changed: 279 additions & 4 deletions

File tree

third_party/nvidia/backend/driver.c

Lines changed: 276 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,15 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
228228
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
229229
CUtensorMapFloatOOBfill oobFill);
230230

231+
typedef CUresult (*cuTensorMapEncodeIm2col_t)(
232+
CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
233+
cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
234+
const cuuint64_t *globalStrides, const int *pixelBoxLowerCorner,
235+
const int *pixelBoxUpperCorner, cuuint32_t channelsPerPixel,
236+
cuuint32_t pixelsPerColumn, const cuuint32_t *elementStrides,
237+
CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
238+
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
239+
231240
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig *config,
232241
CUfunction f, void **kernelParams,
233242
void **extra);
@@ -260,6 +269,9 @@ defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
260269
defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
261270
cuTensorMapEncodeTiled);
262271

272+
defineGetFunctionHandle(getCuTensorMapEncodeIm2colHandle,
273+
cuTensorMapEncodeIm2col);
274+
263275
defineGetFunctionHandle(getLaunchKernelExHandle, cuLaunchKernelEx);
264276

265277
static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
@@ -386,7 +398,7 @@ static PyTypeObject PyCUtensorMapType = {
386398
};
387399
// clang-format on
388400

389-
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
401+
static PyObject *fillTMADescriptorTiled(PyObject *self, PyObject *args) {
390402
unsigned long long global_address;
391403
int swizzle;
392404
int elemSize;
@@ -555,6 +567,265 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
555567
return NULL;
556568
}
557569

570+
static PyObject *fillTMADescriptorIm2col(PyObject *self, PyObject *args) {
571+
unsigned long long global_address;
572+
int swizzle;
573+
int elemSize;
574+
int elemType;
575+
PyObject *blockSize;
576+
PyObject *shape;
577+
PyObject *strides;
578+
int padding;
579+
PyObject *pixelBoxLower;
580+
PyObject *pixelBoxUpper;
581+
int channelsPerPixel;
582+
int pixelsPerColumn;
583+
PyObject *elementStrides;
584+
585+
if (!PyArg_ParseTuple(args, "KiiiOOOiOOiiO", &global_address, &swizzle,
586+
&elemSize, &elemType, &blockSize, &shape, &strides,
587+
&padding, &pixelBoxLower, &pixelBoxUpper,
588+
&channelsPerPixel, &pixelsPerColumn, &elementStrides)) {
589+
return NULL;
590+
}
591+
592+
PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
593+
(PyObject *)&PyCUtensorMapType, NULL);
594+
if (!desc) {
595+
return NULL;
596+
}
597+
598+
PyObject *blockSizeFast = NULL;
599+
PyObject *shapeFast = NULL;
600+
PyObject *stridesFast = NULL;
601+
PyObject *pixelBoxLowerFast = NULL;
602+
PyObject *pixelBoxUpperFast = NULL;
603+
PyObject *elementStridesFast = NULL;
604+
605+
uint32_t blockSizeInt[5];
606+
uint64_t shapeInt[5];
607+
uint64_t stridesLL[5];
608+
int pixelBoxLowerInt[5] = {0};
609+
int pixelBoxUpperInt[5] = {0};
610+
uint32_t elementStridesInt[5] = {1, 1, 1, 1, 1}; // Default to all 1s
611+
612+
// For im2col mode, shape determines the tensor rank, not blockSize
613+
// blockSize is typically 2D [pixelsPerColumn, channelsPerPixel]
614+
// while shape can be 4D or 5D (e.g., NHWC or NDHWC)
615+
shapeFast = PySequence_Fast(shape, "shape must be a sequence");
616+
if (!shapeFast)
617+
goto cleanup;
618+
int rank = PySequence_Fast_GET_SIZE(shapeFast);
619+
620+
for (int i = 0; i < rank; ++i) {
621+
PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
622+
if (!PyLong_Check(item)) {
623+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
624+
goto cleanup;
625+
}
626+
shapeInt[rank - i - 1] = PyLong_AsLong(item);
627+
}
628+
629+
blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
630+
if (!blockSizeFast)
631+
goto cleanup;
632+
int blockRank = PySequence_Fast_GET_SIZE(blockSizeFast);
633+
634+
for (int i = 0; i < blockRank; ++i) {
635+
PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
636+
if (!PyLong_Check(item)) {
637+
PyErr_SetString(PyExc_TypeError, "block size must be an int");
638+
goto cleanup;
639+
}
640+
blockSizeInt[blockRank - i - 1] = PyLong_AsLongLong(item);
641+
}
642+
643+
stridesFast = PySequence_Fast(strides, "strides must be a sequence");
644+
if (!stridesFast)
645+
goto cleanup;
646+
647+
if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
648+
PyErr_Format(PyExc_RuntimeError,
649+
"Rank mismatch for strides in fillTMADescriptorIm2col: shape "
650+
"has rank %d but strides has %zd elements. "
651+
"Expected strides to have %d elements.",
652+
rank, PySequence_Fast_GET_SIZE(stridesFast), rank);
653+
goto cleanup;
654+
}
655+
for (int i = 0; i + 1 < rank; ++i) {
656+
PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
657+
if (!PyLong_Check(item)) {
658+
PyErr_SetString(PyExc_TypeError, "strides must be an int");
659+
goto cleanup;
660+
}
661+
stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
662+
}
663+
stridesLL[rank - 1] =
664+
shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
665+
666+
// Parse pixel box lower corner
667+
pixelBoxLowerFast =
668+
PySequence_Fast(pixelBoxLower, "pixelBoxLower must be a sequence");
669+
if (!pixelBoxLowerFast)
670+
goto cleanup;
671+
672+
int spatialRank = PySequence_Fast_GET_SIZE(pixelBoxLowerFast);
673+
if (spatialRank > 5) {
674+
PyErr_SetString(PyExc_RuntimeError, "Pixel box rank too large (max 5)");
675+
goto cleanup;
676+
}
677+
678+
for (int i = 0; i < spatialRank; ++i) {
679+
PyObject *item = PySequence_Fast_GET_ITEM(pixelBoxLowerFast, i);
680+
if (!PyLong_Check(item)) {
681+
PyErr_SetString(PyExc_TypeError, "pixelBoxLower elements must be int");
682+
goto cleanup;
683+
}
684+
pixelBoxLowerInt[spatialRank - i - 1] = PyLong_AsLong(item);
685+
}
686+
687+
// Parse pixel box upper corner
688+
pixelBoxUpperFast =
689+
PySequence_Fast(pixelBoxUpper, "pixelBoxUpper must be a sequence");
690+
if (!pixelBoxUpperFast)
691+
goto cleanup;
692+
693+
if (spatialRank != PySequence_Fast_GET_SIZE(pixelBoxUpperFast)) {
694+
PyErr_SetString(PyExc_RuntimeError, "Pixel box corner rank mismatch");
695+
goto cleanup;
696+
}
697+
698+
for (int i = 0; i < spatialRank; ++i) {
699+
PyObject *item = PySequence_Fast_GET_ITEM(pixelBoxUpperFast, i);
700+
if (!PyLong_Check(item)) {
701+
PyErr_SetString(PyExc_TypeError, "pixelBoxUpper elements must be int");
702+
goto cleanup;
703+
}
704+
pixelBoxUpperInt[spatialRank - i - 1] = PyLong_AsLong(item);
705+
}
706+
707+
// Parse element strides
708+
elementStridesFast =
709+
PySequence_Fast(elementStrides, "elementStrides must be a sequence");
710+
if (!elementStridesFast)
711+
goto cleanup;
712+
713+
int elementStridesLen = PySequence_Fast_GET_SIZE(elementStridesFast);
714+
if (elementStridesLen != rank) {
715+
PyErr_SetString(PyExc_RuntimeError,
716+
"elementStrides length must match tensor rank");
717+
goto cleanup;
718+
}
719+
720+
for (int i = 0; i < rank; ++i) {
721+
PyObject *item = PySequence_Fast_GET_ITEM(elementStridesFast, i);
722+
if (!PyLong_Check(item)) {
723+
PyErr_SetString(PyExc_TypeError, "elementStrides elements must be int");
724+
goto cleanup;
725+
}
726+
elementStridesInt[rank - i - 1] = PyLong_AsLong(item);
727+
}
728+
729+
Py_DECREF(blockSizeFast);
730+
blockSizeFast = NULL;
731+
Py_DECREF(shapeFast);
732+
shapeFast = NULL;
733+
Py_DECREF(stridesFast);
734+
stridesFast = NULL;
735+
Py_DECREF(pixelBoxLowerFast);
736+
pixelBoxLowerFast = NULL;
737+
Py_DECREF(pixelBoxUpperFast);
738+
pixelBoxUpperFast = NULL;
739+
Py_DECREF(elementStridesFast);
740+
elementStridesFast = NULL;
741+
742+
CUtensorMapFloatOOBfill fill =
743+
(padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
744+
: CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
745+
746+
static cuTensorMapEncodeIm2col_t cuTensorMapEncodeIm2col = NULL;
747+
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeIm2col,
748+
getCuTensorMapEncodeIm2colHandle);
749+
750+
CUresult res = cuTensorMapEncodeIm2col(
751+
&desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
752+
stridesLL, pixelBoxLowerInt, pixelBoxUpperInt, channelsPerPixel,
753+
pixelsPerColumn, elementStridesInt, CU_TENSOR_MAP_INTERLEAVE_NONE,
754+
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill);
755+
756+
if (res != CUDA_SUCCESS) {
757+
const char *str;
758+
cuGetErrorString(res, &str);
759+
char err[4096] = {0};
760+
size_t off = 0;
761+
off += snprintf(err + off, sizeof(err) - off,
762+
"Triton Error [CUDA]: Failed to create im2col tensor map "
763+
"descriptor: %s\n",
764+
str ? str : "Unknown error");
765+
off +=
766+
snprintf(err + off, sizeof(err) - off,
767+
"elemType=%d rank=%d global_address=0x%llx elemSize=%d "
768+
"swizzle=%d padding=%d channelsPerPixel=%d "
769+
"pixelsPerColumn=%d\n",
770+
elemType, rank, (unsigned long long)global_address, elemSize,
771+
swizzle, padding, channelsPerPixel, pixelsPerColumn);
772+
off += snprintf(err + off, sizeof(err) - off, "shape=[");
773+
for (int i = 0; i < rank; ++i) {
774+
off +=
775+
snprintf(err + off, sizeof(err) - off, "%llu%s",
776+
(unsigned long long)shapeInt[i], (i + 1 < rank) ? ", " : "");
777+
}
778+
off += snprintf(err + off, sizeof(err) - off, "]\n");
779+
off += snprintf(err + off, sizeof(err) - off, "strides=[");
780+
for (int i = 0; i < rank; ++i) {
781+
off += snprintf(err + off, sizeof(err) - off, "%llu%s",
782+
(unsigned long long)stridesLL[i],
783+
(i + 1 < rank) ? ", " : "");
784+
}
785+
off += snprintf(err + off, sizeof(err) - off, "]\n");
786+
off += snprintf(err + off, sizeof(err) - off, "blockSize=[");
787+
for (int i = 0; i < blockRank; ++i) {
788+
off +=
789+
snprintf(err + off, sizeof(err) - off, "%u%s",
790+
(unsigned)blockSizeInt[i], (i + 1 < blockRank) ? ", " : "");
791+
}
792+
off += snprintf(err + off, sizeof(err) - off, "]\n");
793+
off += snprintf(err + off, sizeof(err) - off, "pixelBoxLower=[");
794+
for (int i = 0; i < spatialRank; ++i) {
795+
off += snprintf(err + off, sizeof(err) - off, "%d%s", pixelBoxLowerInt[i],
796+
(i + 1 < spatialRank) ? ", " : "");
797+
}
798+
off += snprintf(err + off, sizeof(err) - off, "] pixelBoxUpper=[");
799+
for (int i = 0; i < spatialRank; ++i) {
800+
off += snprintf(err + off, sizeof(err) - off, "%d%s", pixelBoxUpperInt[i],
801+
(i + 1 < spatialRank) ? ", " : "");
802+
}
803+
off += snprintf(err + off, sizeof(err) - off, "]\n");
804+
off += snprintf(err + off, sizeof(err) - off, "elementStrides=[");
805+
for (int i = 0; i < rank; ++i) {
806+
off +=
807+
snprintf(err + off, sizeof(err) - off, "%u%s",
808+
(unsigned)elementStridesInt[i], (i + 1 < rank) ? ", " : "");
809+
}
810+
off += snprintf(err + off, sizeof(err) - off, "]\n");
811+
PyErr_SetString(PyExc_RuntimeError, err);
812+
813+
goto cleanup;
814+
}
815+
816+
return (PyObject *)desc;
817+
818+
cleanup:
819+
Py_XDECREF(blockSizeFast);
820+
Py_XDECREF(shapeFast);
821+
Py_XDECREF(stridesFast);
822+
Py_XDECREF(pixelBoxLowerFast);
823+
Py_XDECREF(pixelBoxUpperFast);
824+
Py_XDECREF(elementStridesFast);
825+
Py_XDECREF(desc);
826+
return NULL;
827+
}
828+
558829
static void ensureCudaContext() {
559830
CUcontext pctx;
560831
CUDA_CHECK(cuCtxGetCurrent(&pctx));
@@ -1128,7 +1399,10 @@ static PyMethodDef ModuleMethods[] = {
11281399
"being dropped. This inherits all the limitations of this call; in "
11291400
"particular it's an error to change this value after launching any kernel "
11301401
"that calls printf()."},
1131-
{"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
1402+
{"fill_tma_descriptor_tiled", fillTMADescriptorTiled, METH_VARARGS,
1403+
"Create TMA descriptor for tiled mode"},
1404+
{"fill_tma_descriptor_im2col", fillTMADescriptorIm2col, METH_VARARGS,
1405+
"Create TMA descriptor for im2col mode"},
11321406
{"build_signature_metadata", buildSignatureMetadata, METH_VARARGS,
11331407
"Calling it with a signature list (ex: ['*fp32', 'u8', 'nvTmaDesc']), "
11341408
"will return metadata to be passed into 'launch' for quicker "

third_party/nvidia/backend/driver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def __init__(self):
8484
self.get_device_properties = mod.get_device_properties
8585
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
8686
self.set_printf_fifo_size = mod.set_printf_fifo_size
87-
self.fill_tma_descriptor = mod.fill_tma_descriptor
87+
self.fill_tma_descriptor_tiled = mod.fill_tma_descriptor_tiled
88+
self.fill_tma_descriptor_im2col = mod.fill_tma_descriptor_im2col
8889
self.launch = mod.launch
8990
self.build_signature_metadata = mod.build_signature_metadata
9091

@@ -231,7 +232,7 @@ def make_tensordesc_arg(arg, metadata):
231232
else:
232233
expanded_shape = shape
233234

234-
cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
235+
cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor_tiled(
235236
arg.base.data_ptr(),
236237
swizzle,
237238
elem_size,

0 commit comments

Comments
 (0)