Skip to content

Commit f870f3c

Browse files
samsjaclaude
andauthored
Feature: bring your own algorithms (#1715)
* update loss interface * refactor: rename grpo_loss to prime_rl_loss and consolidate loss interface - Rename grpo_loss to prime_rl_loss (uses LossConfig directly) - Move LossInputs, LossOutputs, LossFn from loss_interface.py into loss.py - Remove loss_interface.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: add bring-your-own-loss support for custom loss functions - Add CustomLossConfig with path and kwargs fields - Add LossConfigType union (LossConfig | CustomLossConfig) - Update setup_loss_fn to handle custom loss imports - Add _import_object helper for dynamic imports - Add test for custom loss configuration - Add docs/bring-your-own-loss.md documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: add bring-your-own advantage function support - Add AdvantageInputs/AdvantageOutputs dataclasses - Add CustomAdvantageConfig with path and kwargs - Add setup_advantage_fn for custom advantage imports - Refactor compute_advantages to use the new interface - Add tests for custom advantage configuration - Rename docs to bring-your-own-algorithms.md with both loss and advantage sections Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * refactor: rename path to byo_function in custom configs More descriptive name that ties into the "bring your own" concept. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * refactor: address PR review feedback - Extract import_object to shared utils (deduplicate from loss.py and advantage.py) - Add pydantic discriminator types for LossConfigType and AdvantageConfigType - Rename byo_function to import_path (standard Python terminology) - Rename grpo_advantage to default_advantage - Fix train.py AttributeError when using CustomLossConfig - Fix docs: per-example terminology, loss/advantage descriptions, config examples Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: update prime_rl_loss description Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: handle custom loss metrics in training loop and restore extra=forbid on AdvantageConfig - Guard mismatch_kl access in micro-step and step logging (custom loss may not emit it) - Change AdvantageConfig base from BaseModel to BaseConfig to restore extra="forbid" validation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: rename prime_rl_loss to default_loss_fn Consistent with default_advantage naming. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: simplify loss setup logging Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: rename default_advantage to default_advantage_fn Consistent with default_loss_fn naming. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent bd08727 commit f870f3c

File tree

9 files changed

+567
-128
lines changed

9 files changed

+567
-128
lines changed

docs/bring-your-own-algorithms.md

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Bring Your Own Algorithms
2+
3+
Prime-RL supports custom implementations for key algorithmic components, allowing you to experiment with different RL objectives and techniques.
4+
5+
## 1. Custom Loss Functions
6+
7+
The loss is computed **per-sequence** (per-sample). You provide a function that computes the loss for a single sequence, and the framework handles iteration and aggregation.
8+
9+
### Interface
10+
11+
```python
12+
from prime_rl.trainer.rl.loss import LossInputs, LossOutputs
13+
14+
def my_custom_loss(inputs: LossInputs, **kwargs) -> LossOutputs:
15+
...
16+
```
17+
18+
#### LossInputs
19+
20+
```python
21+
@dataclass
22+
class LossInputs:
23+
trainer_logprobs: Float[Tensor, "seq"] # Log probs from current policy
24+
inference_logprobs: Float[Tensor, "seq"] # Log probs from reference policy
25+
teacher_logprobs: Float[Tensor, "seq"] | None # Optional teacher log probs
26+
advantages: Float[Tensor, "seq"] # Per-token advantages
27+
loss_mask: Bool[Tensor, "seq"] # Mask for valid tokens
28+
```
29+
30+
#### LossOutputs
31+
32+
```python
33+
@dataclass
34+
class LossOutputs:
35+
loss: Float[Tensor, ""] # Scalar loss for this sequence
36+
metrics: dict[str, Tensor] # Metrics to log
37+
```
38+
39+
### Example: PPO Clipped Loss
40+
41+
```python
42+
import torch
43+
from prime_rl.trainer.rl.loss import LossInputs, LossOutputs
44+
45+
def ppo_clip_loss(inputs: LossInputs, clip_eps: float = 0.2) -> LossOutputs:
46+
ratio = torch.exp(inputs.trainer_logprobs - inputs.inference_logprobs)
47+
clipped_ratio = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps)
48+
49+
surr1 = ratio * inputs.advantages
50+
surr2 = clipped_ratio * inputs.advantages
51+
52+
loss = -torch.min(surr1, surr2)[inputs.loss_mask].sum()
53+
54+
return LossOutputs(
55+
loss=loss,
56+
metrics={"clip_frac": (ratio != clipped_ratio)[inputs.loss_mask].float().mean()},
57+
)
58+
```
59+
60+
### Configuration
61+
62+
```toml
63+
[loss]
64+
type = "custom"
65+
import_path = "my_module.ppo_clip_loss"
66+
kwargs = { clip_eps = 0.2 }
67+
```
68+
69+
---
70+
71+
## 2. Custom Advantage Functions
72+
73+
Advantages are computed **per-example** (grouped by `rollouts_per_example`). You provide a function that computes advantages for a batch of examples.
74+
75+
### Interface
76+
77+
```python
78+
from prime_rl.orchestrator.advantage import AdvantageInputs, AdvantageOutputs
79+
80+
def my_custom_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs:
81+
...
82+
```
83+
84+
#### AdvantageInputs
85+
86+
```python
87+
@dataclass
88+
class AdvantageInputs:
89+
rewards: Float[Tensor, "num_examples rollouts_per_example"]
90+
completion_lengths: Int[Tensor, "num_examples rollouts_per_example"]
91+
```
92+
93+
#### AdvantageOutputs
94+
95+
```python
96+
@dataclass
97+
class AdvantageOutputs:
98+
advantages: Float[Tensor, "num_examples rollouts_per_example"]
99+
```
100+
101+
### Example: Normalized Advantage
102+
103+
```python
104+
import torch
105+
from prime_rl.orchestrator.advantage import AdvantageInputs, AdvantageOutputs
106+
107+
def normalized_advantage(inputs: AdvantageInputs, eps: float = 1e-8) -> AdvantageOutputs:
108+
"""Normalize advantages to zero mean and unit variance per example."""
109+
mean = inputs.rewards.mean(dim=1, keepdim=True)
110+
std = inputs.rewards.std(dim=1, keepdim=True)
111+
advantages = (inputs.rewards - mean) / (std + eps)
112+
return AdvantageOutputs(advantages=advantages)
113+
```
114+
115+
### Configuration
116+
117+
```toml
118+
[advantage]
119+
type = "custom"
120+
import_path = "my_module.normalized_advantage"
121+
kwargs = { eps = 1e-8 }
122+
```
123+
124+
---
125+
126+
## Default Implementations
127+
128+
If no custom function is specified:
129+
130+
- **Loss**: Uses `default_loss_fn` (masked importance sampling with KL against the inference policy, and optional masking strategies)
131+
- **Advantage**: Uses `default_advantage_fn` (reward minus per-example baseline, a.k.a. DR-GRPO without std normalization)
132+
133+
See `LossConfig` and `AdvantageConfig` for available parameters.
134+
135+
## Tips
136+
137+
- Your functions receive structured inputs via dataclasses with jaxtyping annotations
138+
- Return metrics as scalars or 1D tensors - they'll be aggregated automatically
139+
- Use the `loss_mask` / tensor shapes to handle variable-length sequences
140+
- Test your custom functions with the provided test patterns before training
Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,72 @@
1+
from dataclasses import dataclass
2+
from typing import Callable
3+
14
import torch
5+
from jaxtyping import Float, Int
6+
from torch import Tensor
7+
8+
from prime_rl.orchestrator.config import AdvantageConfigType, CustomAdvantageConfig
9+
from prime_rl.utils.utils import import_object
10+
11+
12+
@dataclass
13+
class AdvantageInputs:
14+
"""Inputs for advantage computation."""
15+
16+
rewards: Float[Tensor, "num_problems rollouts_per_example"]
17+
completion_lengths: Int[Tensor, "num_problems rollouts_per_example"]
18+
19+
20+
@dataclass
21+
class AdvantageOutputs:
22+
"""Outputs from advantage computation."""
23+
24+
advantages: Float[Tensor, "num_problems rollouts_per_example"]
25+
26+
27+
AdvantageFn = Callable[..., AdvantageOutputs]
28+
"""Type for an advantage function.
29+
30+
Expected signature:
31+
def my_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs:
32+
...
33+
"""
34+
235

3-
from prime_rl.orchestrator.config import AdvantageConfig
36+
def default_advantage_fn(inputs: AdvantageInputs, length_weighted_mean: bool = False) -> AdvantageOutputs:
37+
"""Default GRPO advantage: reward minus per-problem baseline."""
38+
if length_weighted_mean:
39+
baseline = (inputs.rewards * inputs.completion_lengths).sum(
40+
dim=1, keepdim=True
41+
) / inputs.completion_lengths.sum(dim=1, keepdim=True)
42+
else:
43+
baseline = inputs.rewards.mean(dim=1, keepdim=True)
44+
45+
return AdvantageOutputs(advantages=inputs.rewards - baseline)
46+
47+
48+
def setup_advantage_fn(config: AdvantageConfigType) -> AdvantageFn:
49+
"""Setup advantage function from config."""
50+
if isinstance(config, CustomAdvantageConfig):
51+
custom_fn = import_object(config.import_path)
52+
kwargs = config.kwargs
53+
54+
def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs:
55+
return custom_fn(inputs, **kwargs)
56+
57+
return advantage_fn
58+
59+
def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs:
60+
return default_advantage_fn(inputs, length_weighted_mean=config.length_weighted_mean)
61+
62+
return advantage_fn
463

564

665
def compute_advantages(
766
rewards: list[float],
867
completion_lengths: list[int],
968
samples_per_problem: int,
10-
advantage_config: AdvantageConfig | None,
69+
advantage_config: AdvantageConfigType | None,
1170
) -> list[float]:
1271
"""
1372
Computes advantages from a flattened list of rewards, grouped by problem.
@@ -16,14 +75,17 @@ def compute_advantages(
1675
rewards: Flattened list of rewards where first `samples_per_problem` rewards are for the first problem
1776
completion_lengths: List of completion lengths for each reward
1877
samples_per_problem: Number of samples (and thus, rewards) per problem
19-
advantage_config: Configuration for advantage computation
78+
advantage_config: Configuration for advantage computation (AdvantageConfig or CustomAdvantageConfig)
2079
"""
2180
if not advantage_config:
2281
return rewards
23-
rewards = torch.tensor(rewards).view(-1, samples_per_problem)
24-
lengths = torch.tensor(completion_lengths).view(-1, samples_per_problem)
25-
if advantage_config.length_weighted_mean:
26-
baseline = (rewards * lengths).sum(dim=1, keepdim=True) / lengths.sum(dim=1, keepdim=True)
27-
else:
28-
baseline = rewards.mean(dim=1, keepdim=True)
29-
return (rewards - baseline).flatten().tolist()
82+
83+
advantage_fn = setup_advantage_fn(advantage_config)
84+
85+
inputs = AdvantageInputs(
86+
rewards=torch.tensor(rewards).view(-1, samples_per_problem),
87+
completion_lengths=torch.tensor(completion_lengths).view(-1, samples_per_problem),
88+
)
89+
90+
result = advantage_fn(inputs)
91+
return result.advantages.flatten().tolist()

src/prime_rl/orchestrator/config.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Annotated, Any, Literal, TypeAlias
33

4-
from pydantic import AliasChoices, BaseModel, Field, model_validator
4+
from pydantic import AliasChoices, BaseModel, Discriminator, Field, Tag, model_validator
55

66
from prime_rl.transport.config import FileSystemTransportConfig, TransportConfigType
77
from prime_rl.utils.config import (
@@ -612,9 +612,36 @@ def validate_skip_verification(self):
612612

613613

614614
class AdvantageConfig(BaseConfig):
615+
"""Config for the default advantage."""
616+
617+
type: Literal["default"] = "default"
615618
length_weighted_mean: bool = False
616619

617620

621+
class CustomAdvantageConfig(BaseModel):
622+
"""Config for a custom external advantage function."""
623+
624+
type: Literal["custom"] = "custom"
625+
import_path: Annotated[
626+
str, Field(description="Import path to the advantage function (e.g., 'my_module.my_advantage')")
627+
]
628+
kwargs: Annotated[
629+
dict[str, Any], Field(default_factory=dict, description="Kwargs to pass to the advantage function")
630+
]
631+
632+
633+
def _advantage_config_discriminator(v: Any) -> str:
634+
if isinstance(v, dict):
635+
return v.get("type", "default")
636+
return getattr(v, "type", "default")
637+
638+
639+
AdvantageConfigType: TypeAlias = Annotated[
640+
Annotated[AdvantageConfig, Tag("default")] | Annotated[CustomAdvantageConfig, Tag("custom")],
641+
Discriminator(_advantage_config_discriminator),
642+
]
643+
644+
618645
class FileSystemWeightBroadcastConfig(BaseModel):
619646
"""Configures the filesystem weight broadcast."""
620647

@@ -683,7 +710,7 @@ class OrchestratorConfig(BaseSettings):
683710
buffer: BufferConfig = BufferConfig()
684711

685712
# The advantage configuration
686-
advantage: AdvantageConfig | None = AdvantageConfig()
713+
advantage: AdvantageConfigType | None = AdvantageConfig()
687714

688715
# The logging configuration
689716
log: LogConfig = LogConfig()

src/prime_rl/trainer/rl/config.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
2-
from typing import Annotated, Literal, TypeAlias
2+
from typing import Annotated, Any, Literal, TypeAlias
33

4-
from pydantic import BaseModel, Field, model_validator
4+
from pydantic import BaseModel, Discriminator, Field, Tag, model_validator
55

66
from prime_rl.trainer.config import (
77
AdamWConfig,
@@ -19,8 +19,9 @@
1919

2020

2121
class LossConfig(BaseConfig):
22-
"""Base config for loss."""
22+
"""Config for the default loss."""
2323

24+
type: Literal["default"] = "default"
2425
ratio_type: Annotated[Literal["token", "sequence"], Field(description="Type of importance ratio to use.")] = "token"
2526

2627
token_mask_high: Annotated[
@@ -72,6 +73,26 @@ def validate_mask_bounds(self):
7273
return self
7374

7475

76+
class CustomLossConfig(BaseModel):
77+
"""Config for a custom external loss function."""
78+
79+
type: Literal["custom"] = "custom"
80+
import_path: Annotated[str, Field(description="Import path to the loss function (e.g., 'my_module.my_loss')")]
81+
kwargs: Annotated[dict[str, Any], Field(default_factory=dict, description="Kwargs to pass to the loss function")]
82+
83+
84+
def _loss_config_discriminator(v: Any) -> str:
85+
if isinstance(v, dict):
86+
return v.get("type", "default")
87+
return getattr(v, "type", "default")
88+
89+
90+
LossConfigType: TypeAlias = Annotated[
91+
Annotated[LossConfig, Tag("default")] | Annotated[CustomLossConfig, Tag("custom")],
92+
Discriminator(_loss_config_discriminator),
93+
]
94+
95+
7596
class FakeDataLoaderConfig(BaseConfig):
7697
"""Configures a fake data loader sampling random micro batches for debugging."""
7798

@@ -130,7 +151,7 @@ class RLTrainerConfig(BaseSettings):
130151
data: DataLoaderConfig = DataLoaderConfig()
131152

132153
# The loss configuration
133-
loss: LossConfig = LossConfig()
154+
loss: LossConfigType = LossConfig()
134155

135156
# The optimizer configuration
136157
optim: Annotated[OptimizerConfigType, Field(discriminator="type")] = AdamWConfig()

0 commit comments

Comments
 (0)