Utility functions for OpenAI Triton
Writing fast GPU kernels is easier with Triton than with CUDA, but there is still a lot of tedious indices juggling. That is not necessary.
Triton-util provides simple higher-level abstractions for frequent but repetitive tasks. This allows you to write code that is closer to how you actually think.
Example: Say you have a 2d matrix of shape (max0,max1) and stride
(stride0,stride1), which you have chunked along both axes. Each
chunk is size (sz0,sz1), and you want to get the (n0,n1)th chunk.
With triton-util, you write
load_2d(ptr, sz0, sz1, n0, n1, max0, max1, stride0) # stride1 defaults to 1instead of
offs0 = n0 * sz0 + tl.arange(0, sz0)
offs1 = n1 * sz1 + tl.arange(0, sz1)
offs = offs0[:,None] * stride0 + offs1[None,:] * stride1
mask = (offs0[:,None] < max0) & (offs1[None,:] < max1)
tl.load(ptr + offs, mask)Additionally, triton-util provides **handy utility functions to make
debugging easier. Want to print txt only on the 1st kernel? Write
print_once(txt) - that's it!
Finally, triton-util is progressive, ie you can use as little or as much as you want. It's fully interoperable with triton. (It is, in fact, pure triton.)
pip install triton-util
print_once(txt)
- Print txt, only on 1st kernel (ie all pids = 0)
breakpoint_once()
- Enter breakpoint, only on 1st kernel (ie all pids = 0)
print_if(txt, conds)
- Print txt, if condition on pids is fulfilled
- Eg
print_if(txt, '=0,>1')prints ifpid_0 = 0,pid_1 > 1andpid_2triton_util/is arbitrary
breakpoint_if(conds)
- Enter breakpoint, if condition on pids is fulfilled
- Eg
breakpoint_if('=0,>1')stops ifpid_0 = 0,pid_1 > 1andpid_2is arbitrary
assert_tensors_gpu_ready(*tensors)
- assert all tensors are contiguous, and on GPU (unless
'TRITON_INTERPRET'=='1')
cdiv(a,b)
- ceiling division
offset_1d(sz, n_prev_chunks=0)
- Return 1d offsets to
(n_prev_chunks+1)th chunk of sizesz
offset_2d(offs_0, offs_1, stride_0, stride_1=1)
- Create a 2d offsets from two 1d offsets
mask_1d(offs, max)
- Create a 1d mask from a 1d offset and a max value
mask_2d(offs_0, offs_1, max_0, max_1)
- Create a 2d mask from two 1d offsets and max values
load_1d(ptr, sz, n, max, stride=1)
- Chunk 1d vector (defined by ptr) into 1d grid, where
each chunk has size
sz, and load thenth chunk.
load_full_1d(ptr, sz, stride=1)
- Load 1d block of size
sz
load_2d(ptr, sz0, sz1, n0, n1, max0, max1, stride0, stride1=1)
- Chunk 2d matrix (defined by ptr) into 2d grid, where each chunk
has size
(sz0,sz1), and load the(n0,n1)th chunk.
load_full_2d(ptr, sz0, sz1, stride0, stride1=1)
- Load 2d block of size
sz0 x sz1
store_1d(vals, ptr, sz, n, max, stride=1)
- Store 1d block into
nth chunk of vector (defined by ptr), where each chunk has sizesz
store_full_1d(vals, ptr, sz, stride=1)
- Store 1d block into vector (defined by ptr)
store_2d(vals, ptr, sz0, sz1, n0, n1, max0, max1, stride0, stride1=1)
- Store 2d block into
(n0,n1)th chunk of matrix (defined by ptr), where each chunk has size(sz0, sz1)
store_full_2d(vals, ptr, sz0, sz1, stride0, stride1=1)
- Store 2d block into matrix (defined by ptr)
- a gentle introduction to Triton? - See A Practitioner's Guide to Triton and its accompanying notebook
- world-class real-life triton kernels? - See triton-index
- a crazy competent and kind community where you can ask questions (beginner or advanced!)? - See cuda mode discord, which has a triton channel
Brought to you by Umer ❤️