Skip to content

[Refactor] fused kernel in forward #1624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
2 changes: 1 addition & 1 deletion recipe/prime/config/prime_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ reward_model:
model:
ref_path: ${reward_model.model.path}
use_remove_padding: True
use_fused_kernels: False
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
tokenizer_path: ${actor_rollout_ref.model.path}
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
ref_type: freeze
Expand Down
43 changes: 5 additions & 38 deletions recipe/prime/prime_dp_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, rewa

self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)

if self.use_fused_kernels:
from verl.utils.experimental.torch_functional import FusedLinearForPPO

self.fused_linear_for_ppo = FusedLinearForPPO()

def _forward_micro_batch(self, micro_batch, prompt_length):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
Expand Down Expand Up @@ -85,14 +80,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
)

if self.use_fused_kernels:
hidden_states = output.last_hidden_state
vocab_weights = self.reward_module.lm_head.weight

rm_log_labels, _ = self.fused_linear_for_ppo(
hidden_states=hidden_states.squeeze(0),
vocab_weights=vocab_weights,
input_ids=input_ids_rmpad_rolled,
)
rm_log_labels = output.log_probs.squeeze(0) # (total_nnz,)
rm_log_labels = rm_log_labels.to(torch.float32)

else:
Expand All @@ -115,14 +103,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
)

if self.use_fused_kernels:
hidden_states = output.last_hidden_state
vocab_weights = self.reward_module.lm_head.weight

rm_log_labels, _ = self.fused_linear_for_ppo.forward(
hidden_states=hidden_states[:, :-1, :],
vocab_weights=vocab_weights,
input_ids=micro_batch["input_ids"][:, 1:],
)
rm_log_labels = output.log_probs[:, :-1] # (bsz, seq_length)
rm_log_labels = rm_log_labels.to(torch.float32)

else:
Expand All @@ -142,18 +123,11 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
)

if self.use_fused_kernels:
hidden_states = ref_output.last_hidden_state
vocab_weights = self.ref_module.lm_head.weight

ref_log_labels, _ = self.fused_linear_for_ppo(
hidden_states=hidden_states.squeeze(0),
vocab_weights=vocab_weights,
input_ids=input_ids_rmpad_rolled,
)
ref_log_labels = ref_output.log_probs.squeeze(0) # (total_nnz,)
ref_log_labels = ref_log_labels.to(torch.float32)

else:
logits = ref_output.logits.squeeze(0)
ref_output_logits = ref_output.logits.squeeze(0)
ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, labels=input_ids_rmpad_rolled)

ref_log_labels = gather_outpus_and_unpad(ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
Expand All @@ -167,14 +141,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
)

if self.use_fused_kernels:
hidden_states = ref_output.last_hidden_state
vocab_weights = self.ref_module.lm_head.weight

ref_log_labels, _ = self.fused_linear_for_ppo.forward(
hidden_states=hidden_states[:, :-1, :],
vocab_weights=vocab_weights,
input_ids=micro_batch["input_ids"][:, 1:],
)
ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length)
ref_log_labels = ref_log_labels.to(torch.float32)

