88#include <stdio.h>
99#include <stdlib.h>
1010
11+ // Include shared TDM utilities
12+ #include "TDMCommon.h"
13+
1114typedef struct {
1215 uint32_t group0_0 ;
1316 uint32_t group0_1 ;
@@ -21,6 +24,14 @@ typedef struct {
2124 uint32_t group1_5 ;
2225 uint32_t group1_6 ;
2326 uint32_t group1_7 ;
27+ uint32_t group2_0 ;
28+ uint32_t group2_1 ;
29+ uint32_t group2_2 ;
30+ uint32_t group2_3 ;
31+ uint32_t group3_0 ;
32+ uint32_t group3_1 ;
33+ uint32_t group3_2 ;
34+ uint32_t group3_3 ;
2435} TDMDescriptor ;
2536
2637typedef struct {
@@ -54,36 +65,39 @@ static PyTypeObject PyTDMDescriptorType = {
5465 .tp_dealloc = (destructor )PyTDMDescriptor_dealloc ,
5566};
5667
57- // TODO: Both host-side and device-side TDM descriptor follow the same encoding
58- // format. Consider to add a common utility to remove duplicate code .
68+ // Encodes a TDM descriptor. Supports 1D-5D tensors.
69+ // Uses the same encoding format as createTDMDescriptor in TDMUtility.cpp .
5970static bool encodeTDMDescriptor (TDMDescriptor * desc , int elementBitWidth ,
6071 uint32_t * blockSize , int numWarps ,
6172 int padInterval , int padAmount , uint32_t * shape ,
6273 uint32_t * strides , uint64_t globalAddress ,
6374 int rank ) {
64- // NYI: TDM > 2D cases
65- if (rank != 2 )
75+ if (rank < 1 || rank > 5 )
6676 return false;
6777
68- // Get warp distribution
69- uint32_t numWarpsDim0 = numWarps ;
70- for (; numWarpsDim0 > blockSize [0 ]; numWarpsDim0 /= 2 )
71- ;
72- uint32_t numWarpsDim1 = numWarps / numWarpsDim0 ;
73- if (!(numWarpsDim0 > 0 && blockSize [1 ] % numWarpsDim1 == 0 ))
74- return false;
78+ memset (desc , 0 , sizeof (TDMDescriptor ));
7579
76- uint32_t blockSize0 = (blockSize [0 ] + numWarpsDim0 - 1 ) / numWarpsDim0 ;
77- uint32_t blockSize1 = (blockSize [1 ] + numWarpsDim1 - 1 ) / numWarpsDim1 ;
80+ // Convert to int64_t for shared function and get adjusted block sizes
81+ int64_t blockShape64 [5 ], adjustedBlockSize64 [5 ];
82+ for (int i = 0 ; i < rank ; ++ i )
83+ blockShape64 [i ] = blockSize [i ];
84+ tdmGetAdjustedBlockShape (blockShape64 , rank , numWarps , adjustedBlockSize64 );
85+
86+ // Convert back to uint32_t
87+ uint32_t adjustedBlockSize [5 ];
88+ for (int i = 0 ; i < rank ; ++ i )
89+ adjustedBlockSize [i ] = (uint32_t )adjustedBlockSize64 [i ];
7890
7991 // group0 (128 bits / 4 dwords) effective bit encoding:
92+ // [1:0]: pred (to be filled later)
93+ // [63:32]: lds address (to be filled later)
8094 // [120:64]: global address
8195 // [127:126]: type - currently always set to 0x2
8296 desc -> group0_2 = (uint32_t )(globalAddress & 0xFFFFFFFF );
83- desc -> group0_3 = (uint32_t )((globalAddress >> 32 ) & 0x01FFFFFF );
84- desc -> group0_3 |= (0x1 << 31 );
97+ desc -> group0_3 = (uint32_t )((globalAddress >> 32 ) & 0x7FFFFFFF ) | (0x1 << 31 );
8598
8699 // group1 (256 bits / 8 dwords) effective bit encoding:
100+ // [15:0]: multicast mask
87101 // [17:16]: data size - log2(element size in bytes)
88102 // [20]: enable padding
89103 // [24:22]: pad interval - log2(pad interval in dwords) - 1
@@ -92,26 +106,72 @@ static bool encodeTDMDescriptor(TDMDescriptor *desc, int elementBitWidth,
92106 // [111:80]: tensor shape dim outer
93107 // [127:112]: block shape dim inner
94108 // [143:128]: block shape dim outer
109+ // [159:144]: tile_dim2
95110 // [207:160]: tensor stride dim outer (we only use 32 bits)
111+ // [255:208]: tensor stride dim 2 (48 bits)
96112 int elementSizeInBytes = elementBitWidth / 8 ;
97- int dataSize = log2 (elementSizeInBytes );
98- desc -> group1_0 = (dataSize << 16 );
113+ int dataSize = (int )log2 (elementSizeInBytes );
99114 int dwordSize = 32 ;
100115 int padIntervalInDwords = padInterval * elementBitWidth / dwordSize ;
101116 int padAmountInDwords = padAmount * elementBitWidth / dwordSize ;
117+
118+ desc -> group1_0 = (dataSize << 16 );
102119 if (padIntervalInDwords > 0 && padAmountInDwords > 0 ) {
103- int log2PadInterval = log2 (padIntervalInDwords );
120+ int log2PadInterval = ( int ) log2 (padIntervalInDwords );
104121 desc -> group1_0 |= (1 << 20 );
105122 desc -> group1_0 |= ((log2PadInterval - 1 ) << 22 );
106123 desc -> group1_0 |= ((padAmountInDwords - 1 ) << 25 );
107124 }
108- desc -> group1_1 = (shape [1 ] << 16 );
109- desc -> group1_2 = (shape [1 ] >> 16 );
110- desc -> group1_2 |= (shape [0 ] << 16 );
111- desc -> group1_3 = (shape [0 ] >> 16 );
112- desc -> group1_3 |= (blockSize1 << 16 );
113- desc -> group1_4 = (blockSize0 & 0xFFFF );
114- desc -> group1_5 = strides [0 ];
125+
126+ // Encode tensor shapes (48-bit encoding, indices from end: rank-1 is inner)
127+ desc -> group1_1 = (shape [rank - 1 ] << 16 );
128+ desc -> group1_2 = (shape [rank - 1 ] >> 16 );
129+
130+ if (rank >= 2 ) {
131+ desc -> group1_2 |= (shape [rank - 2 ] << 16 );
132+ desc -> group1_3 = (shape [rank - 2 ] >> 16 );
133+ }
134+
135+ // Block shapes
136+ desc -> group1_3 |= (adjustedBlockSize [rank - 1 ] << 16 );
137+ if (rank >= 2 )
138+ desc -> group1_4 = (adjustedBlockSize [rank - 2 ] & 0xFFFF );
139+ if (rank >= 3 )
140+ desc -> group1_4 |= (adjustedBlockSize [rank - 3 ] << 16 );
141+
142+ // Strides
143+ if (rank >= 2 )
144+ desc -> group1_5 = strides [rank - 2 ];
145+ if (rank >= 3 ) {
146+ desc -> group1_6 = (strides [rank - 3 ] << 16 );
147+ desc -> group1_7 = (strides [rank - 3 ] >> 16 );
148+ }
149+
150+ // group2 (128 bits / 4 dwords) for 3D-5D tensors:
151+ // [31:0]: tensor_dim2 (3rd dimension from end)
152+ // [63:32]: tensor_dim3 (4th dimension from end)
153+ // [111:64]: tensor_dim2_stride (48 bits, we use 32 bits)
154+ // [127:112]: tile_dim3
155+ if (rank >= 3 ) {
156+ desc -> group2_0 = shape [rank - 3 ];
157+ if (rank >= 4 ) {
158+ desc -> group2_1 = shape [rank - 4 ];
159+ desc -> group2_2 = strides [rank - 4 ];
160+ desc -> group2_3 = (adjustedBlockSize [rank - 4 ] << 16 );
161+ }
162+ }
163+
164+ // group3 (128 bits / 4 dwords) for 4D-5D tensors:
165+ // [47:0]: tensor_dim3_stride (48 bits, we use 32 bits)
166+ // [79:48]: tensor_dim4 (5th dimension from end)
167+ // [95:80]: tile_dim4
168+ // [127:96]: reserved
169+ if (rank == 5 ) {
170+ desc -> group3_0 = strides [rank - 5 ];
171+ desc -> group3_1 = (shape [rank - 5 ] << 16 );
172+ desc -> group3_2 = (shape [rank - 5 ] >> 16 );
173+ desc -> group3_2 |= (adjustedBlockSize [rank - 5 ] << 16 );
174+ }
115175
116176 return true;
117177}
@@ -388,16 +448,16 @@ static PyObject *createTDMDescriptor(PyObject *self, PyObject *args) {
388448 PyObject * shapeFast = NULL ;
389449 PyObject * stridesFast = NULL ;
390450
391- uint32_t blockSizeInt [2 ];
392- uint32_t shapeInt [2 ];
393- uint32_t stridesInt [2 ];
451+ uint32_t blockSizeInt [5 ];
452+ uint32_t shapeInt [5 ];
453+ uint32_t stridesInt [5 ];
394454
395455 blockSizeFast = PySequence_Fast (blockSize , "blockSize must be a sequence" );
396456 if (!blockSizeFast )
397457 goto cleanup ;
398458 int rank = PySequence_Fast_GET_SIZE (blockSizeFast );
399- if (rank != 2 ) {
400- PyErr_SetString (PyExc_RuntimeError , "rank must be 2 " );
459+ if (rank == 0 || rank > 5 ) {
460+ PyErr_SetString (PyExc_RuntimeError , "rank must be between 1 and 5 " );
401461 goto cleanup ;
402462 }
403463
0 commit comments