Skip to content

Commit bf2c67b

Browse files
shivam15smRSun15
andauthored
Fix GRPO to conform with TRL: Fix loss, make tests accurate, correct metrics computation (#628)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Tries to address #626 and other correctness improvements <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: mRSun15 <3150105645@zju.edu.cn>
1 parent f248529 commit bf2c67b

8 files changed

Lines changed: 750 additions & 415 deletions

File tree

dev/modal/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
1515

1616

17-
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
17+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
1818
def liger_tests():
1919
import subprocess
2020

dev/modal/tests_bwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
1515

1616

17-
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
17+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
1818
def liger_bwd_tests():
1919
import subprocess
2020

src/liger_kernel/chunked_loss/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
22
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
34
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
45
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
56
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@
1112
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
1213
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
1314
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
from abc import abstractmethod
2+
from functools import partial
3+
4+
import torch
5+
import torch._dynamo.config
6+
import torch.nn.functional as F
7+
8+
9+
class LigerFusedLinearPPOBase(torch.autograd.Function):
10+
@abstractmethod
11+
def ppo_loss_fn(*args, **kwargs):
12+
"""
13+
To be extended by subclasses.
14+
"""
15+
raise NotImplementedError("PPO loss function must be implemented.")
16+
17+
@staticmethod
18+
def forward(
19+
cls,
20+
ctx,
21+
_input,
22+
weight,
23+
selected_token_ids,
24+
attention_mask,
25+
advantages,
26+
bias=None,
27+
ref_per_token_logps=None,
28+
old_per_token_logps=None,
29+
ref_input=None,
30+
ref_weight=None,
31+
ref_bias=None,
32+
epsilon_low=0.2,
33+
epsilon_high=0.2,
34+
beta=0.04,
35+
temperature=1.0,
36+
compiled=True,
37+
use_ref_model=False,
38+
chunk_size=1,
39+
):
40+
"""Chunked forward pass for PPO loss computation.
41+
42+
Args:
43+
cls: The class
44+
ctx: Context for backward
45+
_input: Input tensor
46+
weight: Weight tensor
47+
selected_token_ids: Selected token ids tensor
48+
attention_mask: Attention mask tensor
49+
advantages: Advantages tensor
50+
bias: Bias tensor
51+
ref_per_token_logps: Reference model log probs per token tensor
52+
old_per_token_logps: Old per token log probabilities tensor
53+
ref_input: Reference model input tensor
54+
ref_weight: Reference model weight tensor
55+
ref_bias: Reference model bias tensor
56+
epsilon_low: Lower bound for clipping the importance sampling ratio
57+
epsilon_high: Upper bound for clipping the importance sampling ratio
58+
beta: Weight for the KL penalty
59+
temperature: Temperature for the logits
60+
compiled: Whether to use torch compile
61+
use_ref_model: Whether to use a reference model
62+
chunk_size: Size of chunks for processing in other loss modules
63+
"""
64+
if use_ref_model:
65+
assert ref_per_token_logps is not None or ref_input is not None, (
66+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
67+
)
68+
if ref_per_token_logps is not None and ref_input is not None:
69+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
70+
# Initialize accumulators
71+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
72+
grad_weight = torch.zeros_like(weight) # [V, H]
73+
grad_inputs = []
74+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
75+
aggregated_metrics = []
76+
77+
# Create a partial function with fixed arguments
78+
compute_loss = partial(
79+
LigerFusedLinearPPOBase._compute_chunk_loss,
80+
ref_weight=ref_weight,
81+
ref_bias=ref_bias,
82+
full_attention_mask=attention_mask,
83+
epsilon_low=epsilon_low,
84+
epsilon_high=epsilon_high,
85+
beta=beta,
86+
temperature=temperature,
87+
use_ref_model=use_ref_model,
88+
ppo_loss_fn=cls.ppo_loss_fn,
89+
)
90+
91+
def fused_fwd_bwd(
92+
input_chunk,
93+
selected_token_ids_chunk,
94+
attention_mask_chunk,
95+
advantages_chunk,
96+
ref_per_token_logps_chunk,
97+
old_per_token_logps_chunk,
98+
ref_input_chunk,
99+
):
100+
"""Fused forward and backward for a chunk."""
101+
argnums = (0, 1, 5) if bias is not None else (0, 1)
102+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
103+
input_chunk, # arg 0
104+
weight, # arg 1
105+
selected_token_ids_chunk, # arg 2
106+
attention_mask_chunk, # arg 3
107+
advantages_chunk, # arg 4
108+
bias, # arg 5
109+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
110+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
111+
ref_input_chunk=ref_input_chunk, # arg 8
112+
)
113+
114+
def accumulate_chunk(
115+
input_chunk,
116+
selected_token_ids_chunk,
117+
attention_mask_chunk,
118+
advantages_chunk,
119+
ref_per_token_logps_chunk=None,
120+
old_per_token_logps_chunk=None,
121+
ref_input_chunk=None,
122+
):
123+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
124+
input_chunk,
125+
selected_token_ids_chunk,
126+
attention_mask_chunk,
127+
advantages_chunk,
128+
ref_per_token_logps_chunk,
129+
old_per_token_logps_chunk,
130+
ref_input_chunk,
131+
)
132+
if bias is not None:
133+
grad_bias.add_(chunk_grad_bias[0])
134+
135+
# Accumulate gradients and loss
136+
grad_weight.add_(chunk_grad_weight)
137+
grad_inputs.append(chunk_grad_input)
138+
loss_acc.add_(chunk_loss)
139+
# Initialize storage for metrics on first chunk
140+
if len(aggregated_metrics) == 0:
141+
for metric in chunk_metrics:
142+
if metric.ndim == 0:
143+
aggregated_metrics.append(torch.zeros((), device=metric.device))
144+
else:
145+
aggregated_metrics.append([])
146+
147+
# Accumulate metrics
148+
for i, metric in enumerate(chunk_metrics):
149+
if metric.ndim == 0:
150+
aggregated_metrics[i].add_(metric)
151+
else:
152+
aggregated_metrics[i].append(metric)
153+
154+
if compiled:
155+
# TODO: Figure out what is better to compile here
156+
# accumulate_chunk = torch.compile(accumulate_chunk)
157+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
158+
159+
# Process input in chunks based on chunk_size
160+
chunks = max(1, _input.shape[0] // chunk_size)
161+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
162+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
163+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
164+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
165+
_ref_per_token_logps_chunks = (
166+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
167+
if use_ref_model and ref_per_token_logps is not None
168+
else [None] * chunks
169+
)
170+
_old_per_token_logps_chunks = (
171+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
172+
if old_per_token_logps is not None
173+
else [None] * chunks
174+
)
175+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
176+
_ref_input_chunks = (
177+
torch.chunk(ref_input, chunks=chunks, dim=0)
178+
if use_ref_model and ref_per_token_logps is None
179+
else [None] * chunks
180+
)
181+
182+
for (
183+
input_chunk,
184+
selected_token_ids_chunk,
185+
attention_mask_chunk,
186+
advantages_chunk,
187+
ref_per_token_logps_chunk,
188+
old_per_token_logps_chunk,
189+
ref_input_chunk,
190+
) in zip(
191+
_input_chunks,
192+
_selected_token_ids_chunks,
193+
_attention_mask_chunks,
194+
_advantages_chunks,
195+
_ref_per_token_logps_chunks,
196+
_old_per_token_logps_chunks,
197+
_ref_input_chunks,
198+
):
199+
# Mark dynamic dimensions
200+
torch._dynamo.mark_dynamic(input_chunk, 1)
201+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
202+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
203+
if ref_per_token_logps_chunk is not None:
204+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
205+
if ref_input_chunk is not None:
206+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
207+
if old_per_token_logps_chunk is not None:
208+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
209+
210+
accumulate_chunk(
211+
input_chunk,
212+
selected_token_ids_chunk,
213+
attention_mask_chunk,
214+
advantages_chunk,
215+
ref_per_token_logps_chunk,
216+
old_per_token_logps_chunk,
217+
ref_input_chunk,
218+
)
219+
220+
# Combine gradients
221+
grad_input = torch.cat(grad_inputs, dim=0)
222+
223+
# Save for backward
224+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
225+
226+
# Finalize metrics
227+
final_metrics = []
228+
for metric in aggregated_metrics:
229+
if isinstance(metric, list):
230+
final_metrics.append(torch.cat(metric, dim=0))
231+
else:
232+
final_metrics.append(metric)
233+
234+
return loss_acc, tuple(final_metrics)
235+
236+
@staticmethod
237+
def _compute_chunk_loss(
238+
input_chunk,
239+
weight,
240+
selected_token_ids_chunk,
241+
attention_mask_chunk,
242+
advantages_chunk,
243+
bias=None,
244+
ref_per_token_logps_chunk=None,
245+
old_per_token_logps_chunk=None,
246+
ref_input_chunk=None,
247+
ref_weight=None,
248+
ref_bias=None,
249+
full_attention_mask=None,
250+
epsilon_low=0.2,
251+
epsilon_high=0.2,
252+
beta=0.04,
253+
temperature=1.0,
254+
use_ref_model=False,
255+
ppo_loss_fn=None,
256+
):
257+
"""Compute loss for a single chunk."""
258+
# Get policy log probabilities using chunk_forward
259+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
260+
261+
# Get reference log probabilities if needed
262+
ref_log_probs = None
263+
if use_ref_model and ref_per_token_logps_chunk is None:
264+
with torch.no_grad():
265+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
266+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
267+
)
268+
269+
# Compute chunk loss and metrics using the provided loss function
270+
chunk_loss, chunk_metrics = ppo_loss_fn(
271+
log_probs=log_probs,
272+
selected_token_ids=selected_token_ids_chunk,
273+
attention_mask=attention_mask_chunk,
274+
advantages=advantages_chunk,
275+
full_attention_mask=full_attention_mask,
276+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
277+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
278+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
279+
epsilon_low=epsilon_low,
280+
epsilon_high=epsilon_high,
281+
beta=beta,
282+
)
283+
284+
return chunk_loss, chunk_metrics
285+
286+
@staticmethod
287+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
288+
"""Forward pass computation for a single chunk without explicit reshaping."""
289+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
290+
logits = torch.matmul(input_chunk, weight.t())
291+
if bias is not None:
292+
logits = logits + bias # Broadcasts bias to [B, T, V]
293+
if temperature != 1.0:
294+
logits = logits / temperature
295+
296+
# Compute log probabilities using softmax over the last dimension
297+
log_probs = F.log_softmax(logits.float(), dim=-1)
298+
299+
return log_probs, logits
300+
301+
@staticmethod
302+
def backward(ctx, grad_output, *grad_metrics):
303+
"""Backward pass for PPO loss."""
304+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
305+
if grad_output != 1.0:
306+
grad_input = grad_input * grad_output
307+
grad_weight = grad_weight * grad_output
308+
if grad_bias is not None:
309+
grad_bias = grad_bias * grad_output
310+
311+
return (
312+
grad_input,
313+
grad_weight,
314+
None, # grad_selected_token_ids
315+
None, # grad_attention_mask
316+
None, # grad_advantages
317+
grad_bias,
318+
None, # grad_ref_per_token_logps
319+
None, # grad_old_per_token_logps
320+
None, # grad_ref_input
321+
None, # grad_ref_weight
322+
None, # grad_ref_bias
323+
None, # grad_epsilon_low
324+
None, # grad_epsilon_high
325+
None, # grad_beta
326+
None, # grad_temperature
327+
None, # grad_compiled
328+
None, # grad_use_ref_model
329+
None, # grad_chunk_size
330+
)

0 commit comments

Comments
 (0)