Skip to content

Commit 02ae80b

Browse files
authoredMar 10, 2023
[chatgpt]add flag of action mask in critic(#3086)
1 parent 95a36ea commit 02ae80b

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed
 

‎applications/ChatGPT/chatgpt/models/base/actor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def generate(
3737
if pad_token_id is not None:
3838
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
3939
if not return_action_mask:
40-
return sequences, attention_mask
40+
return sequences, attention_mask, None
4141
input_len = input_ids.size(1)
4242
eos_token_id = kwargs.get('eos_token_id', None)
4343
if eos_token_id is None:

‎applications/ChatGPT/chatgpt/models/base/critic.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@ class Critic(LoRAModule):
1818
lora_train_bias (str): LoRA bias training mode.
1919
"""
2020

21-
def __init__(self,
22-
model: nn.Module,
23-
value_head: nn.Module,
24-
lora_rank: int = 0,
25-
lora_train_bias: str = 'none') -> None:
21+
def __init__(
22+
self,
23+
model: nn.Module,
24+
value_head: nn.Module,
25+
lora_rank: int = 0,
26+
lora_train_bias: str = 'none',
27+
use_action_mask: bool = False,
28+
) -> None:
2629

2730
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
2831
self.model = model
2932
self.value_head = value_head
33+
self.use_action_mask = use_action_mask
3034
self.convert_to_lora()
3135

3236
def forward(self,
@@ -38,13 +42,13 @@ def forward(self,
3842

3943
values = self.value_head(last_hidden_states).squeeze(-1)
4044

41-
if action_mask is not None:
45+
if action_mask is not None and self.use_action_mask:
4246
num_actions = action_mask.size(1)
4347
prompt_mask = attention_mask[:, :-num_actions]
4448
values = values[:, :-num_actions]
4549
value = masked_mean(values, prompt_mask, dim=1)
4650
return value
4751

4852
values = values[:, :-1]
49-
value = values.mean(dim=1).squeeze(1)
53+
value = values.mean(dim=1)
5054
return value

‎applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(self,
2424
config: Optional[BloomConfig] = None,
2525
checkpoint: bool = False,
2626
lora_rank: int = 0,
27-
lora_train_bias: str = 'none') -> None:
27+
lora_train_bias: str = 'none',
28+
**kwargs) -> None:
2829
if pretrained is not None:
2930
model = BloomModel.from_pretrained(pretrained)
3031
elif config is not None:
@@ -34,4 +35,4 @@ def __init__(self,
3435
if checkpoint:
3536
model.gradient_checkpointing_enable()
3637
value_head = nn.Linear(model.config.hidden_size, 1)
37-
super().__init__(model, value_head, lora_rank, lora_train_bias)
38+
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

‎applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class GPTCritic(Critic):
2020
def __init__(self,
2121
pretrained: Optional[str] = None,
2222
config: Optional[GPT2Config] = None,
23-
checkpoint: bool = False) -> None:
23+
checkpoint: bool = False,
24+
**kwargs) -> None:
2425
if pretrained is not None:
2526
model = GPT2Model.from_pretrained(pretrained)
2627
elif config is not None:
@@ -30,4 +31,4 @@ def __init__(self,
3031
if checkpoint:
3132
model.gradient_checkpointing_enable()
3233
value_head = nn.Linear(model.config.n_embd, 1)
33-
super().__init__(model, value_head)
34+
super().__init__(model, value_head, **kwargs)

‎applications/ChatGPT/chatgpt/models/opt/opt_critic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(self,
2424
config: Optional[OPTConfig] = None,
2525
checkpoint: bool = False,
2626
lora_rank: int = 0,
27-
lora_train_bias: str = 'none') -> None:
27+
lora_train_bias: str = 'none',
28+
**kargs) -> None:
2829
if pretrained is not None:
2930
model = OPTModel.from_pretrained(pretrained)
3031
elif config is not None:
@@ -34,4 +35,4 @@ def __init__(self,
3435
if checkpoint:
3536
model.gradient_checkpointing_enable()
3637
value_head = nn.Linear(model.config.hidden_size, 1)
37-
super().__init__(model, value_head, lora_rank, lora_train_bias)
38+
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

0 commit comments

Comments
 (0)