else:
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/run_dapo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ python3 -m recipe.dapo.main_dapo \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/run_prime.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ python3 -m recipe.prime.main_prime \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=5e-7 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=False \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.model.enable_gradient_checkpointing=False \
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/run_ray_trainer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
data.return_raw_input_ids=True \
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
actor_rollout_ref.model.external_lib=tests.e2e.envs.digit_completion \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=128 \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.optim.lr=1e-4 \
Expand Down
3 changes: 2 additions & 1 deletion tests/e2e/run_ray_trainer_rmpad.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
algorithm.adv_estimator=gae \
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
actor_rollout_ref.rollout.name=vllm \
Expand All @@ -16,4 +17,4 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.use_remove_padding=True \
algorithm.use_kl_in_reward=False \
trainer.total_epochs=1
trainer.total_epochs=1
3 changes: 2 additions & 1 deletion tests/e2e/run_sppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python3 -m recipe.sppo.main_sppo \
actor_rollout_ref.model.path="./models/Qwen2.5-0.5B-Instruct" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
Expand All @@ -42,4 +43,4 @@ python3 -m recipe.sppo.main_sppo \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_epochs=2 $@
trainer.total_epochs=2 $@
37 changes: 28 additions & 9 deletions verl/models/transformers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,12 @@ def llama_attn_forward(


@dataclass
class CausalLMOutputWithoutLogits(CausalLMOutputWithPast):
last_hidden_state: Optional[torch.FloatTensor] = None
class CausalLMOutputForPPO(CausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
entropy: Optional[torch.FloatTensor] = None


def forward_without_logits(
def forward_for_ppo(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -251,14 +252,17 @@ def forward_without_logits(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithoutLogits]:
) -> Union[Tuple, CausalLMOutputForPPO]:
r"""
Copy paste LLaMa's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py

This function should be generic enough for all pure text models.
```"""
from verl.utils.experimental.torch_functional import FusedLinearForPPO

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand All @@ -281,13 +285,28 @@ def forward_without_logits(

hidden_states = outputs[0]

if labels is not None:
raise NotImplementedError("forward_without_logits does not support labels")
if not return_dict:
raise NotImplementedError("forward_without_logits has to return_dict")
raise NotImplementedError("forward_for_ppo has to return_dict")

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.")

fused_linear_for_ppo = FusedLinearForPPO()
log_probs, entropy = fused_linear_for_ppo.forward(
hidden_states=hidden_states,
vocab_weights=self.lm_head.weight,
input_ids=rolled_labels,
temperature=temperature,
)

return CausalLMOutputWithoutLogits(
last_hidden_state=hidden_states,
return CausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
Expand Down
13 changes: 7 additions & 6 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_utils import PreTrainedModel

from verl.models.transformers.llama import forward_for_ppo
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
Expand Down Expand Up @@ -134,9 +135,9 @@ def apply_monkey_patch(
print("Monkey patch FlashAttention2.forward in Qwen2.5VL")

if use_fused_kernels:
from verl.models.transformers.qwen2_5_vl import forward_without_logits
from verl.models.transformers.qwen2_5_vl import forward_for_ppo

Qwen2_5_VLForConditionalGeneration.forward = forward_without_logits
Qwen2_5_VLForConditionalGeneration.forward = forward_for_ppo

return

Expand All @@ -153,9 +154,9 @@ def apply_monkey_patch(
print("Monkey patch FlashAttention2.forward in Qwen2VL")

if use_fused_kernels:
from verl.models.transformers.qwen2_vl import forward_without_logits
from verl.models.transformers.qwen2_vl import forward_for_ppo

Qwen2VLForConditionalGeneration.forward = forward_without_logits
Qwen2VLForConditionalGeneration.forward = forward_for_ppo

return

Expand All @@ -172,9 +173,9 @@ def apply_monkey_patch(
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")

if use_fused_kernels:
from verl.models.transformers.llama import forward_without_logits
from verl.models.transformers.llama import forward_for_ppo

model.__class__.forward = forward_without_logits
model.__class__.forward = forward_for_ppo


@lru_cache
Expand Down
37 changes: 28 additions & 9 deletions verl/models/transformers/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@


@dataclass
class Qwen2_5_VLCausalLMOutputWithoutLogits(Qwen2_5_VLCausalLMOutputWithPast):
last_hidden_state: Optional[torch.FloatTensor] = None
class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
entropy: Optional[torch.FloatTensor] = None


def forward_without_logits(
def forward_for_ppo(
self: Qwen2_5_VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -46,12 +47,15 @@ def forward_without_logits(
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithoutLogits]:
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:
r"""
Copy paste Qwen2_5_VL's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py
```"""
from verl.utils.experimental.torch_functional import FusedLinearForPPO

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -137,13 +141,28 @@ def forward_without_logits(

hidden_states = outputs[0]

if labels is not None:
raise NotImplementedError("forward_without_logits does not support labels")
if not return_dict:
raise NotImplementedError("forward_without_logits has to return_dict")
raise NotImplementedError("forward_for_ppo has to return_dict")

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.")

fused_linear_for_ppo = FusedLinearForPPO()
log_probs, entropy = fused_linear_for_ppo.forward(
hidden_states=hidden_states,
vocab_weights=self.lm_head.weight,
input_ids=rolled_labels,
temperature=temperature,
)

return Qwen2_5_VLCausalLMOutputWithoutLogits(
last_hidden_state=hidden_states,
return Qwen2_5_VLCausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
Expand Down
37 changes: 28 additions & 9 deletions verl/models/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,12 @@ def ulysses_flash_attn_forward(


@dataclass
class Qwen2VLCausalLMOutputWithoutLogits(Qwen2VLCausalLMOutputWithPast):
last_hidden_state: Optional[torch.FloatTensor] = None
class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
entropy: Optional[torch.FloatTensor] = None


def forward_without_logits(
def forward_for_ppo(
self: Qwen2VLForConditionalGeneration,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -315,12 +316,15 @@ def forward_without_logits(
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithoutLogits]:
) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:
r"""
Copy paste Qwen2VL's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py
```"""
from verl.utils.experimental.torch_functional import FusedLinearForPPO

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -399,13 +403,28 @@ def forward_without_logits(

hidden_states = outputs[0]

if labels is not None:
raise NotImplementedError("forward_without_logits does not support labels")
if not return_dict:
raise NotImplementedError("forward_without_logits has to return_dict")
raise NotImplementedError("forward_for_ppo has to return_dict")

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError("To use forward_for_ppo, either labels or input_ids must be provided.")

fused_linear_for_ppo = FusedLinearForPPO()
log_probs, entropy = fused_linear_for_ppo.forward(
hidden_states=hidden_states,
vocab_weights=self.lm_head.weight,
input_ids=rolled_labels,
temperature=temperature,
)

return Qwen2VLCausalLMOutputWithoutLogits(
last_hidden_state=hidden_states,
return Qwen2VLCausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ reward_model:
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
trust_remote_code: False
fsdp_config:
wrap_policy:
Expand Down
Loading
Loading