Skip to content

Commit aac6651

Browse files
committed
[megatron] feat: Share actor and ref in LoRA
For `compute_ref_log_prob`, we can do that by disabling lora layers temporarily for the forward pass, as base weight are frozen and only lora layers are trained. This has already been supported in FSDP LoRA. Signed-off-by: Hollow Man <[email protected]>
1 parent 706a807 commit aac6651

File tree

5 files changed

+36
-42
lines changed

5 files changed

+36
-42
lines changed

recipe/fully_async_policy/fully_async_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def __init__(
8080
self.device_name = device_name if device_name else self.config.trainer.device
8181

8282
# if ref_in_actor is True, the reference policy will be actor without lora applied
83-
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
83+
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
84+
if lora_rank <= 0:
85+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
86+
self.ref_in_actor = lora_rank > 0
8487

8588
# define in-reward KL control
8689
# kl loss control currently not suppoorted

recipe/one_step_off_policy/ray_trainer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@
3737
from verl.single_controller.ray.base import create_colocated_worker_cls
3838
from verl.trainer.ppo import core_algos
3939
from verl.trainer.ppo.core_algos import agg_loss
40-
from verl.trainer.ppo.metric_utils import (
41-
compute_data_metrics,
42-
compute_throughout_metrics,
43-
compute_timing_metrics,
44-
)
40+
from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics
4541
from verl.trainer.ppo.ray_trainer import (
4642
RayPPOTrainer,
4743
ResourcePoolManager,
@@ -54,9 +50,7 @@
5450
from verl.utils import omega_conf_to_dataclass
5551
from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
5652
from verl.utils.debug import marked_timer
57-
from verl.utils.metric import (
58-
reduce_metrics,
59-
)
53+
from verl.utils.metric import reduce_metrics
6054
from verl.utils.tracking import ValidationGenerationsLogger
6155

6256

@@ -120,7 +114,10 @@ def __init__(
120114
self.validation_generations_logger = ValidationGenerationsLogger()
121115

122116
# if ref_in_actor is True, the reference policy will be actor without lora applied
123-
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
117+
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
118+
if lora_rank <= 0:
119+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
120+
self.ref_in_actor = lora_rank > 0
124121

125122
# define in-reward KL control
126123
# kl loss control currently not suppoorted

recipe/transfer_queue/ray_trainer.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,7 @@
4848

4949
from verl import DataProto
5050
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
51-
from verl.single_controller.ray import (
52-
RayClassWithInitArgs,
53-
RayResourcePool,
54-
RayWorkerGroup,
55-
)
51+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
5652
from verl.single_controller.ray.base import create_colocated_worker_cls
5753
from verl.trainer.config import AlgoConfig
5854
from verl.trainer.ppo import core_algos
@@ -64,33 +60,16 @@
6460
process_validation_metrics,
6561
)
6662
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
67-
from verl.trainer.ppo.utils import (
68-
Role,
69-
WorkerType,
70-
need_critic,
71-
need_reference_policy,
72-
need_reward_model,
73-
)
74-
from verl.utils.checkpoint.checkpoint_manager import (
75-
find_latest_ckpt_path,
76-
should_save_ckpt_esi,
77-
)
63+
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
64+
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
7865
from verl.utils.config import omega_conf_to_dataclass
7966
from verl.utils.debug import marked_timer
8067
from verl.utils.metric import reduce_metrics
8168
from verl.utils.rollout_skip import RolloutSkip
82-
from verl.utils.seqlen_balancing import (
83-
calculate_workload,
84-
get_seqlen_balanced_partitions,
85-
log_seqlen_unbalance,
86-
)
69+
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
8770
from verl.utils.torch_functional import masked_mean
8871
from verl.utils.tracking import ValidationGenerationsLogger
89-
from verl.utils.transferqueue_utils import (
90-
create_transferqueue_client,
91-
get_transferqueue_client,
92-
tqbridge,
93-
)
72+
from verl.utils.transferqueue_utils import create_transferqueue_client, get_transferqueue_client, tqbridge
9473

9574

9675
@dataclass
@@ -401,7 +380,10 @@ def __init__(
401380
)
402381

403382
# if ref_in_actor is True, the reference policy will be actor without lora applied
404-
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
383+
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
384+
if lora_rank <= 0:
385+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
386+
self.ref_in_actor = lora_rank > 0
405387

406388
# define in-reward KL control
407389
# kl loss control currently not suppoorted

verl/trainer/ppo/ray_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,10 @@ def __init__(
341341
)
342342

343343
# if ref_in_actor is True, the reference policy will be actor without lora applied
344-
self.ref_in_actor = (
345-
config.actor_rollout_ref.model.get("lora_rank", 0) > 0
346-
or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
347-
)
344+
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
345+
if lora_rank <= 0:
346+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
347+
self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
348348

349349
# define in-reward KL control
350350
# kl loss control currently not suppoorted

verl/workers/megatron_workers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
except ImportError:
3333
repatch = None
3434

35+
from contextlib import nullcontext
36+
3537
from megatron.core import parallel_state as mpu
3638

3739
from verl import DataProto
@@ -819,6 +821,13 @@ def generate_sequences(self, prompts: DataProto):
819821
@GPUMemoryLogger(role="compute_ref_log_prob", logger=logger)
820822
@DistProfiler.annotate(color="olive", role="ref_compute_log_prob")
821823
def compute_ref_log_prob(self, data: DataProto):
824+
if self.peft_cls is not None:
825+
# if is lora, actor without lora applied is the ref
826+
data.meta_info["is_lora"] = True
827+
data = self.compute_log_prob(data)
828+
# this old_log_probs is in fact ref_log_prob
829+
data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]})
830+
return data
822831
assert self._is_ref
823832
if self._ref_is_offload_param:
824833
load_megatron_model_to_gpu(self.ref_module, load_grad=False)
@@ -845,6 +854,8 @@ def compute_log_prob(self, data: DataProto):
845854
if self._is_offload_param:
846855
load_megatron_model_to_gpu(self.actor_module, load_grad=False)
847856
log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger)
857+
is_lora = data.meta_info.pop("is_lora", False)
858+
adapter_ctx = self.peft_cls.disable_adapter(self.actor_module) if is_lora else nullcontext()
848859
# we should always recompute old_log_probs when it is HybridEngine
849860
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
850861
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
@@ -857,7 +868,8 @@ def compute_log_prob(self, data: DataProto):
857868
if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3":
858869
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
859870

860-
output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=True)
871+
with adapter_ctx:
872+
output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
861873
output = DataProto.from_dict(
862874
tensors={"old_log_probs": output, "entropys": entropys},
863875
meta_info={"temperature": self.config.rollout.temperature},

0 commit comments

Comments
 (0)