@@ -228,6 +228,15 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
228228 CUtensorMapSwizzle swizzle , CUtensorMapL2promotion l2Promotion ,
229229 CUtensorMapFloatOOBfill oobFill );
230230
231+ typedef CUresult (* cuTensorMapEncodeIm2col_t )(
232+ CUtensorMap * tensorMap , CUtensorMapDataType tensorDataType ,
233+ cuuint32_t tensorRank , void * globalAddress , const cuuint64_t * globalDim ,
234+ const cuuint64_t * globalStrides , const int * pixelBoxLowerCorner ,
235+ const int * pixelBoxUpperCorner , cuuint32_t channelsPerPixel ,
236+ cuuint32_t pixelsPerColumn , const cuuint32_t * elementStrides ,
237+ CUtensorMapInterleave interleave , CUtensorMapSwizzle swizzle ,
238+ CUtensorMapL2promotion l2Promotion , CUtensorMapFloatOOBfill oobFill );
239+
231240typedef CUresult (* cuLaunchKernelEx_t )(const CUlaunchConfig * config ,
232241 CUfunction f , void * * kernelParams ,
233242 void * * extra );
@@ -260,6 +269,9 @@ defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
260269defineGetFunctionHandle (getCuTensorMapEncodeTiledHandle ,
261270 cuTensorMapEncodeTiled );
262271
272+ defineGetFunctionHandle (getCuTensorMapEncodeIm2colHandle ,
273+ cuTensorMapEncodeIm2col );
274+
263275defineGetFunctionHandle (getLaunchKernelExHandle , cuLaunchKernelEx );
264276
265277static PyObject * occupancyMaxActiveClusters (PyObject * self , PyObject * args ) {
@@ -386,7 +398,7 @@ static PyTypeObject PyCUtensorMapType = {
386398};
387399// clang-format on
388400
389- static PyObject * fillTMADescriptor (PyObject * self , PyObject * args ) {
401+ static PyObject * fillTMADescriptorTiled (PyObject * self , PyObject * args ) {
390402 unsigned long long global_address ;
391403 int swizzle ;
392404 int elemSize ;
@@ -555,6 +567,265 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
555567 return NULL ;
556568}
557569
570+ static PyObject * fillTMADescriptorIm2col (PyObject * self , PyObject * args ) {
571+ unsigned long long global_address ;
572+ int swizzle ;
573+ int elemSize ;
574+ int elemType ;
575+ PyObject * blockSize ;
576+ PyObject * shape ;
577+ PyObject * strides ;
578+ int padding ;
579+ PyObject * pixelBoxLower ;
580+ PyObject * pixelBoxUpper ;
581+ int channelsPerPixel ;
582+ int pixelsPerColumn ;
583+ PyObject * elementStrides ;
584+
585+ if (!PyArg_ParseTuple (args , "KiiiOOOiOOiiO" , & global_address , & swizzle ,
586+ & elemSize , & elemType , & blockSize , & shape , & strides ,
587+ & padding , & pixelBoxLower , & pixelBoxUpper ,
588+ & channelsPerPixel , & pixelsPerColumn , & elementStrides )) {
589+ return NULL ;
590+ }
591+
592+ PyCUtensorMapObject * desc = (PyCUtensorMapObject * )PyObject_CallObject (
593+ (PyObject * )& PyCUtensorMapType , NULL );
594+ if (!desc ) {
595+ return NULL ;
596+ }
597+
598+ PyObject * blockSizeFast = NULL ;
599+ PyObject * shapeFast = NULL ;
600+ PyObject * stridesFast = NULL ;
601+ PyObject * pixelBoxLowerFast = NULL ;
602+ PyObject * pixelBoxUpperFast = NULL ;
603+ PyObject * elementStridesFast = NULL ;
604+
605+ uint32_t blockSizeInt [5 ];
606+ uint64_t shapeInt [5 ];
607+ uint64_t stridesLL [5 ];
608+ int pixelBoxLowerInt [5 ] = {0 };
609+ int pixelBoxUpperInt [5 ] = {0 };
610+ uint32_t elementStridesInt [5 ] = {1 , 1 , 1 , 1 , 1 }; // Default to all 1s
611+
612+ // For im2col mode, shape determines the tensor rank, not blockSize
613+ // blockSize is typically 2D [pixelsPerColumn, channelsPerPixel]
614+ // while shape can be 4D or 5D (e.g., NHWC or NDHWC)
615+ shapeFast = PySequence_Fast (shape , "shape must be a sequence" );
616+ if (!shapeFast )
617+ goto cleanup ;
618+ int rank = PySequence_Fast_GET_SIZE (shapeFast );
619+
620+ for (int i = 0 ; i < rank ; ++ i ) {
621+ PyObject * item = PySequence_Fast_GET_ITEM (shapeFast , i );
622+ if (!PyLong_Check (item )) {
623+ PyErr_SetString (PyExc_TypeError , "shape must be an int" );
624+ goto cleanup ;
625+ }
626+ shapeInt [rank - i - 1 ] = PyLong_AsLong (item );
627+ }
628+
629+ blockSizeFast = PySequence_Fast (blockSize , "blockSize must be a sequence" );
630+ if (!blockSizeFast )
631+ goto cleanup ;
632+ int blockRank = PySequence_Fast_GET_SIZE (blockSizeFast );
633+
634+ for (int i = 0 ; i < blockRank ; ++ i ) {
635+ PyObject * item = PySequence_Fast_GET_ITEM (blockSizeFast , i );
636+ if (!PyLong_Check (item )) {
637+ PyErr_SetString (PyExc_TypeError , "block size must be an int" );
638+ goto cleanup ;
639+ }
640+ blockSizeInt [blockRank - i - 1 ] = PyLong_AsLongLong (item );
641+ }
642+
643+ stridesFast = PySequence_Fast (strides , "strides must be a sequence" );
644+ if (!stridesFast )
645+ goto cleanup ;
646+
647+ if (rank != PySequence_Fast_GET_SIZE (stridesFast )) {
648+ PyErr_Format (PyExc_RuntimeError ,
649+ "Rank mismatch for strides in fillTMADescriptorIm2col: shape "
650+ "has rank %d but strides has %zd elements. "
651+ "Expected strides to have %d elements." ,
652+ rank , PySequence_Fast_GET_SIZE (stridesFast ), rank );
653+ goto cleanup ;
654+ }
655+ for (int i = 0 ; i + 1 < rank ; ++ i ) {
656+ PyObject * item = PySequence_Fast_GET_ITEM (stridesFast , i );
657+ if (!PyLong_Check (item )) {
658+ PyErr_SetString (PyExc_TypeError , "strides must be an int" );
659+ goto cleanup ;
660+ }
661+ stridesLL [rank - i - 2 ] = elemSize * PyLong_AsLongLong (item );
662+ }
663+ stridesLL [rank - 1 ] =
664+ shapeInt [rank - 1 ] * (rank == 1 ? elemSize : stridesLL [rank - 2 ]);
665+
666+ // Parse pixel box lower corner
667+ pixelBoxLowerFast =
668+ PySequence_Fast (pixelBoxLower , "pixelBoxLower must be a sequence" );
669+ if (!pixelBoxLowerFast )
670+ goto cleanup ;
671+
672+ int spatialRank = PySequence_Fast_GET_SIZE (pixelBoxLowerFast );
673+ if (spatialRank > 5 ) {
674+ PyErr_SetString (PyExc_RuntimeError , "Pixel box rank too large (max 5)" );
675+ goto cleanup ;
676+ }
677+
678+ for (int i = 0 ; i < spatialRank ; ++ i ) {
679+ PyObject * item = PySequence_Fast_GET_ITEM (pixelBoxLowerFast , i );
680+ if (!PyLong_Check (item )) {
681+ PyErr_SetString (PyExc_TypeError , "pixelBoxLower elements must be int" );
682+ goto cleanup ;
683+ }
684+ pixelBoxLowerInt [spatialRank - i - 1 ] = PyLong_AsLong (item );
685+ }
686+
687+ // Parse pixel box upper corner
688+ pixelBoxUpperFast =
689+ PySequence_Fast (pixelBoxUpper , "pixelBoxUpper must be a sequence" );
690+ if (!pixelBoxUpperFast )
691+ goto cleanup ;
692+
693+ if (spatialRank != PySequence_Fast_GET_SIZE (pixelBoxUpperFast )) {
694+ PyErr_SetString (PyExc_RuntimeError , "Pixel box corner rank mismatch" );
695+ goto cleanup ;
696+ }
697+
698+ for (int i = 0 ; i < spatialRank ; ++ i ) {
699+ PyObject * item = PySequence_Fast_GET_ITEM (pixelBoxUpperFast , i );
700+ if (!PyLong_Check (item )) {
701+ PyErr_SetString (PyExc_TypeError , "pixelBoxUpper elements must be int" );
702+ goto cleanup ;
703+ }
704+ pixelBoxUpperInt [spatialRank - i - 1 ] = PyLong_AsLong (item );
705+ }
706+
707+ // Parse element strides
708+ elementStridesFast =
709+ PySequence_Fast (elementStrides , "elementStrides must be a sequence" );
710+ if (!elementStridesFast )
711+ goto cleanup ;
712+
713+ int elementStridesLen = PySequence_Fast_GET_SIZE (elementStridesFast );
714+ if (elementStridesLen != rank ) {
715+ PyErr_SetString (PyExc_RuntimeError ,
716+ "elementStrides length must match tensor rank" );
717+ goto cleanup ;
718+ }
719+
720+ for (int i = 0 ; i < rank ; ++ i ) {
721+ PyObject * item = PySequence_Fast_GET_ITEM (elementStridesFast , i );
722+ if (!PyLong_Check (item )) {
723+ PyErr_SetString (PyExc_TypeError , "elementStrides elements must be int" );
724+ goto cleanup ;
725+ }
726+ elementStridesInt [rank - i - 1 ] = PyLong_AsLong (item );
727+ }
728+
729+ Py_DECREF (blockSizeFast );
730+ blockSizeFast = NULL ;
731+ Py_DECREF (shapeFast );
732+ shapeFast = NULL ;
733+ Py_DECREF (stridesFast );
734+ stridesFast = NULL ;
735+ Py_DECREF (pixelBoxLowerFast );
736+ pixelBoxLowerFast = NULL ;
737+ Py_DECREF (pixelBoxUpperFast );
738+ pixelBoxUpperFast = NULL ;
739+ Py_DECREF (elementStridesFast );
740+ elementStridesFast = NULL ;
741+
742+ CUtensorMapFloatOOBfill fill =
743+ (padding == 1 ) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
744+ : CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ;
745+
746+ static cuTensorMapEncodeIm2col_t cuTensorMapEncodeIm2col = NULL ;
747+ INITIALIZE_FUNCTION_POINTER_IF_NULL (cuTensorMapEncodeIm2col ,
748+ getCuTensorMapEncodeIm2colHandle );
749+
750+ CUresult res = cuTensorMapEncodeIm2col (
751+ & desc -> tensorMap , elemType , rank , (void * )global_address , shapeInt ,
752+ stridesLL , pixelBoxLowerInt , pixelBoxUpperInt , channelsPerPixel ,
753+ pixelsPerColumn , elementStridesInt , CU_TENSOR_MAP_INTERLEAVE_NONE ,
754+ swizzle , CU_TENSOR_MAP_L2_PROMOTION_L2_128B , fill );
755+
756+ if (res != CUDA_SUCCESS ) {
757+ const char * str ;
758+ cuGetErrorString (res , & str );
759+ char err [4096 ] = {0 };
760+ size_t off = 0 ;
761+ off += snprintf (err + off , sizeof (err ) - off ,
762+ "Triton Error [CUDA]: Failed to create im2col tensor map "
763+ "descriptor: %s\n" ,
764+ str ? str : "Unknown error" );
765+ off +=
766+ snprintf (err + off , sizeof (err ) - off ,
767+ "elemType=%d rank=%d global_address=0x%llx elemSize=%d "
768+ "swizzle=%d padding=%d channelsPerPixel=%d "
769+ "pixelsPerColumn=%d\n" ,
770+ elemType , rank , (unsigned long long )global_address , elemSize ,
771+ swizzle , padding , channelsPerPixel , pixelsPerColumn );
772+ off += snprintf (err + off , sizeof (err ) - off , "shape=[" );
773+ for (int i = 0 ; i < rank ; ++ i ) {
774+ off +=
775+ snprintf (err + off , sizeof (err ) - off , "%llu%s" ,
776+ (unsigned long long )shapeInt [i ], (i + 1 < rank ) ? ", " : "" );
777+ }
778+ off += snprintf (err + off , sizeof (err ) - off , "]\n" );
779+ off += snprintf (err + off , sizeof (err ) - off , "strides=[" );
780+ for (int i = 0 ; i < rank ; ++ i ) {
781+ off += snprintf (err + off , sizeof (err ) - off , "%llu%s" ,
782+ (unsigned long long )stridesLL [i ],
783+ (i + 1 < rank ) ? ", " : "" );
784+ }
785+ off += snprintf (err + off , sizeof (err ) - off , "]\n" );
786+ off += snprintf (err + off , sizeof (err ) - off , "blockSize=[" );
787+ for (int i = 0 ; i < blockRank ; ++ i ) {
788+ off +=
789+ snprintf (err + off , sizeof (err ) - off , "%u%s" ,
790+ (unsigned )blockSizeInt [i ], (i + 1 < blockRank ) ? ", " : "" );
791+ }
792+ off += snprintf (err + off , sizeof (err ) - off , "]\n" );
793+ off += snprintf (err + off , sizeof (err ) - off , "pixelBoxLower=[" );
794+ for (int i = 0 ; i < spatialRank ; ++ i ) {
795+ off += snprintf (err + off , sizeof (err ) - off , "%d%s" , pixelBoxLowerInt [i ],
796+ (i + 1 < spatialRank ) ? ", " : "" );
797+ }
798+ off += snprintf (err + off , sizeof (err ) - off , "] pixelBoxUpper=[" );
799+ for (int i = 0 ; i < spatialRank ; ++ i ) {
800+ off += snprintf (err + off , sizeof (err ) - off , "%d%s" , pixelBoxUpperInt [i ],
801+ (i + 1 < spatialRank ) ? ", " : "" );
802+ }
803+ off += snprintf (err + off , sizeof (err ) - off , "]\n" );
804+ off += snprintf (err + off , sizeof (err ) - off , "elementStrides=[" );
805+ for (int i = 0 ; i < rank ; ++ i ) {
806+ off +=
807+ snprintf (err + off , sizeof (err ) - off , "%u%s" ,
808+ (unsigned )elementStridesInt [i ], (i + 1 < rank ) ? ", " : "" );
809+ }
810+ off += snprintf (err + off , sizeof (err ) - off , "]\n" );
811+ PyErr_SetString (PyExc_RuntimeError , err );
812+
813+ goto cleanup ;
814+ }
815+
816+ return (PyObject * )desc ;
817+
818+ cleanup :
819+ Py_XDECREF (blockSizeFast );
820+ Py_XDECREF (shapeFast );
821+ Py_XDECREF (stridesFast );
822+ Py_XDECREF (pixelBoxLowerFast );
823+ Py_XDECREF (pixelBoxUpperFast );
824+ Py_XDECREF (elementStridesFast );
825+ Py_XDECREF (desc );
826+ return NULL ;
827+ }
828+
558829static void ensureCudaContext () {
559830 CUcontext pctx ;
560831 CUDA_CHECK (cuCtxGetCurrent (& pctx ));
@@ -1128,7 +1399,10 @@ static PyMethodDef ModuleMethods[] = {
11281399 "being dropped. This inherits all the limitations of this call; in "
11291400 "particular it's an error to change this value after launching any kernel "
11301401 "that calls printf()." },
1131- {"fill_tma_descriptor" , fillTMADescriptor , METH_VARARGS , "doc" },
1402+ {"fill_tma_descriptor_tiled" , fillTMADescriptorTiled , METH_VARARGS ,
1403+ "Create TMA descriptor for tiled mode" },
1404+ {"fill_tma_descriptor_im2col" , fillTMADescriptorIm2col , METH_VARARGS ,
1405+ "Create TMA descriptor for im2col mode" },
11321406 {"build_signature_metadata" , buildSignatureMetadata , METH_VARARGS ,
11331407 "Calling it with a signature list (ex: ['*fp32', 'u8', 'nvTmaDesc']), "
11341408 "will return metadata to be passed into 'launch' for quicker "
0 commit comments