Skip to content

[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371

Open
@shivam15s

Description

@shivam15s

🚀 The feature, motivation and pitch

We want to support various alignment and distillation loss functions.
Refer this PR on ORPO: #362

Progress

Alignment

Distillation

Design

Approach Overview:

The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:

  1. Modular Optimization Process:
    • Every alignment algorithm’s optimization can be broken into three key steps:
      • Linear layer computation
      • Loss computation
      • Gradient calculation
  2. Fused Linear and Loss Computation:
    • Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
  3. Chunking & Forward Optimization:
    • Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
    • We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
  4. Torch Compile for Kernel Optimization:
    • Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.

By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.

Key Findings

By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger.
References:

  1. Torch compiled FLCE is 2x faster than the current FLCE #227
  2. https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py

Interface

Have a base class FlexChunkLoss that handles chunking, accumulation and compiling strategies.
A custom loss class wraps the FlexChunkLoss and implements the loss fn that operates on a given chunk.

class Mycustomloss(FlexChunkLoss):
  def loss_fn(...):
    ..do something here

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions