Release Notes – Release 2.2
Key Features and Enhancements
- [PyTorch] Added support for per-tensor current scaling recipe.
- [PyTorch] Implemented cross-entropy loss with support for splitting computation across multiple devices.
- [PyTorch] Added support for CPU offloading with Megatron-Core style distributed optimizers.
- [PyTorch] Added support for KV cache for FusedAttention, FlashAttention, and UnfusedDotProductAttention backends.
- [PyTorch] Improved bulk TP communication overlap by launching GEMMs on lower priority streams.
- [C/PyTorch] Improved performance for P2P-based Tensor Parallel (TP) communication overlap.
- [Jax] Added support for THD format with ring attention.
- [Jax] Improved performance and memory usage for causal mask in the cuDNN attention backend.
- [C] Added multi-node support for NVIDIA® NVLink for TP overlap with userbuffers.
Fixed Issues
- [PyTorch] Fixed convergence when using context parallelism with a fused attention backend.
- [PyTorch] Fixed a crash using GroupedLinear when the last input has no tokens.
- [PyTorch] Made miscellaneous fixes to improve overall performance of the MXFP8 recipe.
- [PyTorch] Reintroduced support for return_bias argument to all modules, which was silently ignored in v2.0 and v2.1.
- [PyTorch] Reintroduced support for FP8 communication for overlapping reduce-scatter and GEMM when using TP overlap with userbuffers.
- [PyTorch] Fixed gradient accumulation fusion in the LayerNormMLP module.
- [C/PyTorch] Made miscellaneous numerical fixes to the fused attention backend.
- [C] Avoided creating a new cublasLtHandle for every GEMM call to avoid memory leaks.
- [Jax] Fixed shape and sharding inference in fused-attention C++ extension.
- [Jax] Fixed an import error in the encoder example.
Known Issues in This Release
- RTX 5090 is currently unsupported for FP8 execution. Support will be added in v2.3.0.
- Transformer Engine may crash when it is installed via the PyPI registry but is run in an environment with CUDA version < 12.8. A temporary workaround is to install from source until the issue is fixed.
Breaking Changes in This Release
- [PyTorch] The deprecated interval argument for the DelayedScaling recipe has been removed.
- [PyTorch] There are multiple breaking changes in the InferenceParams class.
- New arguments num_heads_kv, head_dim_k, and dtype are required during initialization.
- The user must call a pre_step method to update the InferenceParams state.
- The swap_key_value_dict method has been removed, as the step method now automatically reorders the key/value sequences according to their batch indices.
Deprecated Features
There are no deprecated features in this release.
Miscellaneous
- [PyTorch] The minimum required PyTorch version is changed to 2.1.