@@ -2205,333 +2205,6 @@ __global__ void kdequant_mm_int32_fp16(
2205
2205
}
2206
2206
}
2207
2207
2208
- template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat (char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
2209
- {
2210
-
2211
- // 0. Load data into 32*32 shared memory tiles
2212
- // 1. transpose / reorder in shared memory
2213
- // 2. store
2214
-
2215
- // COL32 FORMAT:
2216
- // rows*32 tiles
2217
-
2218
- // TURING FORMAT:
2219
- // 8*32 tiles with 4*4 subtiles
2220
- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
2221
- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
2222
- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
2223
- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
2224
- // index increases by 32
2225
-
2226
- // AMPERE FORMAT:
2227
- // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
2228
- // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2229
- // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
2230
-
2231
-
2232
- // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
2233
- // As such we need:
2234
- // at least 32*4 shared memory tiles for col32; preferably 32*32
2235
- // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
2236
- // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
2237
- // for efficient loading of row major we need to load 128 elements and repeat this 32 items
2238
- // this would imply a 32x128 shared memory tile -> 4kb
2239
- // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
2240
- // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
2241
- // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
2242
- // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
2243
- //
2244
- // to make the shared memory work with that occupancy we might need to union the block loads/stores
2245
-
2246
- // each block loads TILE_COLs columns and TILE_ROW rows
2247
- // after reading a tile the row counter increase by TILE_ROWS
2248
- // the col counter reset after reading TILE_COL elements
2249
- const int base_row = ((blockIdx .x *TILE_COLS)/tiledCols)*TILE_ROWS;
2250
- // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
2251
- const int base_col = (blockIdx .x *TILE_COLS) % tiledCols;
2252
- const int base_idx = (base_row*cols) + base_col;
2253
-
2254
- // we load 128 bytes per warp with
2255
- // 32 rows for transposes that fill col32 types
2256
- // so that we can have contiguous stores
2257
- __shared__ char smem_data[32 *33 *ITEMS_PER_THREAD];
2258
- char local_data[ITEMS_PER_THREAD];
2259
- typedef cub::BlockExchange<char , THREADS, ITEMS_PER_THREAD> BlockExchange;
2260
-
2261
- // we load row after row from the base_position
2262
- // Load data row by row
2263
- int warps = blockDim .x /32 ;
2264
- int warp_id = threadIdx .x /32 ;
2265
- int warp_lane = threadIdx .x % 32 ;
2266
- int offset = 0 ;
2267
-
2268
- int smem_row = 0 ;
2269
- // each warp loads one row of 128 bytes
2270
- for (int row = warp_id; row < TILE_ROWS; row+=warps)
2271
- {
2272
- int i = base_idx + (row*cols);
2273
- // we load up to 128 bytes/items per load
2274
- int valid_items = cols - base_col > 32 *ITEMS_PER_THREAD ? 32 *ITEMS_PER_THREAD : cols - base_col;
2275
-
2276
- // 0. Load data into 32*32 shared memory tiles
2277
- if (base_row + row < rows)
2278
- {
2279
- #pragma unroll ITEMS_PER_THREAD
2280
- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2281
- {
2282
- int col_idx = warp_lane+(j*32 );
2283
- if (col_idx < valid_items)
2284
- local_data[j] = A[i+col_idx];
2285
- else
2286
- local_data[j] = 0 ;
2287
- }
2288
- }
2289
- else
2290
- {
2291
- #pragma unroll ITEMS_PER_THREAD
2292
- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2293
- local_data[j] = 0 ;
2294
- }
2295
-
2296
- if (TRANSPOSE)
2297
- {
2298
- #pragma unroll ITEMS_PER_THREAD
2299
- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2300
- {
2301
- int local_col = (32 *j)+warp_lane;
2302
- // int local_row = row;
2303
- // store as 256x32
2304
- smem_data[(local_col*33 ) + row] = local_data[j];
2305
- }
2306
- }
2307
- else
2308
- {
2309
- // treat smem as 32x256, that is 32 rows and 256 columns
2310
- #pragma unroll ITEMS_PER_THREAD
2311
- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2312
- smem_data[row*32 *ITEMS_PER_THREAD + (warp_lane) + (j*32 )] = local_data[j];
2313
- }
2314
-
2315
-
2316
-
2317
- smem_row += warps;
2318
-
2319
- // 1. transpose / reorder in shared memory
2320
- if (smem_row % 32 == 0 )
2321
- {
2322
- smem_row = 0 ;
2323
- __syncthreads ();
2324
-
2325
- for (int subrow = warp_id; subrow < 32 ; subrow+=warps)
2326
- {
2327
- for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2328
- {
2329
-
2330
- switch (FORMAT)
2331
- {
2332
- case COL32:
2333
- if (TRANSPOSE)
2334
- {
2335
- // data lies in shared memory in the following way:
2336
- // row0 [col0 col1 ... col31]
2337
- // row1 [col0 col1 ... col31]
2338
- // ...
2339
- //
2340
- // As such we read consecutive entries with 256 threads (8rows x 32 columns)
2341
- // as j increase, the row increase by a factor of 8
2342
- // We load 8 rows per subrow loop, and subrow increase by 8 per loop
2343
- // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
2344
- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
2345
- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
2346
- // const int local_row = warp_id; // each warp_id is one row
2347
- // const int block_row = base_col; // block offset for row
2348
- // const int local_col = warp_lane
2349
- // const int global_col = base_row; // block offset for col
2350
- if ((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
2351
- {
2352
- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
2353
- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
2354
-
2355
- // each 32 columns we have new tile
2356
- // each tile has size outRows*32 and base_row is done in increments of 32
2357
- offset = base_row*outRows;
2358
- out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx .x ] = data;
2359
- }
2360
- }
2361
- else
2362
- {
2363
- if (((base_row+subrow) < rows) && (base_col+(j*32 )+warp_lane < outCols))
2364
- {
2365
- offset = (base_col/32 )*(32 *rows);
2366
- char data = smem_data[(subrow*32 *ITEMS_PER_THREAD) + (j*32 ) + warp_lane];
2367
- out[offset+(base_row+subrow)*32 + ((j)*rows*32 )+warp_lane] = data;
2368
- }
2369
- }
2370
- break ;
2371
- case COL_TURING:
2372
- // TURING FORMAT:
2373
- // 8*32 tiles with 4*4 subtiles
2374
- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
2375
- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
2376
- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
2377
- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
2378
- // index increases by 32
2379
- //
2380
- // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
2381
- if (TRANSPOSE)
2382
- {
2383
- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
2384
- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
2385
- // const int local_row = warp_id; // each warp_id is one row
2386
- // const int block_row = base_col; // block offset for row
2387
- // const int local_col = warp_lane
2388
- // const int global_col = base_row; // block offset for col
2389
- if ((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
2390
- {
2391
- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
2392
- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
2393
-
2394
- // each 32 columns we have new tile
2395
- // each tile has size 8*32 = 256 elements offset
2396
- // for each row offset of 8 we increaes the tile first
2397
- // after all rows are exhausted, we increase the col
2398
- int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8 )*256 ; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
2399
-
2400
- // we increase by row_tile_column every 32 columns
2401
- // base_row increase in increments of 32
2402
- // int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
2403
- // int col_offset = (base_row/32)*row_tile_column;
2404
- // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
2405
- // 256*outRows/8*base_row/32 = outRows*base_row
2406
- int col_offset = outRows*base_row;
2407
-
2408
- offset = row_offset+col_offset;
2409
-
2410
- // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
2411
- // odd or even rows with the warp_id (each warp processes one row)
2412
- // the col is warp_lane (max 32 columns per row) and the row warp_id
2413
- if (warp_id % 2 == 1 )
2414
- // odd
2415
- offset += 128 + (warp_lane/4 )*16 + (warp_lane%4 ) + (((warp_id%8 )-1 )*2 );
2416
- else
2417
- // even
2418
- offset += 0 + (warp_lane/4 )*16 + (warp_lane%4 ) + ((warp_id%8 )*2 );
2419
-
2420
- out[offset] = data;
2421
- }
2422
- }
2423
- else
2424
- {
2425
- if (((base_row+subrow) < rows) && (base_col+(j*32 )+warp_lane < outCols))
2426
- {
2427
- char data = smem_data[(subrow*32 *ITEMS_PER_THREAD) + (j*32 ) + warp_lane];
2428
- // set offset designates the tile offset among the 8*32 tiles
2429
- // we first increase rows and then columns. Since we load 128 columns at once
2430
- // we increase the offset by outRows*32 every 32 columns
2431
- // additionally, we increase the offset by 8*32=256 every 8 rows
2432
- offset = ((base_col+(j*32 ))/32 )*outRows*32 + (((base_row+subrow)/8 )*256 ); // global offset (8x32 tile)
2433
- // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
2434
- // each of these has 32 values in total for 32*4 = 128 as offset if odd
2435
- // every set of 4 columns increases the total offset by 16
2436
- // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
2437
- // this happens every 8 rows anew (subrow % 8)
2438
- // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
2439
- int subcol = warp_lane;
2440
-
2441
- // add local offset (4x4 sub-tile)
2442
- if (subrow % 2 == 1 )
2443
- // odd
2444
- offset += 128 + (subcol/4 )*16 + (subcol%4 ) + (((subrow%8 )-1 )*2 );
2445
- else
2446
- // even
2447
- offset += 0 + (subcol/4 )*16 + (subcol%4 ) + ((subrow%8 )*2 );
2448
-
2449
- out[offset] = data;
2450
- }
2451
- }
2452
- break ;
2453
- case COL_AMPERE:
2454
- // AMPERE FORMAT:
2455
- // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
2456
- // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2457
- // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
2458
- if (TRANSPOSE)
2459
- {
2460
- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
2461
- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
2462
- // const int local_row = warp_id; // each warp_id is one row
2463
- // const int block_row = base_col; // block offset for row
2464
- // const int local_col = warp_lane
2465
- // const int global_col = base_row; // block offset for col
2466
- if ((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
2467
- {
2468
- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
2469
- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
2470
-
2471
- // each 32 columns we have new tile
2472
- // each tile has size 32*32 = 1024 elements offset
2473
- // for each row offset of 32 we increaes the tile first
2474
- // after all rows are exhausted, we increase the col
2475
- int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32 )*1024 ; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
2476
-
2477
- // we increase by row_tile_column every 32 columns
2478
- // base_row increase in increments of 32
2479
- // int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
2480
- // int col_offset = (base_row/32)*row_tile_column;
2481
- // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
2482
- // 1024*outRows/32*base_row/32 = outRows*base_row
2483
- int col_offset = outRows*base_row;
2484
-
2485
- offset = row_offset+col_offset;
2486
-
2487
-
2488
- // same as in the non-transpose case (see below)
2489
- // the difference is that now rows = cols
2490
- // in this case warp_id = subrow
2491
-
2492
- // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2493
- // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
2494
- // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
2495
- // every 2 rows, the offset increases by two [0, 1, 8, 9...]
2496
- // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
2497
- int local_row = (jrow + warp_id) % 32 ; // offset for row > 32 is already calculated into row_offset
2498
- int ampere_row = ((local_row % 8 )/2 )*8 + (local_row/8 )*2 + (local_row % 2 );
2499
-
2500
- // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
2501
- out[offset + (ampere_row*32 ) + warp_lane] = data;
2502
- }
2503
- }
2504
- else
2505
- {
2506
- if (((base_row+subrow) < rows) && (base_col+(j*32 )+warp_lane < outCols))
2507
- {
2508
- char data = smem_data[(subrow*32 *ITEMS_PER_THREAD) + (j*32 ) + warp_lane];
2509
-
2510
- // set offset designates the tile offset among the 32*32 tiles
2511
- // we first increase rows and then columns. Since we load 128 columns at once
2512
- // we increase the offset by outRows*32 every 32 columns
2513
- // additionally, we increase the offset by 32*32=1024 every 32 rows
2514
- offset = ((base_col+(j*32 ))/32 )*outRows*32 + (((base_row+subrow)/32 )*1024 ); // global offset (32x32 tile)
2515
-
2516
- // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
2517
- // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
2518
- // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
2519
- // every 2 rows, the offset increases by two [0, 1, 8, 9...]
2520
- // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
2521
- int local_row = ((subrow % 8 )/2 )*8 + (subrow/8 )*2 + (subrow % 2 );
2522
-
2523
- // global offset + row with 32 cols each + 32 cols per j + col_idx
2524
- out[offset + (local_row*32 ) + warp_lane] = data;
2525
- }
2526
- }
2527
- break ;
2528
- }
2529
- }
2530
- }
2531
- }
2532
- }
2533
- }
2534
-
2535
2208
#define DENORM 1 .0f /127 .0f
2536
2209
#define MAX_SPARSE_COUNT 32
2537
2210
#define SMEM_SIZE 8 *256
@@ -3386,13 +3059,6 @@ template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max
3386
3059
template __global__ void kspmm_coo_very_sparse_naive<signed char , 16 , 8 >(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3387
3060
template __global__ void kspmm_coo_very_sparse_naive<signed char , 32 , 8 >(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3388
3061
3389
- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3390
- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3391
- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3392
- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3393
- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 0 , COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3394
- template __global__ void kTransformRowToFormat <256 , 8 , 32 , 32 *8 , 1 , COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
3395
-
3396
3062
template __global__ void kdequant_mm_int32_fp16<4 , 512 >(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
3397
3063
3398
3064
template __device__ unsigned char dQuantize<0 >(float * smem_code, const float rand, float x);
0 commit comments