Skip to content

Commit c08ff73

Browse files
committed
fix #175
Former-commit-id: fd557eb
1 parent e9736b2 commit c08ff73

5 files changed

Lines changed: 38 additions & 12 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ transformers>=4.29.1
33
datasets>=2.12.0
44
accelerate>=0.19.0
55
peft>=0.3.0
6-
trl>=0.4.4
6+
trl==0.4.4
77
sentencepiece
88
jieba
99
rouge-chinese

src/llmtuner/chat/stream_chat.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from typing import Any, Dict, Generator, List, Optional, Tuple
23
from threading import Thread
34
from transformers import TextIteratorStreamer
@@ -41,10 +42,10 @@ def process_args(
4142
gen_kwargs = self.generating_args.to_dict()
4243
gen_kwargs.update(dict(
4344
input_ids=inputs["input_ids"],
44-
temperature=temperature if temperature else gen_kwargs["temperature"],
45-
top_p=top_p if top_p else gen_kwargs["top_p"],
46-
top_k=top_k if top_k else gen_kwargs["top_k"],
47-
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
45+
temperature=temperature or gen_kwargs["temperature"],
46+
top_p=top_p or gen_kwargs["top_p"],
47+
top_k=top_k or gen_kwargs["top_k"],
48+
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
4849
logits_processor=get_logits_processor()
4950
))
5051

@@ -58,6 +59,7 @@ def process_args(
5859

5960
return gen_kwargs, prompt_length
6061

62+
@torch.inference_mode()
6163
def chat(
6264
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
6365
) -> Tuple[str, Tuple[int, int]]:
@@ -68,6 +70,7 @@ def chat(
6870
response_length = len(outputs)
6971
return response, (prompt_length, response_length)
7072

73+
@torch.inference_mode()
7174
def stream_chat(
7275
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
7376
) -> Generator[str, None, None]:

src/llmtuner/tuner/core/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
2929
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
3030
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
31-
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
31+
require_version("trl==0.4.4", "To fix: pip install trl==0.4.4")
3232

3333

3434
def load_model_and_tokenizer(

src/llmtuner/tuner/ppo/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def ppo_train(self, max_target_length: int) -> None:
153153
if self.control.should_training_stop:
154154
break
155155

156-
@torch.no_grad()
156+
@torch.inference_mode()
157157
def generate(
158158
self,
159159
inputs: Dict[str, torch.Tensor],

src/llmtuner/tuner/sft/trainer.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,40 @@ def prediction_step(
3232
Subclass and override to inject custom behavior.
3333
"""
3434
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
35-
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
36-
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
37-
else:
38-
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
35+
if prompt_len > label_len:
36+
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
37+
if label_len > prompt_len:
38+
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
39+
3940
loss, generated_tokens, labels = super().prediction_step(
4041
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
4142
)
42-
generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None
43+
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
4344

4445
return (loss, generated_tokens, labels)
4546

47+
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
48+
r"""
49+
Pads the tensor to the same length as the target tensor.
50+
51+
Should only be called when predict_with_generate=True.
52+
"""
53+
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
54+
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
55+
# If PAD token is not defined at least EOS token has to be defined
56+
pad_token_id = (
57+
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
58+
)
59+
else:
60+
if self.model.config.pad_token_id is not None:
61+
pad_token_id = self.model.config.pad_token_id
62+
else:
63+
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
64+
65+
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
66+
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
67+
return padded_tensor
68+
4669
def save_predictions(
4770
self,
4871
predict_results: PredictionOutput

0 commit comments

Comments
 (0)