|
| 1 | +import numpy as np |
| 2 | +import paddle |
| 3 | + |
| 4 | +import tilelang |
| 5 | +import tilelang.language as T |
| 6 | + |
| 7 | + |
| 8 | +# @tilelang.jit(target="cuda") |
| 9 | +# target currently can be "cuda" or "hip" or "cpu". |
| 10 | +# if not specified, it will be inferred from the input tensors during compile time |
| 11 | +@tilelang.jit |
| 12 | +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): |
| 13 | + @T.prim_func |
| 14 | + def matmul_relu_kernel( |
| 15 | + A: T.Tensor((M, K), dtype), |
| 16 | + B: T.Tensor((K, N), dtype), |
| 17 | + C: T.Tensor((M, N), dtype), |
| 18 | + ): |
| 19 | + # Initialize Kernel Context |
| 20 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): |
| 21 | + A_shared = T.alloc_shared((block_M, block_K), dtype) |
| 22 | + B_shared = T.alloc_shared((block_K, block_N), dtype) |
| 23 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 24 | + |
| 25 | + # Enable rasterization for better L2 cache locality (Optional) |
| 26 | + # T.use_swizzle(panel_size=10, enable=True) |
| 27 | + |
| 28 | + # Clear local accumulation |
| 29 | + T.clear(C_local) |
| 30 | + |
| 31 | + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): |
| 32 | + # Copy tile of A |
| 33 | + # This is a sugar syntax for parallelized copy |
| 34 | + T.copy(A[by * block_M, ko * block_K], A_shared) |
| 35 | + |
| 36 | + # Copy tile of B |
| 37 | + T.copy(B[ko * block_K, bx * block_N], B_shared) |
| 38 | + |
| 39 | + # Perform a tile-level GEMM on the shared buffers |
| 40 | + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs |
| 41 | + T.gemm(A_shared, B_shared, C_local) |
| 42 | + |
| 43 | + # relu |
| 44 | + for i, j in T.Parallel(block_M, block_N): |
| 45 | + C_local[i, j] = T.max(C_local[i, j], 0) |
| 46 | + |
| 47 | + # Copy result back to global memory |
| 48 | + T.copy(C_local, C[by * block_M, bx * block_N]) |
| 49 | + |
| 50 | + return matmul_relu_kernel |
| 51 | + |
| 52 | + |
| 53 | +def test_quick_start(): |
| 54 | + M = 1024 # M = T.dynamic("m") if you want to use dynamic shape |
| 55 | + N = 1024 |
| 56 | + K = 1024 |
| 57 | + block_M = 128 |
| 58 | + block_N = 128 |
| 59 | + block_K = 32 |
| 60 | + |
| 61 | + # Define the kernel (matmul) and compile/lower it into an executable module |
| 62 | + matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) |
| 63 | + # Test the kernel in Python with PyTorch data |
| 64 | + import paddle |
| 65 | + |
| 66 | + # Create random input tensors on the GPU |
| 67 | + a = paddle.randn(M, K, device="cuda", dtype=paddle.float16) |
| 68 | + b = paddle.randn(K, N, device="cuda", dtype=paddle.float16) |
| 69 | + c = paddle.empty(M, N, device="cuda", dtype=paddle.float16) |
| 70 | + |
| 71 | + # Run the kernel through the Profiler |
| 72 | + matmul_relu_kernel(a, b, c) |
| 73 | + |
| 74 | + print(c) |
| 75 | + # Reference multiplication using PyTorch |
| 76 | + ref_c = paddle.nn.functional.relu(a @ b) |
| 77 | + |
| 78 | + # Validate correctness |
| 79 | + np.testing.assert_allclose(c.numpy(), ref_c.numpy(), rtol=1e-2, atol=1e-2) |
0 commit comments