Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 276 additions & 2 deletions third_party/nvidia/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,15 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill);

typedef CUresult (*cuTensorMapEncodeIm2col_t)(
CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
const cuuint64_t *globalStrides, const int *pixelBoxLowerCorner,
const int *pixelBoxUpperCorner, cuuint32_t channelsPerPixel,
cuuint32_t pixelsPerColumn, const cuuint32_t *elementStrides,
CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);

typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig *config,
CUfunction f, void **kernelParams,
void **extra);
Expand Down Expand Up @@ -260,6 +269,9 @@ defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
cuTensorMapEncodeTiled);

defineGetFunctionHandle(getCuTensorMapEncodeIm2colHandle,
cuTensorMapEncodeIm2col);

defineGetFunctionHandle(getLaunchKernelExHandle, cuLaunchKernelEx);

static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
Expand Down Expand Up @@ -386,7 +398,7 @@ static PyTypeObject PyCUtensorMapType = {
};
// clang-format on

static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
static PyObject *fillTMADescriptorTiled(PyObject *self, PyObject *args) {
unsigned long long global_address;
int swizzle;
int elemSize;
Expand Down Expand Up @@ -555,6 +567,265 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
return NULL;
}

static PyObject *fillTMADescriptorIm2col(PyObject *self, PyObject *args) {
unsigned long long global_address;
int swizzle;
int elemSize;
int elemType;
PyObject *blockSize;
PyObject *shape;
PyObject *strides;
int padding;
PyObject *pixelBoxLower;
PyObject *pixelBoxUpper;
int channelsPerPixel;
int pixelsPerColumn;
PyObject *elementStrides;

if (!PyArg_ParseTuple(args, "KiiiOOOiOOiiO", &global_address, &swizzle,
&elemSize, &elemType, &blockSize, &shape, &strides,
&padding, &pixelBoxLower, &pixelBoxUpper,
&channelsPerPixel, &pixelsPerColumn, &elementStrides)) {
return NULL;
}

PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
(PyObject *)&PyCUtensorMapType, NULL);
if (!desc) {
return NULL;
}

PyObject *blockSizeFast = NULL;
PyObject *shapeFast = NULL;
PyObject *stridesFast = NULL;
PyObject *pixelBoxLowerFast = NULL;
PyObject *pixelBoxUpperFast = NULL;
PyObject *elementStridesFast = NULL;

uint32_t blockSizeInt[5];
uint64_t shapeInt[5];
uint64_t stridesLL[5];
int pixelBoxLowerInt[5] = {0};
int pixelBoxUpperInt[5] = {0};
uint32_t elementStridesInt[5] = {1, 1, 1, 1, 1}; // Default to all 1s

// For im2col mode, shape determines the tensor rank, not blockSize
// blockSize is typically 2D [pixelsPerColumn, channelsPerPixel]
// while shape can be 4D or 5D (e.g., NHWC or NDHWC)
shapeFast = PySequence_Fast(shape, "shape must be a sequence");
if (!shapeFast)
goto cleanup;
int rank = PySequence_Fast_GET_SIZE(shapeFast);

for (int i = 0; i < rank; ++i) {
PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
if (!PyLong_Check(item)) {
PyErr_SetString(PyExc_TypeError, "shape must be an int");
goto cleanup;
}
shapeInt[rank - i - 1] = PyLong_AsLong(item);
}

blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
if (!blockSizeFast)
goto cleanup;
int blockRank = PySequence_Fast_GET_SIZE(blockSizeFast);

for (int i = 0; i < blockRank; ++i) {
PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
if (!PyLong_Check(item)) {
PyErr_SetString(PyExc_TypeError, "block size must be an int");
goto cleanup;
}
blockSizeInt[blockRank - i - 1] = PyLong_AsLongLong(item);
}

stridesFast = PySequence_Fast(strides, "strides must be a sequence");
if (!stridesFast)
goto cleanup;

if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
PyErr_Format(PyExc_RuntimeError,
"Rank mismatch for strides in fillTMADescriptorIm2col: shape "
"has rank %d but strides has %zd elements. "
"Expected strides to have %d elements.",
rank, PySequence_Fast_GET_SIZE(stridesFast), rank);
goto cleanup;
}
for (int i = 0; i + 1 < rank; ++i) {
PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
if (!PyLong_Check(item)) {
PyErr_SetString(PyExc_TypeError, "strides must be an int");
goto cleanup;
}
stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
}
stridesLL[rank - 1] =
shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);

