Skip to content

Commit 4779f26

Browse files
[Refactor] fused kernel in forward (#1624)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? Shifts fused_linear_for_ppo into model.forward for FSDP ### High-Level Design Self explaining ### Specific Changes - Update monkey patch to return log_probs and entropy instead of last_hidden_state. ### API No changes ### Usage Example ```sh actor_rollout_ref.model.use_fused_kernels=True ``` ### Test ![image](https://github.com/user-attachments/assets/c6af68fb-0200-4aee-9596-0b445afdc562) ### Additional Info. - This is to fix #1565 - The original bug arises because we tried to access model.lm_head.weight from outside of the FSDP wrapped context. ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
1 parent 0286210 commit 4779f26

File tree

13 files changed

+111
-104
lines changed

13 files changed

+111
-104
lines changed

recipe/prime/config/prime_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ reward_model:
3232
model:
3333
ref_path: ${reward_model.model.path}
3434
use_remove_padding: True
35-
use_fused_kernels: False
35+
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
3636
tokenizer_path: ${actor_rollout_ref.model.path}
3737
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
3838
ref_type: freeze

recipe/prime/prime_dp_rm.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, rewa
4747

4848
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
4949

50-
if self.use_fused_kernels:
51-
from verl.utils.experimental.torch_functional import FusedLinearForPPO
52-
53-
self.fused_linear_for_ppo = FusedLinearForPPO()
54-
5550
def _forward_micro_batch(self, micro_batch, prompt_length):
5651
input_ids = micro_batch["input_ids"]
5752
batch_size, seqlen = input_ids.shape
@@ -85,14 +80,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
8580
)
8681

8782
if self.use_fused_kernels:
88-
hidden_states = output.last_hidden_state
89-
vocab_weights = self.reward_module.lm_head.weight
90-
91-
rm_log_labels, _ = self.fused_linear_for_ppo(
92-
hidden_states=hidden_states.squeeze(0),
93-
vocab_weights=vocab_weights,
94-
input_ids=input_ids_rmpad_rolled,
95-
)
83+
rm_log_labels = output.log_probs.squeeze(0) # (total_nnz,)
9684
rm_log_labels = rm_log_labels.to(torch.float32)
9785

9886
else:
@@ -115,14 +103,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
115103
)
116104

117105
if self.use_fused_kernels:
118-
hidden_states = output.last_hidden_state
119-
vocab_weights = self.reward_module.lm_head.weight
120-
121-
rm_log_labels, _ = self.fused_linear_for_ppo.forward(
122-
hidden_states=hidden_states[:, :-1, :],
123-
vocab_weights=vocab_weights,
124-
input_ids=micro_batch["input_ids"][:, 1:],
125-
)
106+
rm_log_labels = output.log_probs[:, :-1] # (bsz, seq_length)
126107
rm_log_labels = rm_log_labels.to(torch.float32)
127108

128109
else:
@@ -142,18 +123,11 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
142123
)
143124

144125
if self.use_fused_kernels:
145-
hidden_states = ref_output.last_hidden_state
146-
vocab_weights = self.ref_module.lm_head.weight
147-
148-
ref_log_labels, _ = self.fused_linear_for_ppo(
149-
hidden_states=hidden_states.squeeze(0),
150-
vocab_weights=vocab_weights,
151-
input_ids=input_ids_rmpad_rolled,
152-
)
126+
ref_log_labels = ref_output.log_probs.squeeze(0) # (total_nnz,)
153127
ref_log_labels = ref_log_labels.to(torch.float32)
154128

155129
else:
156-
logits = ref_output.logits.squeeze(0)
130+
ref_output_logits = ref_output.logits.squeeze(0)
157131
ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, labels=input_ids_rmpad_rolled)
158132

159133
ref_log_labels = gather_outpus_and_unpad(ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
@@ -167,14 +141,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
167141
)
168142

169143
if self.use_fused_kernels:
170-
hidden_states = ref_output.last_hidden_state
171-
vocab_weights = self.ref_module.lm_head.weight
172-
173-
ref_log_labels, _ = self.fused_linear_for_ppo.forward(
174-
hidden_states=hidden_states[:, :-1, :],
175-
vocab_weights=vocab_weights,
176-
input_ids=micro_batch["input_ids"][:, 1:],
177-
)
144+
ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length)
178145
ref_log_labels = ref_log_labels.to(torch.float32)
179146

