-
Notifications
You must be signed in to change notification settings - Fork 290
Expand file tree
/
Copy pathagentic_grpo_learner.py
More file actions
567 lines (499 loc) · 19.9 KB
/
agentic_grpo_learner.py
File metadata and controls
567 lines (499 loc) · 19.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements an RLLearner for the Agentic GRPO algorithm.
This learner orchestrates the process of generating multiple text completions
for each prompt from a dataset, computing rewards and advantages according to
the GRPO (Group-wise Reward Policy Optimization) algorithm, and then training
the actor model.
The data flow is designed around an asynchronous producer-consumer pattern:
1. A producer generates rollouts (text generations) in parallel for each prompt.
2. These rollouts are grouped by the original prompt.
3. For each group, rewards and advantages are computed.
4. The resulting training examples are put into a queue.
5. The main training loop consumes these examples to update the model weights.
"""
from __future__ import annotations
import dataclasses
from typing import Any, Dict, List, Sequence, Type, TypeVar
from absl import logging
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl import algo_core # pylint: disable=unused-import
from tunix.perf.experimental import constants as perf_constants
from tunix.rl import common
from tunix.rl import function_registry
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl import utils as rl_utils
from tunix.rl.agentic import agentic_rl_learner
from tunix.rl.agentic import utils as agentic_utils
from tunix.rl.agentic.agents import base_agent
from tunix.rl.agentic.agents import model_agent
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.environments import task_environment
from tunix.rl.algo_core import utils as ppo_helpers
from tunix.utils import trajectory_logger
TrainingInputT = agentic_rl_learner.TrainingInputT
RewardFn = agentic_rl_learner.RewardFn
MetricFn = agentic_rl_learner.MetricFn
TrainExample = agentic_rl_learner.TrainExample
@dataclasses.dataclass(slots=True, kw_only=True)
class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
"""Configuration for GRPO algorithm.
Attributes:
algo_variant: Algorithm variant name.
advantage_estimator: Name of the advantage estimator function.
policy_loss_fn: Name of the policy loss function.
loss_agg_mode: Method for aggregating the loss. Supported values:
"token-mean", "sequence-mean-token-mean", "sequence-mean-token-scale",
"sequence-mean-token-sum-norm".
num_generations: Number of samples per prompt (G in the paper). Must be > 1.
num_iterations: Number of GRPO iterations per batch (μ in the paper).
beta: KL penalty coefficient.
kl_loss_mode: Method for computing the KL loss.
force_compute_kl: Whether to force compute KL divergence for logging
even when it would normally be skipped (e.g., when beta is 0.0).
epsilon: PPO-style clipping epsilon.
epsilon_high: PPO-style clipping epsilon upper bound.
loss_algo: "grpo" or "gspo-token".
system_prompt: System prompt for the agent.
max_concurrency: Maximum number of concurrent rollout engines.
off_policy_steps: Number of off-policy steps can be accepted before a policy
update.
degenerate_group_masking: Whether to mask out degenerate groups with all-0
advantages.
"""
algo_variant: str = "agentic_grpo"
advantage_estimator: str = "agentic_grpo"
policy_loss_fn: str = "agentic_grpo"
loss_agg_mode: str = "sequence-mean-token-mean"
loss_algo: (
str
) = ( # grpo or gspo-token # TODO(sizhi): Remove this option once gspo is
# refactored to a separate loss fn.
"grpo"
)
num_generations: int = 2
num_iterations: int = 1
beta: float = 0.04
kl_loss_mode: str = "kl"
force_compute_kl: bool = False
epsilon: float = 0.2
system_prompt: str = ""
max_concurrency: int = 16
epsilon_high: float | None = None # 0.28 from DAPO.
off_policy_steps: int = 0
degenerate_group_masking: bool = (
True # Whether to mask out degenerate groups with all-0 advantages.
)
def __post_init__(self):
if self.num_generations <= 1:
raise ValueError(
"num_generations must be greater than 1. Received: "
f"{self.num_generations}"
)
if self.epsilon_high is None:
self.epsilon_high = self.epsilon
if self.loss_algo not in ["grpo", "gspo-token"]:
raise ValueError(
"loss_algo should be either grpo or gspo-token. Received: "
f"{self.loss_algo}"
)
TGrpoConfig = TypeVar("TGrpoConfig", bound=GRPOConfig)
class GRPOLearner(agentic_rl_learner.AgenticRLLearner[TGrpoConfig]):
"""An RLLearner that implements the GRPO algorithm in an agentic setting.
GRPO is a reinforcement learning algorithm designed to enhance the reasoning
abilities of large language models, like mathematical problem-solving. It is
a variant of Proximal Policy Optimization (PPO) that reduces memory usage by
eliminating the need for a separate value function model. GRPO works by
generating multiple responses for a given prompt, evaluating these responses
using a reward model, and then calculating a relative advantage based on the
group's performance to update the policy.
References:
- https://arxiv.org/abs/2402.03300
"""
def __init__(
self,
rl_cluster: rl_cluster_lib.RLCluster,
algo_config: TGrpoConfig,
reward_fns: RewardFn | List[RewardFn] | None = None,
chat_parser: Any | None = None,
metric_fns: Sequence[MetricFn] | None = None,
agent_class: Type[
base_agent.ConversationAgentBase
] = model_agent.ModelAgent,
agent_kwargs: Dict[str, Any] | None = None,
env_class: Type[
base_environment.BaseTaskEnv
] = task_environment.TaskEnvironment,
env_kwargs: Dict[str, Any] | None = None,
):
"""Initializes the `GRPOTrainer`.
Args:
rl_cluster: RL cluster containing actor, reference and reward models.
reward_fns: A single callable or a list of callables that compute a
scalar reward for given prompts and completions. Each function should
accept `prompts`, `completions` and optional keyword arguments, and
return a list of float rewards.
algo_config: An instance of `GRPOConfig` containing all GRPO specific
parameters.
chat_parser: A parser to handle chat message formatting.
metric_fns: A sequence of callables that compute metrics for the
completions. Each callable should accept ``prompts``, ``completions``,
``rewards``, ``advantages`` and optional keyword arguments, and return
a dictionary of metric names to tuples of
``(metric_value, aggregation_fn)``:
>>> def metric_fn(
... prompts, completions, rewards, advantages, **kargs
... ):
... return {
... # ...
... "prompt_min_len": (min(len(p) for p in prompts), np.min),
... # ... }
agent_class: The class of the agent to be used.
agent_kwargs: Keyword arguments to pass to the agent class.
env_class: The class of the environment to be used.
env_kwargs: Keyword arguments to pass to the environment class.
""" # fmt: skip
super().__init__(
rl_cluster=rl_cluster,
reward_fns=reward_fns,
metric_fns=metric_fns,
algo_config=algo_config,
chat_parser=chat_parser,
agent_class=agent_class,
agent_kwargs=agent_kwargs,
env_class=env_class,
env_kwargs=env_kwargs,
)
self._trajectory_logger = None
metrics_logger_options = (
self.rl_cluster.cluster_config.training_config.metrics_logging_options
)
metrics_log_dir = (
metrics_logger_options.log_dir if metrics_logger_options else None
)
if metrics_log_dir:
self._trajectory_logger = trajectory_logger.AsyncTrajectoryLogger(
metrics_log_dir
)
else:
logging.warning("Metrics log dir is None, skipping trajectory logging.")
# Workaround to pass loss fn with algorithm flag
policy_loss_fn = function_registry.get_policy_loss_fn(
self.algo_config.policy_loss_fn
)
loss_fn = lambda model, train_example, algo_config: policy_loss_fn(
model,
train_example,
algo_config=self.algo_config,
pad_id=self.rl_cluster.rollout.pad_id(),
eos_id=self.rl_cluster.rollout.eos_id(),
)
self.rl_cluster.actor_trainer.with_loss_fn(
loss_fn,
has_aux=True,
)
self.rl_cluster.actor_trainer.with_gen_model_input_fn(
lambda x: {
"train_example": x,
"algo_config": self.algo_config,
}
)
self.rl_cluster.actor_trainer.with_rl_metrics_to_log({
"kl": np.mean,
"entropy": np.mean,
"pg_loss": np.mean,
"pg_clipfrac": np.mean,
"ppo_kl": np.mean,
})
self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([
lambda: "kl"
if self.algo_config.force_compute_kl
or self.algo_config.beta != 0.0
else None,
])
def _process_results(
self,
trajectories: List[Any],
mode: rl_cluster_lib.Mode = rl_cluster_lib.Mode.TRAIN,
expected_step: int | None = None,
) -> List[TrainExample]:
"""Processes generation results, computes rewards and advantages.
This is a core method that performs several steps:
1. Extracts completions from the raw trajectory results.
2. Pads prompt and completion tokens to a consistent length.
3. Computes masks for prompts and completions.
4. Gets reference and old model log probabilities if required.
5. Computes rewards for each completion using the provided reward functions.
6. Computes GRPO-specific advantages from the rewards.
7. Buffers metrics for logging.
8. Constructs and returns a list of `TrainExample` objects.
Args:
trajectories: A list of trajectory results for a single GRPO group.
mode: The current mode (TRAIN or EVAL).
expected_step: The expected training step.
Returns:
A list of `TrainExample` instances containing all data needed for the
loss function.
Raises:
ValueError: If `policy_version` is missing from any trajectory task.
RuntimeError: If `old_per_token_logps` is not available for off-policy RL.
"""
logging.debug(
"Processing results to compute advantage for %d items.",
len(trajectories),
)
# With a full group, sorting by pair_index is not necessary as they all
# originate from the same initial prompt.
pad_value = self.rl_cluster.rollout.pad_id()
eos_value = self.rl_cluster.rollout.eos_id()
# Extract completions and tokens from the group of G results.
completion_texts: List[str] = []
completion_tokens_list: List[np.ndarray] = []
completion_masks_list: List[np.ndarray] = []
old_logprobs_list: List[np.ndarray] = []
policy_versions_list: List[int] = []
trajectory_rewards_list: List[float] = []
trajectories_to_log = []
for item in trajectories:
trajectories_to_log.append(item.traj)
conversation = item.traj.get("conversation_text") or []
assistant_text = next(
(
message["content"]
for message in conversation
if message["role"] == "assistant"
),
"",
)
completion_texts.append(assistant_text)
completion_tokens_list.append(item.traj.get("conversation_tokens"))
completion_masks_list.append(item.traj.get("conversation_masks"))
old_logprobs_list.append(item.traj.get("old_logprobs"))
policy_version = item.traj.get("policy_version")
if policy_version is None:
raise ValueError("policy_version is missing from trajectory task.")
policy_versions_list.append(policy_version)
trajectory_rewards_list.append(item.traj.get("trajectory_reward"))
# Log trajectory.
if self._trajectory_logger and trajectories_to_log:
for traj in trajectories_to_log:
self._trajectory_logger.log_item_async(traj)
# All results in a group share the same prompt.
prompt_tokens = trajectories[0].traj.get("prompt_tokens")
# Pad all prompts and completions to consistent lengths.
rollout_config = self.rl_cluster.cluster_config.rollout_config
if isinstance(rollout_config, dict):
rollout_config = rollout_config[mode]
padded_prompt_ids = []
padded_completion_ids = []
padded_completion_masks = []
padded_old_logprobs = []
max_response_length = self.algo_config.max_response_length
clipped_completion_count = 0
for completion_tokens, completion_mask, old_logprobs in zip(
completion_tokens_list, completion_masks_list, old_logprobs_list
):
if (
len(completion_tokens) >= max_response_length
and completion_mask[-1] != eos_value
):
clipped_completion_count += 1
padded_prompt, padded_completion, _ = (
agentic_utils.pad_prompt_and_completion(
prompt_tokens,
completion_tokens,
rollout_config.max_prompt_length,
max_response_length,
pad_value,
)
)
padded_prompt_ids.append(padded_prompt)
padded_completion_ids.append(padded_completion[:max_response_length])
padded_completion_masks.append(
agentic_utils.right_pad(completion_mask, max_response_length, 0)[
:max_response_length
]
)
if old_logprobs is not None:
padded_old_logprobs.append(
agentic_utils.right_pad(
old_logprobs,
length=max_response_length,
pad=0.0,
dtype=old_logprobs.dtype,
)[:max_response_length]
)
prompt_ids = jnp.asarray(padded_prompt_ids)
prompt_mask = prompt_ids != pad_value
completion_ids = jnp.asarray(padded_completion_ids)
completion_mask = jnp.asarray(padded_completion_masks)
logging.debug(
"Token shapes: prompt_ids=%s, completion_ids=%s",
prompt_ids.shape,
completion_ids.shape,
)
if padded_old_logprobs and len(padded_old_logprobs) == len(
completion_tokens_list
):
old_per_token_logps = jnp.asarray(padded_old_logprobs)
else:
old_per_token_logps = None
if self.algo_config.num_iterations > 1 and old_per_token_logps is None:
raise RuntimeError(
"old_per_token_logps is not available for off-policy RL. Enable "
" `return_logprobs` in RolloutConfig."
)
# Collect perf tags
traj = trajectories[0].traj
group_id = traj.get("group_id")
if group_id is None:
original_input = traj.get("original_input", {})
group_id = original_input.get("group_id")
perf_tags = {
perf_constants.STEP: expected_step,
}
if group_id is not None:
perf_tags[perf_constants.GROUP_ID] = group_id
if self.algo_config.force_compute_kl or self.algo_config.beta != 0.0:
with self.rl_cluster.perf_v2.span(
perf_constants.REFERENCE_INFERENCE,
devices=self.rl_cluster.r2m[rl_cluster_lib.Role.REFERENCE].devices,
tags=perf_tags,
) as interval_v2:
ref_per_token_logps = self.rl_cluster.get_ref_per_token_logps(
prompt_tokens=prompt_ids,
completion_tokens=completion_ids,
pad_id=pad_value,
eos_id=eos_value,
micro_batch_size=None,
)
interval_v2.async_end([ref_per_token_logps])
else:
ref_per_token_logps = None
# Rewards & advantages
# Prepare arguments for reward computation by forwarding all training inputs
# except for prompts, which is passed explicitly.
original_input = trajectories[0].traj["original_input"]
original_inputs = rl_utils.merge_micro_batches(
[original_input] * self.algo_config.num_generations
)
prompt_token_len = len(prompt_tokens)
self.rl_cluster.buffer_metrics_async(
{
"generation/prompts/mean_length": (prompt_token_len, np.mean),
"generation/prompts/max_length": (prompt_token_len, np.max),
"generation/prompts/min_length": (prompt_token_len, np.min),
},
mode=mode,
step=expected_step,
)
reward_kwargs = {
key: value for key, value in original_inputs.items() if key != "prompts"
}
reward_kwargs["trajectory_rewards"] = trajectory_rewards_list
with self.rl_cluster.perf_v2.span(
perf_constants.ADVANTAGE_COMPUTATION,
tags=perf_tags,
):
rewards = self._compute_rewards(
prompts=original_inputs["prompts"],
completions=completion_texts,
mode=mode,
**reward_kwargs,
expected_step=expected_step,
)
advantage_estimator = function_registry.get_advantage_estimator(
self.algo_config.advantage_estimator
)
advantages = advantage_estimator(
rewards=rewards, num_generations=self.algo_config.num_generations
)
logging.debug("Advantages computed: %s", advantages)
if self.algo_config.degenerate_group_masking:
if jnp.all(jnp.isclose(advantages, 0.0)):
logging.info(
"Filtering degenerate group %s with all-0 advantages.", group_id
)
completion_mask = jnp.zeros_like(completion_mask)
policy_versions = np.array(policy_versions_list, dtype=np.int32)
# Log completion lengths, rewards and env time.
agg_completion_mask = completion_mask.sum(axis=-1)
metrics_to_log = {
"generation/completions/mean_length": (
np.mean(agg_completion_mask),
np.mean,
),
"generation/completions/max_length": (
np.max(agg_completion_mask),
np.max,
),
"generation/completions/min_length": (
np.min(agg_completion_mask),
np.min,
),
"generation/completions/clip_ratio": (
clipped_completion_count / len(trajectories),
np.mean,
),
}
# Extract time metrics (env_time and reward_time)
for time_key, prefix in [
("env_time", "generation/trajectory/env_time"),
("reward_time", "generation/trajectory/reward_time"),
]:
time_dicts = [item.traj.get(time_key, {}) for item in trajectories]
# Safely gather all unique sub-keys (e.g., 'reset_latency') across all trajectories
for sub_key in {k for d in time_dicts for k in d.keys()}:
vals = [d.get(sub_key, 0.0) for d in time_dicts]
metrics_to_log.update({
f"{prefix}/{sub_key}/mean": (np.mean(vals), np.mean),
f"{prefix}/{sub_key}/max": (np.max(vals), np.max),
f"{prefix}/{sub_key}/min": (np.min(vals), np.min),
})
self.rl_cluster.buffer_metrics_async(
metrics_to_log,
mode=mode,
step=expected_step,
)
for metric_fn in self.metric_fns:
user_defined_metric = metric_fn(
prompts=original_inputs["prompts"],
completions=completion_texts,
advantages=advantages,
rewards=rewards,
**{
key: value
for key, value in original_inputs.items()
if key != "prompts"
},
)
self.rl_cluster.buffer_metrics_async(
user_defined_metric, mode=mode, step=expected_step
)
combined_batch = TrainExample(
prompt_ids=prompt_ids,
prompt_mask=prompt_mask,
completion_ids=completion_ids,
completion_mask=completion_mask,
ref_per_token_logps=ref_per_token_logps,
advantages=advantages,
old_per_token_logps=old_per_token_logps,
policy_version=policy_versions,
)
return [combined_batch]
GrpoConfig = GRPOConfig
GrpoLearner = GRPOLearner