You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
each torch.tensor that is passed on to a triton kernel is implicity converted to it's pointer to the first element. hence the parameters are like x_ptr, y_ptr etc.,
tl.constexpr designates BLOCK_SIZE as a compile time variable instead of a runtime variable. this means any change to this means essentially a different kernel.
the arguments of a kernel are also called metaparameters.
if the kernel deals with multiple dimensions, then tl.program_id() will allow us to access tl.program_id(axis=0,1, etc.,)
tl.load() is a memory operation so we want to keep track of how many times we load data onto the gpu memory.
tl.store() will write back to the SRAM
the purpose of the helper/wrapper function that calls the kernel is to
allocate memory for the output vector
enque the kernel calls
triton kernels can't automatically move data between devices so we have to manually make sure the data is on the same device before calling the kernel function.
@triton.testing.perf_report is a decorator that's part of the triton built in utility that lets us benchmark custom ops. it lets us set the conditions under which the benchmark will be performed.
torch.jit.script takes python code and converts it to a static graph, essentially converting it to C++ ahead of time.
torch.compile() is more modern and flexible that torch.jit.script. it can handle dynamic code unlike torch.compile.script and optimizes code as it runs.