Skip to content

Commit 726ba08

Browse files
committed
use uv
1 parent f257c59 commit 726ba08

File tree

2 files changed

+96
-3
lines changed

2 files changed

+96
-3
lines changed

.github/workflows/ci-paddle.yml

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,29 @@ jobs:
7272
docker exec -t ${{ env.container_name }} /bin/bash -c '
7373
set -e
7474
source ${{ github.workspace }}/../../../proxy
75-
pip install -r requirements-test.txt
76-
pip install -e .
75+
76+
# Install uv
77+
curl -LsSf https://astral.sh/uv/install.sh | sh
78+
source $HOME/.cargo/env
79+
80+
# Create and activate virtual environment
81+
uv venv .venv
82+
source .venv/bin/activate
83+
84+
# Install paddle
85+
uv pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/
86+
87+
# Install project and minimal test runner
88+
uv pip install pytest
89+
uv pip install -e .
7790
'
7891
7992
- name: Run tests
8093
run: |
8194
docker exec -t ${{ env.container_name }} /bin/bash -c '
8295
set -e
83-
pytest testing/
96+
source .venv/bin/activate
97+
pytest tests_paddle/
8498
'
8599
86100
- name: Terminate and delete the container

tests_paddle/test_quick_start.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
import paddle
3+
4+
import tilelang
5+
import tilelang.language as T
6+
7+
8+
# @tilelang.jit(target="cuda")
9+
# target currently can be "cuda" or "hip" or "cpu".
10+
# if not specified, it will be inferred from the input tensors during compile time
11+
@tilelang.jit
12+
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
13+
@T.prim_func
14+
def matmul_relu_kernel(
15+
A: T.Tensor((M, K), dtype),
16+
B: T.Tensor((K, N), dtype),
17+
C: T.Tensor((M, N), dtype),
18+
):
19+
# Initialize Kernel Context
20+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
21+
A_shared = T.alloc_shared((block_M, block_K), dtype)
22+
B_shared = T.alloc_shared((block_K, block_N), dtype)
23+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
24+
25+
# Enable rasterization for better L2 cache locality (Optional)
26+
# T.use_swizzle(panel_size=10, enable=True)
27+
28+
# Clear local accumulation
29+
T.clear(C_local)
30+
31+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
32+
# Copy tile of A
33+
# This is a sugar syntax for parallelized copy
34+
T.copy(A[by * block_M, ko * block_K], A_shared)
35+
36+
# Copy tile of B
37+
T.copy(B[ko * block_K, bx * block_N], B_shared)
38+
39+
# Perform a tile-level GEMM on the shared buffers
40+
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
41+
T.gemm(A_shared, B_shared, C_local)
42+
43+
# relu
44+
for i, j in T.Parallel(block_M, block_N):
45+
C_local[i, j] = T.max(C_local[i, j], 0)
46+
47+
# Copy result back to global memory
48+
T.copy(C_local, C[by * block_M, bx * block_N])
49+
50+
return matmul_relu_kernel
51+
52+
53+
def test_quick_start():
54+
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
55+
N = 1024
56+
K = 1024
57+
block_M = 128
58+
block_N = 128
59+
block_K = 32
60+
61+
# Define the kernel (matmul) and compile/lower it into an executable module
62+
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
63+
# Test the kernel in Python with PyTorch data
64+
import paddle
65+
66+
# Create random input tensors on the GPU
67+
a = paddle.randn(M, K, device="cuda", dtype=paddle.float16)
68+
b = paddle.randn(K, N, device="cuda", dtype=paddle.float16)
69+
c = paddle.empty(M, N, device="cuda", dtype=paddle.float16)
70+
71+
# Run the kernel through the Profiler
72+
matmul_relu_kernel(a, b, c)
73+
74+
print(c)
75+
# Reference multiplication using PyTorch
76+
ref_c = paddle.nn.functional.relu(a @ b)
77+
78+
# Validate correctness
79+
np.testing.assert_allclose(c.numpy(), ref_c.numpy(), rtol=1e-2, atol=1e-2)

0 commit comments

Comments
 (0)