Skip to content
45 changes: 43 additions & 2 deletions skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class AdvantageEstimator(StrEnum):
GRPO = "grpo"
RLOO = "rloo"
REINFORCE_PP = "reinforce++"
MAXRL = "MAXRL"


class AdvantageEstimatorRegistry(BaseFunctionRegistry):
Expand All @@ -453,6 +454,7 @@ def repopulate_registry(cls):
"gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
"rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
"reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
"MAXRL": [AdvantageEstimator.MAXRL, compute_maxrl_advantage],
}

for ae_name, (ae_type, ae_func) in ae_types.items():
Expand Down Expand Up @@ -609,7 +611,9 @@ def sapo_policy_loss(
# The SAPO paper uses sequence_mean reduction; there's no reason
# why a user couldn't use token_mean reduction, but
# it's not clear whether it would be stable or not.
from loguru import logger as logger_ # have to do lazy import to avoid pickling error
from loguru import (
logger as logger_, # have to do lazy import to avoid pickling error
)

logger_.warning(f"With SAPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}")

Expand Down Expand Up @@ -685,7 +689,9 @@ def gspo_policy_loss(
# The GSPO paper uses sequence_mean reduction; there's no reason
# why a user couldn't use token_mean reduction, but
# it's not clear whether it would be stable or not.
from loguru import logger as logger_ # have to do lazy import to avoid pickling error
from loguru import (
logger as logger_, # have to do lazy import to avoid pickling error
)

logger_.warning(f"With GSPO it's recommended to use 'sequence_mean' loss reduction; got {loss_reduction}")

Expand Down Expand Up @@ -1182,6 +1188,41 @@ def compute_grpo_outcome_advantage(
return scores, scores


@register_advantage_estimator(AdvantageEstimator.MAXRL)
def compute_maxrl_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute advantage for MAXRL using mean-normalized group-relative rewards."""
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if len(id2score[index[i]]) > 1:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2mean[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1) * response_mask

return scores, scores


def repopulate_all_registries():
PolicyLossRegistry.repopulate_registry()
AdvantageEstimatorRegistry.repopulate_registry()
Expand Down
27 changes: 27 additions & 0 deletions skyrl-train/tests/cpu/utils/test_ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
compute_approx_kl,
compute_gae_advantage_return,
compute_grpo_outcome_advantage,
compute_maxrl_advantage,
compute_advantages_and_returns,
AdaptiveKLController,
FixedKLController,
Expand Down Expand Up @@ -172,6 +173,32 @@ def test_compute_grpo_outcome_advantage_norm_std_false():
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_maxrl_advantage():
# Two groups: [6.0, 3.0] mean=4.5, [9.0, 12.0] mean=10.5
token_level_rewards = torch.tensor(
[
[1.0, 2.0, 3.0], # sum = 6.0, group 0
[1.0, 1.0, 1.0], # sum = 3.0, group 0
[3.0, 3.0, 3.0], # sum = 9.0, group 1
[4.0, 4.0, 4.0], # sum = 12.0, group 1
]
)
response_mask = torch.ones_like(token_level_rewards)
index = np.array([0, 0, 1, 1])

adv, ret = compute_maxrl_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)

expected = torch.tensor([1.5 / 4.5, -1.5 / 4.5, -1.5 / 10.5, 1.5 / 10.5]).unsqueeze(-1) * response_mask

assert adv.shape == token_level_rewards.shape
assert torch.allclose(adv, ret), "Advantages and returns should be equal with MAXRL"
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_gae_advantage_return(advantage_test_data):
rewards, values, response_mask, index = advantage_test_data

Expand Down
34 changes: 34 additions & 0 deletions skyrl/skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class AdvantageEstimator(StrEnum):
GRPO = "grpo"
RLOO = "rloo"
REINFORCE_PP = "reinforce++"
MAXRL = "MAXRL"


class AdvantageEstimatorRegistry(BaseFunctionRegistry):
Expand All @@ -453,6 +454,7 @@ def repopulate_registry(cls):
"gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
"rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
"reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
"MAXRL": [AdvantageEstimator.MAXRL, compute_maxrl_advantage],
}

for ae_name, (ae_type, ae_func) in ae_types.items():
Expand Down Expand Up @@ -1182,6 +1184,38 @@ def compute_grpo_outcome_advantage(
return scores, scores


@register_advantage_estimator(AdvantageEstimator.MAXRL)
def compute_maxrl_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute advantage for MAXRL using mean-normalized group-relative rewards."""
scores = token_level_rewards.sum(dim=-1)

id2score = defaultdict(list)
id2mean = {}

with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2mean[index[i]] + epsilon)
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Outdated
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Outdated
scores = scores.unsqueeze(-1) * response_mask

return scores, scores


def repopulate_all_registries():
PolicyLossRegistry.repopulate_registry()
AdvantageEstimatorRegistry.repopulate_registry()
Expand Down
27 changes: 27 additions & 0 deletions skyrl/tests/backends/skyrl_train/utils/test_ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
compute_approx_kl,
compute_gae_advantage_return,
compute_grpo_outcome_advantage,
compute_maxrl_advantage,
compute_advantages_and_returns,
AdaptiveKLController,
FixedKLController,
Expand Down Expand Up @@ -172,6 +173,32 @@ def test_compute_grpo_outcome_advantage_norm_std_false():
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_maxrl_advantage():
# Two groups: [6.0, 3.0] mean=4.5, [9.0, 12.0] mean=10.5
token_level_rewards = torch.tensor(
[
[1.0, 2.0, 3.0], # sum = 6.0, group 0
[1.0, 1.0, 1.0], # sum = 3.0, group 0
[3.0, 3.0, 3.0], # sum = 9.0, group 1
[4.0, 4.0, 4.0], # sum = 12.0, group 1
]
)
response_mask = torch.ones_like(token_level_rewards)
index = np.array([0, 0, 1, 1])

adv, ret = compute_maxrl_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)

expected = torch.tensor([1.5 / 4.5, -1.5 / 4.5, -1.5 / 10.5, 1.5 / 10.5]).unsqueeze(-1) * response_mask

assert adv.shape == token_level_rewards.shape
assert torch.allclose(adv, ret), "Advantages and returns should be equal with MAXRL"
assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}"


def test_compute_gae_advantage_return(advantage_test_data):
rewards, values, response_mask, index = advantage_test_data

Expand Down
Loading