180147
else:

tests/e2e/run_dapo.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ python3 -m recipe.dapo.main_dapo \
6666
actor_rollout_ref.model.path="${MODEL_PATH}" \
6767
actor_rollout_ref.actor.optim.lr=1e-6 \
6868
actor_rollout_ref.model.use_remove_padding=True \
69+
actor_rollout_ref.model.use_fused_kernels=True \
6970
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
7071
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
7172
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \

tests/e2e/run_prime.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ python3 -m recipe.prime.main_prime \
3434
actor_rollout_ref.model.path="${MODEL_PATH}" \
3535
actor_rollout_ref.actor.optim.lr=5e-7 \
3636
actor_rollout_ref.model.use_remove_padding=True \
37-
actor_rollout_ref.model.use_fused_kernels=False \
37+
actor_rollout_ref.model.use_fused_kernels=True \
3838
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
3939
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
4040
actor_rollout_ref.model.enable_gradient_checkpointing=False \

tests/e2e/run_ray_trainer.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
1717
data.return_raw_input_ids=True \
1818
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
1919
actor_rollout_ref.model.external_lib=tests.e2e.envs.digit_completion \
20+
actor_rollout_ref.model.use_fused_kernels=True \
2021
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=128 \
2122
actor_rollout_ref.actor.entropy_coeff=0 \
2223
actor_rollout_ref.actor.optim.lr=1e-4 \

tests/e2e/run_ray_trainer_rmpad.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
88
algorithm.adv_estimator=gae \
99
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
1010
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
11+
actor_rollout_ref.model.use_fused_kernels=True \
1112
actor_rollout_ref.actor.use_kl_loss=False \
1213
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
1314
actor_rollout_ref.rollout.name=vllm \
@@ -16,4 +17,4 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
1617
critic.model.path=Qwen/Qwen2.5-0.5B \
1718
critic.model.use_remove_padding=True \
1819
algorithm.use_kl_in_reward=False \
19-
trainer.total_epochs=1
20+
trainer.total_epochs=1

tests/e2e/run_sppo.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ python3 -m recipe.sppo.main_sppo \
2424
actor_rollout_ref.model.path="./models/Qwen2.5-0.5B-Instruct" \
2525
actor_rollout_ref.actor.optim.lr=1e-6 \
2626
actor_rollout_ref.model.use_remove_padding=True \
27+
actor_rollout_ref.model.use_fused_kernels=True \
2728
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
2829
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
2930
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
@@ -42,4 +43,4 @@ python3 -m recipe.sppo.main_sppo \
4243
trainer.n_gpus_per_node=8 \
4344
trainer.nnodes=1 \
4445
trainer.save_freq=-1 \
45-
trainer.total_epochs=2 $@
46+
trainer.total_epochs=2 $@