// Parse pixel box lower corner
pixelBoxLowerFast =
PySequence_Fast(pixelBoxLower, "pixelBoxLower must be a sequence");
if (!pixelBoxLowerFast)
goto cleanup;

int spatialRank = PySequence_Fast_GET_SIZE(pixelBoxLowerFast);
if (spatialRank > 5) {
PyErr_SetString(PyExc_RuntimeError, "Pixel box rank too large (max 5)");
goto cleanup;
}

for (int i = 0; i < spatialRank; ++i) {
PyObject *item = PySequence_Fast_GET_ITEM(pixelBoxLowerFast, i);
if (!PyLong_Check(item)) {
PyErr_SetString(PyExc_TypeError, "pixelBoxLower elements must be int");
goto cleanup;
}
pixelBoxLowerInt[spatialRank - i - 1] = PyLong_AsLong(item);
}

// Parse pixel box upper corner
pixelBoxUpperFast =
PySequence_Fast(pixelBoxUpper, "pixelBoxUpper must be a sequence");
if (!pixelBoxUpperFast)
goto cleanup;

if (spatialRank != PySequence_Fast_GET_SIZE(pixelBoxUpperFast)) {
PyErr_SetString(PyExc_RuntimeError, "Pixel box corner rank mismatch");
goto cleanup;
}

for (int i = 0; i < spatialRank; ++i) {
PyObject *item = PySequence_Fast_GET_ITEM(pixelBoxUpperFast, i);
if (!PyLong_Check(item)) {
PyErr_SetString(PyExc_TypeError, "pixelBoxUpper elements must be int");
goto cleanup;
}
pixelBoxUpperInt[spatialRank - i - 1] = PyLong_AsLong(item);
}

// Parse element strides
elementStridesFast =
PySequence_Fast(elementStrides, "elementStrides must be a sequence");
if (!elementStridesFast)
goto cleanup;

int elementStridesLen = PySequence_Fast_GET_SIZE(elementStridesFast);
if (elementStridesLen != rank) {
PyErr_SetString(PyExc_RuntimeError,
"elementStrides length must match tensor rank");
goto cleanup;
}

for (int i = 0; i < rank; ++i) {
PyObject *item = PySequence_Fast_GET_ITEM(elementStridesFast, i);
if (!PyLong_Check(item)) {
PyErr_SetString(PyExc_TypeError, "elementStrides elements must be int");
goto cleanup;
}
elementStridesInt[rank - i - 1] = PyLong_AsLong(item);
}

Py_DECREF(blockSizeFast);
blockSizeFast = NULL;
Py_DECREF(shapeFast);
shapeFast = NULL;
Py_DECREF(stridesFast);
stridesFast = NULL;
Py_DECREF(pixelBoxLowerFast);
pixelBoxLowerFast = NULL;
Py_DECREF(pixelBoxUpperFast);
pixelBoxUpperFast = NULL;
Py_DECREF(elementStridesFast);
elementStridesFast = NULL;

CUtensorMapFloatOOBfill fill =
(padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
: CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;

static cuTensorMapEncodeIm2col_t cuTensorMapEncodeIm2col = NULL;
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeIm2col,
getCuTensorMapEncodeIm2colHandle);

CUresult res = cuTensorMapEncodeIm2col(
&desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
stridesLL, pixelBoxLowerInt, pixelBoxUpperInt, channelsPerPixel,
pixelsPerColumn, elementStridesInt, CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill);

