Skip to content

Commit f721db2

Browse files
committed
start wiring up dense rewarding with implicit prm
1 parent f3e20cf commit f721db2

File tree

3 files changed

+46
-29
lines changed

3 files changed

+46
-29
lines changed

palm_rlhf_pytorch/implicit_process_reward.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Free Process Rewards without Process Labels
2+
# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime
3+
14
from __future__ import annotations
25
from copy import deepcopy
36

@@ -18,9 +21,6 @@ def get_logprob_at(logits, seq):
1821
log_prob = log_probs.gather(-1, seq)
1922
return rearrange(log_prob, '... 1 -> ...')
2023

21-
# Free Process Rewards without Process Labels
22-
# Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime
23-
2424
class ImplicitPRM(Module):
2525
""" PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """
2626

@@ -51,12 +51,6 @@ def forward(
5151
seq,
5252
labels = None
5353
):
54-
"""
55-
b - batch
56-
n - sequence
57-
l - logit dimension (num tokens)
58-
"""
59-
6054
source_seq, target_seq = seq[:, :-1], seq[:, 1:]
6155

6256
mask = target_seq >= 0 # assume any token ids < 0 to be padding

palm_rlhf_pytorch/ppo.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
from palm_rlhf_pytorch.palm import PaLM
2929
from palm_rlhf_pytorch.reward import RewardModel
30+
from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM
3031
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator
31-
3232
from accelerate import Accelerator
3333

3434
# actor critic - PaLM with lora
@@ -47,7 +47,7 @@ class ActorCritic(Module):
4747
def __init__(
4848
self,
4949
palm: PaLM,
50-
critic_palm: PaLM | None = None,
50+
critic: PaLM | ImplicitPRM | None = None,
5151
pooled_values = False,
5252
actor_lora = True,
5353
critic_lora = True,
@@ -61,13 +61,26 @@ def __init__(
6161
super().__init__()
6262
self.actor_palm = palm
6363

64-
self.critic_palm = critic_palm
64+
# detect implicit prm and auto-set some hyperparameters
65+
66+
critic_is_prm = isinstance(critic, ImplicitPRM)
67+
68+
critic_lora &= not critic_is_prm
69+
pooled_values |= critic_is_prm
70+
71+
self.critic_is_prm = critic_is_prm
72+
73+
# critic
74+
75+
self.critic = critic
6576

66-
if not exists(self.critic_palm):
67-
self.critic_palm = copy.deepcopy(palm)
77+
if not exists(self.critic):
78+
self.critic = copy.deepcopy(palm)
6879

6980
self.actor_palm.set_dropout(actor_dropout)
70-
self.critic_palm.set_dropout(critic_dropout)
81+
82+
if not critic_is_prm:
83+
self.critic.set_dropout(critic_dropout)
7184

7285
self.actor_lora = actor_lora
7386
self.critic_lora = critic_lora
@@ -79,16 +92,19 @@ def __init__(
7992
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
8093

8194
if self.critic_lora:
82-
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
95+
self.critic.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
8396

8497
self.pooled_values = pooled_values
85-
self.value_head = nn.Sequential(
86-
nn.Linear(palm.dim, 1),
87-
Rearrange('... 1 -> ...')
88-
)
98+
self.value_head = nn.Identity()
99+
100+
if not critic_is_prm:
101+
self.value_head = nn.Sequential(
102+
nn.Linear(palm.dim, 1),
103+
Rearrange('... 1 -> ...')
104+
)
89105

90-
nn.init.zeros_(self.value_head[0].bias)
91-
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
106+
nn.init.zeros_(self.value_head[0].bias)
107+
nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2))
92108

93109
def actor_parameters(self):
94110
if not self.actor_lora:
@@ -99,11 +115,14 @@ def actor_parameters(self):
99115
]
100116

101117
def critic_parameters(self):
118+
if self.critic_is_prm:
119+
return self.critic.parameters()
120+
102121
if not self.actor_lora:
103-
return [*self.critic_palm.parameters(), *self.value_head.parameters()]
122+
return [*self.critic.parameters(), *self.value_head.parameters()]
104123

105124
return [
106-
*self.critic_palm.finetune_parameters(self.critic_lora_scope),
125+
*self.critic.finetune_parameters(self.critic_lora_scope),
107126
*self.value_head.parameters()
108127
]
109128

@@ -170,7 +189,11 @@ def forward(
170189
if not return_values:
171190
return action_logits, None
172191

173-
critic_embeds = self.critic_palm(
192+
if self.critic_is_prm:
193+
values = self.critic(x)
194+
return action_logits, values
195+
196+
critic_embeds = self.critic(
174197
x,
175198
return_only_embedding = True,
176199
finetune_scope = self.critic_lora_scope
@@ -287,8 +310,8 @@ def clipped_value_loss(values, rewards, old_values, clip):
287310

288311
# rlhf trainer
289312

290-
@beartype
291313
class RLHFTrainer(Module):
314+
@beartype
292315
def __init__(
293316
self,
294317
*,
@@ -298,7 +321,7 @@ def __init__(
298321
tokenizer: Callable | None = None,
299322
palm: PaLM,
300323
reward_model: RewardModel,
301-
critic_palm: PaLM | None = None,
324+
critic: PaLM | ImplicitPRM | None = None,
302325
actor_critic: ActorCritic | None = None,
303326
actor_lr = 1e-4,
304327
critic_lr = 1e-4,
@@ -351,7 +374,7 @@ def __init__(
351374
if not exists(actor_critic):
352375
actor_critic = ActorCritic(
353376
palm = palm,
354-
critic_palm = critic_palm,
377+
critic = critic,
355378
actor_lora = actor_lora,
356379
critic_lora = critic_lora,
357380
actor_lora_r = actor_lora_r,

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'PaLM-rlhf-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.7',
6+
version = '0.3.9',
77
license='MIT',
88
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)