verl/models/transformers/llama.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,12 @@ def llama_attn_forward(
233233

234234

235235
@dataclass
236-
class CausalLMOutputWithoutLogits(CausalLMOutputWithPast):
237-
last_hidden_state: Optional[torch.FloatTensor] = None
236+
class CausalLMOutputForPPO(CausalLMOutputWithPast):
237+
log_probs: Optional[torch.FloatTensor] = None
238+
entropy: Optional[torch.FloatTensor] = None
238239

239240

240-
def forward_without_logits(
241+
def forward_for_ppo(
241242
self,
242243
input_ids: torch.LongTensor = None,
243244
attention_mask: Optional[torch.Tensor] = None,
@@ -251,14 +252,17 @@ def forward_without_logits(
251252
return_dict: Optional[bool] = None,
252253
cache_position: Optional[torch.LongTensor] = None,
253254
logits_to_keep: Union[int, torch.Tensor] = 0,
255+
temperature: float = 1.0,
254256
**loss_kwargs,
255-
) -> Union[Tuple, CausalLMOutputWithoutLogits]:
257+
) -> Union[Tuple, CausalLMOutputForPPO]:
256258
r"""
257259
Copy paste LLaMa's forward
258260
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py
259261
260262
This function should be generic enough for all pure text models.
261263
```"""
264+
from verl.utils.experimental.torch_functional import FusedLinearForPPO
265+
262266
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
263267
output_hidden_states = (
264268
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -281,13 +285,28 @@ def forward_without_logits(
281285

282286
hidden_states = outputs[0]
283287

284-
if labels is not None:
285-
raise NotImplementedError("forward_without_logits does not support labels")
286288
if not return_dict:
287-
raise NotImplementedError("forward_without_logits has to return_dict")
289+
raise NotImplementedError("forward_for_ppo has to return_dict")
290+
291+
# Loss calculations
292+
if labels is not None:
293+
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
294+
elif input_ids is not None:
295+
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
296+
else:
297+
raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.")
298+
299+
fused_linear_for_ppo = FusedLinearForPPO()
300+
log_probs, entropy = fused_linear_for_ppo.forward(
301+
hidden_states=hidden_states,
302+
vocab_weights=self.lm_head.weight,
303+
input_ids=rolled_labels,
304+
temperature=temperature,
305+
)
288306

289-
return CausalLMOutputWithoutLogits(
290-
last_hidden_state=hidden_states,
307+
return CausalLMOutputForPPO(
308+
log_probs=log_probs,
309+
entropy=entropy,
291310
past_key_values=outputs.past_key_values,
292311
hidden_states=outputs.hidden_states,
293312
attentions=outputs.attentions,

verl/models/transformers/monkey_patch.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.modeling_flash_attention_utils import _flash_attention_forward
2626
from transformers.modeling_utils import PreTrainedModel
2727

28+
from verl.models.transformers.llama import forward_for_ppo
2829
from verl.utils.ulysses import (
2930
gather_heads_scatter_seq,
3031
gather_seq_scatter_heads,
@@ -134,9 +135,9 @@ def apply_monkey_patch(
134135
print("Monkey patch FlashAttention2.forward in Qwen2.5VL")
135136

136137
if use_fused_kernels:
137-
from verl.models.transformers.qwen2_5_vl import forward_without_logits
138+
from verl.models.transformers.qwen2_5_vl import forward_for_ppo
138139

139-
Qwen2_5_VLForConditionalGeneration.forward = forward_without_logits
140+
Qwen2_5_VLForConditionalGeneration.forward = forward_for_ppo
140141

141142
return
142143

@@ -153,9 +154,9 @@ def apply_monkey_patch(
153154
print("Monkey patch FlashAttention2.forward in Qwen2VL")
154155

155156
if use_fused_kernels:
156-
from verl.models.transformers.qwen2_vl import forward_without_logits
157+
from verl.models.transformers.qwen2_vl import forward_for_ppo
157158

158-
Qwen2VLForConditionalGeneration.forward = forward_without_logits
159+
Qwen2VLForConditionalGeneration.forward = forward_for_ppo
159160

160161
return
161162

@@ -172,9 +173,9 @@ def apply_monkey_patch(
172173
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
173174

174175
if use_fused_kernels:
175-
from verl.models.transformers.llama import forward_without_logits
176+
from verl.models.transformers.llama import forward_for_ppo
176177

177-
model.__class__.forward = forward_without_logits
178+
model.__class__.forward = forward_for_ppo
178179

179180

180181
@lru_cache

verl/models/transformers/qwen2_5_vl.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323

2424

2525
@dataclass
26-
class Qwen2_5_VLCausalLMOutputWithoutLogits(Qwen2_5_VLCausalLMOutputWithPast):
27-
last_hidden_state: Optional[torch.FloatTensor] = None
26+
class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast):
27+
log_probs: Optional[torch.FloatTensor] = None
28+
entropy: Optional[torch.FloatTensor] = None
2829

2930

30-
def forward_without_logits(
31+
def forward_for_ppo(
3132
self: Qwen2_5_VLForConditionalGeneration,
3233
input_ids: torch.LongTensor = None,
3334
attention_mask: Optional[torch.Tensor] = None,
@@ -46,12 +47,15 @@ def forward_without_logits(
4647
rope_deltas: Optional[torch.LongTensor] = None,
4748
cache_position: Optional[torch.LongTensor] = None,
4849
second_per_grid_ts: Optional[torch.Tensor] = None,
50+
temperature: float = 1.0,
4951
**loss_kwargs,
50-
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithoutLogits]:
52+
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:
5153
r"""
5254
Copy paste Qwen2_5_VL's forward
5355
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py
5456
```"""
57+
from verl.utils.experimental.torch_functional import FusedLinearForPPO
58+
5559
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
5660
output_hidden_states = (
5761
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -137,13 +141,28 @@ def forward_without_logits(
137141

138142
hidden_states = outputs[0]
139143

140-
if labels is not None:
141-
raise NotImplementedError("forward_without_logits does not support labels")
142144
if not return_dict:
143-
raise NotImplementedError("forward_without_logits has to return_dict")
145+
raise NotImplementedError("forward_for_ppo has to return_dict")
146+
147+
# Loss calculations
148+
if labels is not None:
149+
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
150+
elif input_ids is not None:
151+
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
152+
else:
153+
raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.")
154+
155+
fused_linear_for_ppo = FusedLinearForPPO()
156+
log_probs, entropy = fused_linear_for_ppo.forward(
157+
hidden_states=hidden_states,
158+
vocab_weights=self.lm_head.weight,
159+
input_ids=rolled_labels,
160+
temperature=temperature,
161+
)
144162

145-
return Qwen2_5_VLCausalLMOutputWithoutLogits(
146-
last_hidden_state=hidden_states,
163+
return Qwen2_5_VLCausalLMOutputForPPO(
164+
log_probs=log_probs,
165+
entropy=entropy,
147166
past_key_values=outputs.past_key_values,
148167
hidden_states=outputs.hidden_states,
149168
attentions=outputs.attentions,

verl/models/transformers/qwen2_vl.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,12 @@ def ulysses_flash_attn_forward(
293293

294294

295295
@dataclass
296-
class Qwen2VLCausalLMOutputWithoutLogits(Qwen2VLCausalLMOutputWithPast):
297-
last_hidden_state: Optional[torch.FloatTensor] = None
296+
class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):
297+
log_probs: Optional[torch.FloatTensor] = None
298+
entropy: Optional[torch.FloatTensor] = None
298299

299300

300-
def forward_without_logits(
301+
def forward_for_ppo(
301302
self: Qwen2VLForConditionalGeneration,
302303
input_ids: torch.LongTensor = None,
303304
attention_mask: Optional[torch.Tensor] = None,
@@ -315,12 +316,15 @@ def forward_without_logits(
315316
video_grid_thw: Optional[torch.LongTensor] = None,
316317
rope_deltas: Optional[torch.LongTensor] = None,
317318
cache_position: Optional[torch.LongTensor] = None,
319+
temperature: float = 1.0,
318320
**loss_kwargs,
319-
) -> Union[Tuple, Qwen2VLCausalLMOutputWithoutLogits]:
321+
) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:
320322
r"""
321323
Copy paste Qwen2VL's forward
322324
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py
323325
```"""
326+
from verl.utils.experimental.torch_functional import FusedLinearForPPO
327+
324328
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
325329
output_hidden_states = (
326330
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -399,13 +403,28 @@ def forward_without_logits(
399403

400404
hidden_states = outputs[0]
401405

402-
if labels is not None:
403-
raise NotImplementedError("forward_without_logits does not support labels")
404406
if not return_dict:
405-
raise NotImplementedError("forward_without_logits has to return_dict")
407+
raise NotImplementedError("forward_for_ppo has to return_dict")
408+
409+
# Loss calculations
410+
if labels is not None:
411+
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
412+
elif input_ids is not None:
413+
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
414+
else:
415+
raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.")
416+
417+
fused_linear_for_ppo = FusedLinearForPPO()
418+
log_probs, entropy = fused_linear_for_ppo.forward(
419+
hidden_states=hidden_states,
420+
vocab_weights=self.lm_head.weight,
421+
input_ids=rolled_labels,
422+
temperature=temperature,
423+
)
406424

407-
return Qwen2VLCausalLMOutputWithoutLogits(
408-
last_hidden_state=hidden_states,
425+
return Qwen2VLCausalLMOutputForPPO(
426+
log_probs=log_probs,
427+
entropy=entropy,
409428
past_key_values=outputs.past_key_values,
410429
hidden_states=outputs.hidden_states,
411430
attentions=outputs.attentions,

verl/trainer/config/ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ reward_model:
191191
path: ~/models/FsfairX-LLaMA3-RM-v0.1
192192
external_lib: ${actor_rollout_ref.model.external_lib}
193193
use_remove_padding: False
194+
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
194195
trust_remote_code: False
195196
fsdp_config:
196197
wrap_policy:

0 commit comments

Comments
 (0)