if (res != CUDA_SUCCESS) {
const char *str;
cuGetErrorString(res, &str);
char err[4096] = {0};
size_t off = 0;
off += snprintf(err + off, sizeof(err) - off,
"Triton Error [CUDA]: Failed to create im2col tensor map "
"descriptor: %s\n",
str ? str : "Unknown error");
off +=
snprintf(err + off, sizeof(err) - off,
"elemType=%d rank=%d global_address=0x%llx elemSize=%d "
"swizzle=%d padding=%d channelsPerPixel=%d "
"pixelsPerColumn=%d\n",
elemType, rank, (unsigned long long)global_address, elemSize,
swizzle, padding, channelsPerPixel, pixelsPerColumn);
off += snprintf(err + off, sizeof(err) - off, "shape=[");
for (int i = 0; i < rank; ++i) {
off +=
snprintf(err + off, sizeof(err) - off, "%llu%s",
(unsigned long long)shapeInt[i], (i + 1 < rank) ? ", " : "");
}
off += snprintf(err + off, sizeof(err) - off, "]\n");
off += snprintf(err + off, sizeof(err) - off, "strides=[");
for (int i = 0; i < rank; ++i) {
off += snprintf(err + off, sizeof(err) - off, "%llu%s",
(unsigned long long)stridesLL[i],
(i + 1 < rank) ? ", " : "");
}
off += snprintf(err + off, sizeof(err) - off, "]\n");
off += snprintf(err + off, sizeof(err) - off, "blockSize=[");
for (int i = 0; i < blockRank; ++i) {
off +=
snprintf(err + off, sizeof(err) - off, "%u%s",
(unsigned)blockSizeInt[i], (i + 1 < blockRank) ? ", " : "");
}
off += snprintf(err + off, sizeof(err) - off, "]\n");
off += snprintf(err + off, sizeof(err) - off, "pixelBoxLower=[");
for (int i = 0; i < spatialRank; ++i) {
off += snprintf(err + off, sizeof(err) - off, "%d%s", pixelBoxLowerInt[i],
(i + 1 < spatialRank) ? ", " : "");
}
off += snprintf(err + off, sizeof(err) - off, "] pixelBoxUpper=[");
for (int i = 0; i < spatialRank; ++i) {
off += snprintf(err + off, sizeof(err) - off, "%d%s", pixelBoxUpperInt[i],
(i + 1 < spatialRank) ? ", " : "");
}
off += snprintf(err + off, sizeof(err) - off, "]\n");
off += snprintf(err + off, sizeof(err) - off, "elementStrides=[");
for (int i = 0; i < rank; ++i) {
off +=
snprintf(err + off, sizeof(err) - off, "%u%s",
(unsigned)elementStridesInt[i], (i + 1 < rank) ? ", " : "");
}
off += snprintf(err + off, sizeof(err) - off, "]\n");
PyErr_SetString(PyExc_RuntimeError, err);

goto cleanup;
}

return (PyObject *)desc;

cleanup:
Py_XDECREF(blockSizeFast);
Py_XDECREF(shapeFast);
Py_XDECREF(stridesFast);
Py_XDECREF(pixelBoxLowerFast);
Py_XDECREF(pixelBoxUpperFast);
Py_XDECREF(elementStridesFast);
Py_XDECREF(desc);
return NULL;
}

static void ensureCudaContext() {
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
Expand Down Expand Up @@ -1128,7 +1399,10 @@ static PyMethodDef ModuleMethods[] = {
"being dropped. This inherits all the limitations of this call; in "
"particular it's an error to change this value after launching any kernel "
"that calls printf()."},
{"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
{"fill_tma_descriptor_tiled", fillTMADescriptorTiled, METH_VARARGS,
"Create TMA descriptor for tiled mode"},
{"fill_tma_descriptor_im2col", fillTMADescriptorIm2col, METH_VARARGS,
"Create TMA descriptor for im2col mode"},
{"build_signature_metadata", buildSignatureMetadata, METH_VARARGS,
"Calling it with a signature list (ex: ['*fp32', 'u8', 'nvTmaDesc']), "
"will return metadata to be passed into 'launch' for quicker "
Expand Down
5 changes: 3 additions & 2 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __init__(self):
self.get_device_properties = mod.get_device_properties
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
self.set_printf_fifo_size = mod.set_printf_fifo_size
self.fill_tma_descriptor = mod.fill_tma_descriptor
self.fill_tma_descriptor_tiled = mod.fill_tma_descriptor_tiled
self.fill_tma_descriptor_im2col = mod.fill_tma_descriptor_im2col
self.launch = mod.launch
self.build_signature_metadata = mod.build_signature_metadata

Expand Down Expand Up @@ -231,7 +232,7 @@ def make_tensordesc_arg(arg, metadata):
else:
expanded_shape = shape

cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor(
cu_tensor_map = triton.runtime.driver.active.utils.fill_tma_descriptor_tiled(
arg.base.data_ptr(),
swizzle,
elem_size,
Expand Down
Loading