Description
🚀 The feature, motivation and pitch
We want to support various alignment and distillation loss functions.
Refer this PR on ORPO: #362
Progress
Alignment
- ORPO Add Chunked ORPO Loss #362
- CPO Adds the CPO Alignment Loss Function #382
- DPO Support Chunked DPO Loss Kernel #378
- SimPO Add Chunked SimPO Loss #386
- IRPO
- KTO Add KTO Loss #475
- f-PO
Distillation
- KL divergence
- cosine_similarity
- earth_mover_distance
- JSD Add JSD Loss for Distillation #425
- KVD
Design
Approach Overview:
The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:
- Modular Optimization Process:
- Every alignment algorithm’s optimization can be broken into three key steps:
- Linear layer computation
- Loss computation
- Gradient calculation
- Every alignment algorithm’s optimization can be broken into three key steps:
- Fused Linear and Loss Computation:
- Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
- Chunking & Forward Optimization:
- Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
- We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
- Torch Compile for Kernel Optimization:
- Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.
By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.
Key Findings
By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger.
References:
- Torch compiled FLCE is 2x faster than the current FLCE #227
- https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
Interface
Have a base class FlexChunkLoss
that handles chunking, accumulation and compiling strategies.
A custom loss class wraps the FlexChunkLoss
and implements the loss fn that operates on a given chunk.
class Mycustomloss(FlexChunkLoss):
def loss_fn(...):
..do something here
Alternatives
No response
Additional context
No response