27
27
28
28
from palm_rlhf_pytorch .palm import PaLM
29
29
from palm_rlhf_pytorch .reward import RewardModel
30
+ from palm_rlhf_pytorch .implicit_process_reward import ImplicitPRM
30
31
from palm_rlhf_pytorch .utils import masked_mean , eval_decorator
31
-
32
32
from accelerate import Accelerator
33
33
34
34
# actor critic - PaLM with lora
@@ -47,7 +47,7 @@ class ActorCritic(Module):
47
47
def __init__ (
48
48
self ,
49
49
palm : PaLM ,
50
- critic_palm : PaLM | None = None ,
50
+ critic : PaLM | ImplicitPRM | None = None ,
51
51
pooled_values = False ,
52
52
actor_lora = True ,
53
53
critic_lora = True ,
@@ -61,13 +61,26 @@ def __init__(
61
61
super ().__init__ ()
62
62
self .actor_palm = palm
63
63
64
- self .critic_palm = critic_palm
64
+ # detect implicit prm and auto-set some hyperparameters
65
+
66
+ critic_is_prm = isinstance (critic , ImplicitPRM )
67
+
68
+ critic_lora &= not critic_is_prm
69
+ pooled_values |= critic_is_prm
70
+
71
+ self .critic_is_prm = critic_is_prm
72
+
73
+ # critic
74
+
75
+ self .critic = critic
65
76
66
- if not exists (self .critic_palm ):
67
- self .critic_palm = copy .deepcopy (palm )
77
+ if not exists (self .critic ):
78
+ self .critic = copy .deepcopy (palm )
68
79
69
80
self .actor_palm .set_dropout (actor_dropout )
70
- self .critic_palm .set_dropout (critic_dropout )
81
+
82
+ if not critic_is_prm :
83
+ self .critic .set_dropout (critic_dropout )
71
84
72
85
self .actor_lora = actor_lora
73
86
self .critic_lora = critic_lora
@@ -79,16 +92,19 @@ def __init__(
79
92
self .actor_palm .add_finetune_params (actor_lora_scope , lora_r = actor_lora_r )
80
93
81
94
if self .critic_lora :
82
- self .critic_palm .add_finetune_params (critic_lora_scope , lora_r = critic_lora_r )
95
+ self .critic .add_finetune_params (critic_lora_scope , lora_r = critic_lora_r )
83
96
84
97
self .pooled_values = pooled_values
85
- self .value_head = nn .Sequential (
86
- nn .Linear (palm .dim , 1 ),
87
- Rearrange ('... 1 -> ...' )
88
- )
98
+ self .value_head = nn .Identity ()
99
+
100
+ if not critic_is_prm :
101
+ self .value_head = nn .Sequential (
102
+ nn .Linear (palm .dim , 1 ),
103
+ Rearrange ('... 1 -> ...' )
104
+ )
89
105
90
- nn .init .zeros_ (self .value_head [0 ].bias )
91
- nn .init .orthogonal_ (self .value_head [0 ].weight , gain = math .sqrt (2 ))
106
+ nn .init .zeros_ (self .value_head [0 ].bias )
107
+ nn .init .orthogonal_ (self .value_head [0 ].weight , gain = math .sqrt (2 ))
92
108
93
109
def actor_parameters (self ):
94
110
if not self .actor_lora :
@@ -99,11 +115,14 @@ def actor_parameters(self):
99
115
]
100
116
101
117
def critic_parameters (self ):
118
+ if self .critic_is_prm :
119
+ return self .critic .parameters ()
120
+
102
121
if not self .actor_lora :
103
- return [* self .critic_palm .parameters (), * self .value_head .parameters ()]
122
+ return [* self .critic .parameters (), * self .value_head .parameters ()]
104
123
105
124
return [
106
- * self .critic_palm .finetune_parameters (self .critic_lora_scope ),
125
+ * self .critic .finetune_parameters (self .critic_lora_scope ),
107
126
* self .value_head .parameters ()
108
127
]
109
128
@@ -170,7 +189,11 @@ def forward(
170
189
if not return_values :
171
190
return action_logits , None
172
191
173
- critic_embeds = self .critic_palm (
192
+ if self .critic_is_prm :
193
+ values = self .critic (x )
194
+ return action_logits , values
195
+
196
+ critic_embeds = self .critic (
174
197
x ,
175
198
return_only_embedding = True ,
176
199
finetune_scope = self .critic_lora_scope
@@ -287,8 +310,8 @@ def clipped_value_loss(values, rewards, old_values, clip):
287
310
288
311
# rlhf trainer
289
312
290
- @beartype
291
313
class RLHFTrainer (Module ):
314
+ @beartype
292
315
def __init__ (
293
316
self ,
294
317
* ,
@@ -298,7 +321,7 @@ def __init__(
298
321
tokenizer : Callable | None = None ,
299
322
palm : PaLM ,
300
323
reward_model : RewardModel ,
301
- critic_palm : PaLM | None = None ,
324
+ critic : PaLM | ImplicitPRM | None = None ,
302
325
actor_critic : ActorCritic | None = None ,
303
326
actor_lr = 1e-4 ,
304
327
critic_lr = 1e-4 ,
@@ -351,7 +374,7 @@ def __init__(
351
374
if not exists (actor_critic ):
352
375
actor_critic = ActorCritic (
353
376
palm = palm ,
354
- critic_palm = critic_palm ,
377
+ critic = critic ,
355
378
actor_lora = actor_lora ,
356
379
critic_lora = critic_lora ,
357
380
actor_lora_r = actor_lora_r ,
0 commit comments