diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..8da09efac2d 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -185,6 +185,30 @@ class AsyncGRPOConfig(_BaseConfig): metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + # Parameters that control LoRA training and weight sync + use_lora: bool = field( + default=False, + metadata={ + "help": "Enable LoRA mode. When True, the model is loaded as a PEFT adapter (base model auto-resolved " + "from adapter_config.json), only LoRA weights are trained, and weight sync saves the adapter to disk " + "then tells vLLM to hot-reload via /v1/load_lora_adapter instead of streaming all weights over NCCL." + }, + ) + lora_adapter_path: str | None = field( + default=None, + metadata={ + "help": "Path to the PEFT LoRA adapter directory. Required when use_lora=True. This is where the " + "adapter is saved during weight sync and where vLLM reads it from." + }, + ) + lora_name: str = field( + default="sft", + metadata={ + "help": "The LoRA adapter name registered in vLLM (via --lora-modules name=path). Used both as the " + "'model' field in generation requests and as the adapter name in /v1/load_lora_adapter calls." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -201,6 +225,9 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.use_lora and not self.lora_adapter_path: + raise ValueError("lora_adapter_path is required when use_lora=True") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index a81dad5639f..9a344fba457 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -14,6 +14,7 @@ import math +import os import queue import textwrap import time @@ -60,6 +61,7 @@ def stop(self) -> None: ... def pause(self) -> None: ... def resume(self) -> None: ... def send_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -> None: ... + def reload_lora(self, adapter_path: str, lora_name: str) -> None: ... def update_model_version(self, version: int) -> None: ... @@ -270,7 +272,7 @@ class AsyncGRPOTrainer(_BaseTrainer): def __init__( self, model: str, - reward_funcs: RewardFunc | list[RewardFunc], + reward_funcs: RewardFunc | list[RewardFunc] | None = None, args: AsyncGRPOConfig | None = None, train_dataset: Dataset | IterableDataset | None = None, processing_class: PreTrainedTokenizerBase | None = None, @@ -291,6 +293,19 @@ def __init__( model_name = model model = AutoModelForCausalLM.from_pretrained(model, device_map=None, dtype=torch.float32) + if self.args.use_lora: + lora_count = 0 + for name, param in model.named_parameters(): + param.requires_grad = "lora_" in name + if param.requires_grad: + lora_count += 1 + if lora_count == 0: + raise ValueError( + "use_lora=True but no LoRA parameters found in model. " + "Ensure the model path contains adapter_config.json and adapter weights." + ) + logger.info(f"Enabled gradients on {lora_count} LoRA parameter tensors") + if self.args.use_liger_kernel: raise NotImplementedError("`use_liger_kernel` is not supported yet.") @@ -303,7 +318,9 @@ def __init__( processing_class.pad_token = processing_class.eos_token # Reward functions - if not isinstance(reward_funcs, list): + if rollout_worker is None and reward_funcs is None: + raise ValueError("reward_funcs is required when no custom rollout_worker is provided") + if reward_funcs is not None and not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] # Initialize the Trainer @@ -354,15 +371,15 @@ def __init__( # Use the injected worker (e.g. a stub in tests). The queue is owned by the worker. self.rollout_worker = rollout_worker else: - # Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training. - # DTensor.shape returns the global shape without triggering any all-gather. + # NCCL weight transfer needs full metadata; LoRA mode skips this entirely. weight_names, weight_dtype_names, weight_shapes = [], [], [] - for name, param in model.named_parameters(): - # DDP/FSDP1 wrapping, avoids vllm module not exist error - name = name.removeprefix("module.") - weight_names.append(name) - weight_dtype_names.append(str(param.dtype).split(".")[-1]) - weight_shapes.append(list(param.shape)) + if not self.args.use_lora: + for name, param in model.named_parameters(): + name = name.removeprefix("module.") # DDP/FSDP1 wrapping + weight_names.append(name) + weight_dtype_names.append(str(param.dtype).split(".")[-1]) + weight_shapes.append(list(param.shape)) + self.rollout_worker = AsyncRolloutWorker( model_name=model_name, dataset=train_dataset, @@ -384,6 +401,8 @@ def __init__( weight_names=weight_names, weight_dtype_names=weight_dtype_names, weight_shapes=weight_shapes, + use_lora=self.args.use_lora, + lora_name=self.args.lora_name, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -579,6 +598,48 @@ def _streaming_iter(self): def _sync_weight(self): t0 = time.time() + + if self.args.use_lora: + self._sync_weight_lora(t0) + else: + self._sync_weight_nccl(t0) + + weight_sync_time_s = time.time() - t0 + self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) + logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + + def _sync_weight_lora(self, t0: float): + """LoRA sync: save adapter to disk, then tell vLLM to hot-reload it.""" + adapter_path = self.args.lora_adapter_path + lora_name = self.args.lora_name + + # Pause vLLM FIRST so no requests trigger lazy LoRA loading mid-write + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.pause() + + self.accelerator.wait_for_everyone() + + # All ranks must call save_pretrained so that FSDP2 DTensor full_tensor() collectives + # (which are all-gathers) don't deadlock. Only rank 0 actually writes files to disk. + unwrapped = self.accelerator.unwrap_model(self.model) + if self.accelerator.is_main_process: + logger.info(f"Weight sync (LoRA): saving adapter to {adapter_path}...") + unwrapped.save_pretrained(adapter_path, is_main_process=self.accelerator.is_main_process) + if self.accelerator.is_main_process: + os.sync() + t_save = time.time() + logger.info(f"Weight sync (LoRA): save took {t_save - t0:.1f}s") + + self.accelerator.wait_for_everyone() + + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.reload_lora(adapter_path, lora_name) + self.rollout_worker.resume() + self.model_version += 1 + self.rollout_worker.update_model_version(self.model_version) + + def _sync_weight_nccl(self, t0: float): + """Original NCCL path: stream all weights to vLLM.""" logger.info("Weight sync: pausing vLLM...") if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.pause() @@ -604,9 +665,6 @@ def _sync_weight(self): self.rollout_worker.resume() self.model_version += 1 self.rollout_worker.update_model_version(self.model_version) - weight_sync_time_s = time.time() - t0 - self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") def _inner_training_loop(self, *args, **kwargs): # Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train() diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 4fd11312fd2..e49cd29adbb 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -107,12 +107,14 @@ def __init__( weight_names: list[str] | None = None, weight_dtype_names: list[str] | None = None, weight_shapes: list[list[int]] | None = None, + use_lora: bool = False, + lora_name: str | None = None, ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( "vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'" ) - self.model_name = model_name + self.lora_sync = use_lora self.max_tool_calling_iterations = max_tool_calling_iterations self.dataset = dataset self._dataset_iter = iter(dataset) @@ -127,6 +129,10 @@ def __init__( "is_checkpoint_format": True, } + # When LoRA sync is active, generation requests use the LoRA adapter name + # (e.g. "sft") while the tokenizer still loads from model_name (adapter dir). + self.model_name = lora_name if self.lora_sync else model_name + self.reward_funcs = reward_funcs self.reward_func_names = [f.__name__ for f in reward_funcs] self.num_generations = num_generations @@ -165,7 +171,7 @@ def __init__( self.chat_template_kwargs = chat_template_kwargs or {} self.log_completions = log_completions self.num_completions_to_print = num_completions_to_print - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Always use original path for tokenizer self.tokenizer = add_response_schema(self.tokenizer) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. @@ -181,9 +187,12 @@ def __init__( self.model_version = 0 self.session = None - # Wait for the vLLM server and initialize NCCL weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) - self._init_weight_transfer() + if self.lora_sync: + logger.info("LoRA sync mode: skipping NCCL weight transfer init (will use save-to-disk + HTTP reload)") + self.model_update_group = None + else: + self._init_weight_transfer() def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: float = 2.0) -> None: """Block until the vLLM server is healthy.""" @@ -296,6 +305,18 @@ def resume(self) -> None: requests.post(f"{self.vllm_server_url}/resume") logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") + def reload_lora(self, adapter_path: str, lora_name: str) -> None: + """Tell vLLM to hot-reload a LoRA adapter from disk.""" + t0 = time.time() + payload = { + "lora_name": lora_name, + "lora_path": adapter_path, + "load_inplace": True, + } + resp = requests.post(f"{self.vllm_server_url}/v1/load_lora_adapter", json=payload, timeout=120) + resp.raise_for_status() + logger.info(f"[weight_sync] LoRA reload ({lora_name} from {adapter_path}) took {time.time() - t0:.1f}s") + def send_weights(self, iterator) -> None: if self.model_update_group is None: return