You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
CUTLASS 4.3.1 will be out to fix some CuTe DSL issues very soon.
CuTe DSL
New features:
Supported Apache TVM-FFI for further reduced host runtime overhead for JIT functions, better PyTorch and ML frameworks interopability
Added fake tensor and stream to decouple compile jit function with "from_dlpack" flow. Now we no longer require users to have real tensor when compile jit function.
Added FastDivmodDivisor with Python operator overloads, new APIs, Cute dialect integration, and optimized static tile scheduler performance for faster index mapping.
Added l2 cache evict priority for tma related ops. Users could do fine-grain l2 cache control.
Debuggability improvements:
Supported source location tracking for DSL APIs (Allow tools like nsight profiling to correlate perf metrics with Python source code)
To demonstrate usage of new Pipeline APIs PipelineProducer and PipelineConsumer to simplify code without explicit pipeline state management (Exiting APIs are still maintained)
Separated epilogue code for non-TMA and TMA implementation
Fixed TensorSSA.getitem indexing to match CuTe's indexing convention
Fixed an issue with cutlass.max and cutlass.min
Fixed an issue with mark_compact_shape_dynamic
CUTLASS C++
Further enhance Blackwell SM100 Attention kernels in example 77.
Add softmax skip correction.
Fix a shared memory allocation bug where it needs to opt in maximum dynamics shared memory explicitly once it exceeds 48KB.
Fix a dead hang issue caused by early return warp.
Add support through cmdline argument lists for batch, no_verif, cluster_shape and cluster_shape_fallback in example 89.
Add Ragged Contiguous Grouped gemm kernel in example 92.
This kernel uses a TMA 3D load to load the weights matrix and use the tensormap update method to load activations.
Add 256x128 tile size support for Hopper SM90 deepgemm in example 67.
Performance is optimized to align with Deepseek implementation.
Simplification of API for MoE gemms.
Instead of requiring users to call several cute utilities to set up the stride, API moe_stride_utils is introduced to help setup strides in the kernel.
Instead of requiring users to set vectors like problem_shapes_device and problem_shapes_hosts, a new problem shape struct called MoEProblemShape is introduced which takes in max_m, max_n, max_k and counts vector as input and deduce problem shapes internally whenever required.
Enable GEMM_K = 0 in grouped gemm.
Optimize group gemm kernels by enabling async TMA desc update.
Support Blackwell SM100 convolution stream-K kernel.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
CUTLASS 4.3.1 will be out to fix some CuTe DSL issues very soon.
CuTe DSL
nsightprofiling to correlate perf metrics with Python source code)PipelineProducerandPipelineConsumerto simplify code without explicit pipeline state management (Exiting APIs are still maintained)Baseline + XTensorSSA.reduceto support static value as initial valuemake_layout_tvis_staticPipelineAsyncSmemAllocatorpipeline,utilsandcute.mathCUTLASS C++
batch,no_verif,cluster_shapeandcluster_shape_fallbackin example 89.moe_stride_utilsis introduced to help setup strides in the kernel.problem_shapes_deviceandproblem_shapes_hosts, a new problem shape struct calledMoEProblemShapeis introduced which takes in max_m, max_n, max_k and counts vector as input and deduce problem shapes internally whenever required.cutlass::int8_tand replace it withint8_t.wait_on_dependent_gridsfor PDL use case.bytes_with_problem_shapeof block scaled profiler.This discussion was created from the release CUTLASS 4.3.0.
Beta Was this translation helpful? Give feedback.
All reactions