Skip to content

Commit 6d72619

Browse files
krammnicfelipemello1Mark Obozovebsmothers
authored
[RFC] truncation and skipping (#2419)
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Mark Obozov <[email protected]> Co-authored-by: ebsmothers <[email protected]>
1 parent b4d7fbb commit 6d72619

25 files changed

+168
-36
lines changed

recipes/knowledge_distillation_single_device.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
)
3232
from torchtune.recipe_interfaces import FTRecipeInterface
3333
from torchtune.training import DummyProfiler, PROFILER_KEY
34-
3534
from tqdm import tqdm
3635

3736
log = utils.get_logger("DEBUG")

recipes/lora_dpo_single_device.py

-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ def train(self) -> None:
533533
== self.max_steps_per_epoch
534534
):
535535
break
536-
537536
# batch is input_ids, labels
538537
num_tokens += batch[0].numel()
539538
policy_chosen_rejected_outputs = self.concatenated_forward(

recipes/lora_finetune_distributed.py

-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
)
3737
from torchtune.recipe_interfaces import FTRecipeInterface
3838
from torchtune.training import DummyProfiler, PROFILER_KEY
39-
4039
from tqdm import tqdm
4140

4241
log = utils.get_logger("DEBUG")

recipes/ppo_full_finetune_single_device.py

-1
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,6 @@ def train(self) -> None:
922922
self._sampler.set_epoch(curr_epoch)
923923

924924
for idx, batch in enumerate(self._dataloader):
925-
926925
# Start tracking CUDA memory for active steps for just the first epoch
927926
if (
928927
curr_epoch == 0

tests/torchtune/data/test_data_utils.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,24 @@ def test_truncate():
1818
tokens = [1, 2, 3, 4, -1]
1919

2020
# Test no truncation
21-
truncated_tokens = truncate(
22-
tokens=tokens,
23-
max_seq_len=5,
24-
eos_id=-1,
25-
)
21+
truncated_tokens = truncate(tokens=tokens, max_seq_len=5, eos_id=-1)
2622
assert truncated_tokens == tokens
2723

2824
masks = [True, True, False, True, False]
2925
# Test truncated mask
30-
truncated_masks = truncate(tokens=masks, max_seq_len=4, eos_id=False)
26+
truncated_masks = truncate(
27+
tokens=masks, max_seq_len=4, eos_id=False, truncation_type="right"
28+
)
29+
3130
assert truncated_masks == [True, True, False, False]
3231

32+
# Test right truncation
33+
truncated_masks = truncate(
34+
tokens=masks, max_seq_len=4, eos_id=False, truncation_type="left"
35+
)
36+
37+
assert truncated_masks == [True, False, True, False]
38+
3339

3440
def test_format_content_with_images():
3541
test_image_1 = Image.new(mode="RGB", size=(4, 4))

tests/torchtune/models/llama2/test_llama2_tokenizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def test_tokenize_messages(self, messages, expected_tokens):
6161
tokens, mask = tokenizer.tokenize_messages(messages)
6262
# Mask user, unmask assistant, add EOS token
6363
expected_mask = [True] * 75 + [False] * 125
64+
65+
assert len(tokens) == len(mask)
6466
assert expected_tokens == tokens
6567
assert expected_mask == mask
6668

torchtune/data/_utils.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def truncate(
2626
tokens: List[Any],
2727
max_seq_len: int,
2828
eos_id: Optional[Any] = None,
29+
truncation_type: str = "right",
2930
) -> List[Any]:
3031
"""
3132
Truncate a list of tokens to a maximum length. If eos_id is provided, the last
@@ -36,13 +37,29 @@ def truncate(
3637
max_seq_len (int): maximum length of the list
3738
eos_id (Optional[Any]): token to replace the last token with. If None, the
3839
last token will not be replaced. Default is None.
40+
truncation_type (str): type of truncation to apply, either "left" or "right".
41+
Default is "right".
3942
4043
Returns:
4144
List[Any]: truncated list of tokens
45+
46+
Raises:
47+
ValueError: if truncation_type is not "left" or "right"
4248
"""
43-
tokens_truncated = tokens[:max_seq_len]
44-
if eos_id is not None and tokens_truncated[-1] != eos_id:
49+
50+
if truncation_type == "left":
51+
tokens_truncated = tokens[-max_seq_len:] # Take the last max_seq_len tokens
52+
elif truncation_type == "right":
53+
tokens_truncated = tokens[:max_seq_len] # Take the first max_seq_len tokens
54+
else:
55+
raise ValueError(
56+
f"truncation_type must be 'left' or 'right', got {truncation_type}"
57+
)
58+
59+
# Replace the last token with eos_id if necessary
60+
if eos_id is not None and tokens_truncated and tokens_truncated[-1] != eos_id:
4561
tokens_truncated[-1] = eos_id
62+
4663
return tokens_truncated
4764

4865

torchtune/datasets/_preference.py

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
136136
chosen_input_ids, chosen_masks = self._tokenizer.tokenize_messages(
137137
transformed_sample["chosen"],
138138
)
139+
139140
chosen_labels = list(
140141
np.where(chosen_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids)
141142
)

torchtune/models/gemma/_model_builders.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def gemma_2b() -> TransformerDecoder:
4343
)
4444

4545

46-
def gemma_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None) -> GemmaTokenizer:
46+
def gemma_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None, truncation_type: str = "right") -> GemmaTokenizer:
4747
"""
4848
Tokenizer for Gemma.
4949
@@ -55,12 +55,13 @@ def gemma_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_templat
5555
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
5656
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
5757
prepend/append tags.
58-
58+
truncation_type (str): type of truncation to apply, either "left" or "right".
59+
Default is "right".
5960
6061
Returns:
6162
GemmaTokenizer: Instantiation of the Gemma tokenizer
6263
"""
63-
return GemmaTokenizer(path=path, max_seq_len=max_seq_len, prompt_template=_get_prompt_template(prompt_template) if prompt_template is not None else None)
64+
return GemmaTokenizer(path=path, max_seq_len=max_seq_len, prompt_template=_get_prompt_template(prompt_template) if prompt_template is not None else None, truncation_type=truncation_type)
6465

6566

6667
def lora_gemma_2b(

torchtune/models/gemma/_tokenizer.py

+5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class GemmaTokenizer(ModelTokenizer, Transform):
3434
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate`
3535
3636
The extra text will still get tokenized as normal text, not as special tokens. Default is None.
37+
truncation_type (str): type of truncation to apply, either "left" or "right".
38+
Default is "right".
3739
3840
Examples:
3941
>>> tokenizer = GemmaTokenizer("/path/to/spm_model")
@@ -47,6 +49,7 @@ def __init__(
4749
path: str,
4850
max_seq_len: Optional[int] = None,
4951
prompt_template: Optional[PromptTemplate] = None,
52+
truncation_type: str = "right",
5053
):
5154
self._spm_model = SentencePieceBaseTokenizer(path)
5255

@@ -59,6 +62,7 @@ def __init__(
5962
self.max_seq_len = max_seq_len
6063

6164
self.prompt_template = prompt_template
65+
self.truncation_type = truncation_type
6266

6367
@property
6468
def eos_id(self):
@@ -142,6 +146,7 @@ def tokenize_messages(
142146
messages=templated_messages,
143147
bos_id=self.bos_id,
144148
eos_id=self.eos_id if add_eos else None,
149+
truncation_type=self.truncation_type,
145150
)
146151

147152
def __call__(

torchtune/models/llama2/_model_builders.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def llama2_7b() -> TransformerDecoder:
4242
)
4343

4444

45-
def llama2_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = "torchtune.models.llama2.Llama2ChatTemplate") -> Llama2Tokenizer:
45+
def llama2_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = "torchtune.models.llama2.Llama2ChatTemplate", truncation_type: str = "right") -> Llama2Tokenizer:
4646
"""
4747
Tokenizer for Llama2.
4848
@@ -54,11 +54,12 @@ def llama2_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_templa
5454
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
5555
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
5656
prepend/append tags. Default is :class:`~torchtune.models.llama2.Llama2ChatTemplate`.
57-
57+
truncation_type (str): type of truncation to apply, either "left" or "right".
58+
Default is "right".
5859
Returns:
5960
Llama2Tokenizer: Instantiation of the Llama2 tokenizer
6061
"""
61-
return Llama2Tokenizer(path=path, max_seq_len=max_seq_len, prompt_template=_get_prompt_template(prompt_template) if prompt_template is not None else None)
62+
return Llama2Tokenizer(path=path, max_seq_len=max_seq_len, prompt_template=_get_prompt_template(prompt_template) if prompt_template is not None else None, truncation_type=truncation_type)
6263

6364

6465
def lora_llama2_7b(

torchtune/models/llama2/_tokenizer.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Llama2Tokenizer(ModelTokenizer, Transform):
4444
4545
The extra text will still get tokenized as normal text, not as special tokens.
4646
Default is :class:`~torchtune.models.llama2.Llama2ChatTemplate`.
47+
truncation_type (str): type of truncation to apply, either "left" or "right".
48+
Default is "right".
4749
4850
Examples:
4951
>>> tokenizer = Llama2Tokenizer("/path/to/spm_model")
@@ -57,6 +59,7 @@ def __init__(
5759
path: str,
5860
max_seq_len: Optional[int] = None,
5961
prompt_template: Optional[PromptTemplate] = Llama2ChatTemplate(),
62+
truncation_type: str = "right",
6063
):
6164
self._spm_model = SentencePieceBaseTokenizer(path)
6265

@@ -69,6 +72,7 @@ def __init__(
6972
self.max_seq_len = max_seq_len
7073

7174
self.prompt_template = prompt_template
75+
self.truncation_type = truncation_type
7276

7377
@property
7478
def eos_id(self):
@@ -159,6 +163,7 @@ def tokenize_messages(
159163
messages=templated_messages,
160164
bos_id=self.bos_id if add_start_tokens else None,
161165
eos_id=self.eos_id if add_end_tokens else None,
166+
truncation_type=self.truncation_type,
162167
)
163168

164169
def __call__(

torchtune/models/llama3/_model_builders.py

+4
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def llama3_tokenizer(
7070
special_tokens_path: Optional[str] = None,
7171
max_seq_len: Optional[int] = None,
7272
prompt_template: Optional[_TemplateType] = None,
73+
truncation_type: str = "right",
7374
) -> Llama3Tokenizer:
7475
"""
7576
Tokenizer for Llama3.
@@ -85,6 +86,8 @@ def llama3_tokenizer(
8586
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
8687
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
8788
prepend/append tags.
89+
truncation_type (str): type of truncation to apply, either "left" or "right".
90+
Default is "right".
8891
8992
Returns:
9093
Llama3Tokenizer: Instantiation of the Llama3 tokenizer
@@ -102,6 +105,7 @@ def llama3_tokenizer(
102105
special_tokens=special_tokens,
103106
max_seq_len=max_seq_len,
104107
prompt_template=template,
108+
truncation_type=truncation_type,
105109
)
106110

107111

torchtune/models/llama3/_tokenizer.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class Llama3Tokenizer(ModelTokenizer, Transform):
6565
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate`
6666
6767
The extra text will still get tokenized as normal text, not as special tokens. Default is None.
68+
truncation_type (str): type of truncation to apply, either "left" or "right".
69+
Default is "right".
6870
6971
Examples:
7072
>>> tokenizer = Llama3Tokenizer("/path/to/tt_model")
@@ -79,6 +81,7 @@ def __init__(
7981
special_tokens: Optional[Dict[str, int]] = None,
8082
max_seq_len: Optional[int] = None,
8183
prompt_template: Optional[PromptTemplate] = None,
84+
truncation_type: str = "right",
8285
):
8386
self.special_tokens = (
8487
special_tokens if special_tokens is not None else LLAMA3_SPECIAL_TOKENS
@@ -124,6 +127,8 @@ def __init__(
124127
r"<\|start_header_id\|>.*?<\|end_header_id\|>\n\n"
125128
)
126129

130+
self.truncation_type = truncation_type
131+
127132
def _validate_special_tokens(
128133
self,
129134
):
@@ -324,9 +329,17 @@ def tokenize_messages(
324329

325330
if self.max_seq_len:
326331
tokens = truncate(
327-
tokens, self.max_seq_len, self.eos_id if add_end_tokens else None
332+
tokens=tokens,
333+
max_seq_len=self.max_seq_len,
334+
eos_id=self.eos_id if add_end_tokens else None,
335+
truncation_type=self.truncation_type,
336+
)
337+
mask = truncate(
338+
tokens=mask,
339+
max_seq_len=self.max_seq_len,
340+
eos_id=True if add_end_tokens else None,
341+
truncation_type=self.truncation_type,
328342
)
329-
mask = truncate(mask, self.max_seq_len, True if add_end_tokens else None)
330343

331344
return tokens, mask
332345

torchtune/models/mistral/_model_builders.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def mistral_7b() -> TransformerDecoder:
4848
)
4949

5050

51-
def mistral_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = "torchtune.models.mistral.MistralChatTemplate") -> MistralTokenizer:
51+
def mistral_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = "torchtune.models.mistral.MistralChatTemplate", truncation_type: str = "right",) -> MistralTokenizer:
5252
"""
5353
Tokenizer for Mistral models.
5454
@@ -60,11 +60,13 @@ def mistral_tokenizer(path: str, max_seq_len: Optional[int] = None, prompt_templ
6060
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
6161
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
6262
prepend/append tags. Default is :class:`~torchtune.models.mistral.MistralChatTemplate`.
63+
truncation_type (str): type of truncation to apply, either "left" or "right".
64+
Default is "right".
6365
6466
Returns:
6567
MistralTokenizer: Instantiation of the Mistral tokenizer
6668
"""
67-
return MistralTokenizer(path=path, max_seq_len=max_seq_len, prompt_template=_get_prompt_template(prompt_template) if prompt_template is not None else None)
69+
return MistralTokenizer(path=path, max_seq_len=max_seq_len, prompt_template=_get_prompt_template(prompt_template) if prompt_template is not None else None, truncation_type=truncation_type)
6870

6971

7072
def lora_mistral_7b(

torchtune/models/mistral/_tokenizer.py

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class MistralTokenizer(ModelTokenizer, Transform):
3636
3737
The extra text will still get tokenized as normal text, not as special tokens.
3838
Default is :class:`~torchtune.models.mistral.MistralChatTemplate`.
39+
truncation_type (str): type of truncation to apply, either "left" or "right".
40+
Default is "right".
3941
4042
Examples:
4143
>>> tokenizer = MistralTokenizer("/path/to/spm_model")
@@ -49,6 +51,7 @@ def __init__(
4951
path: str,
5052
max_seq_len: Optional[int] = None,
5153
prompt_template: Optional[PromptTemplate] = MistralChatTemplate(),
54+
truncation_type: str = "right",
5255
):
5356
self._spm_model = SentencePieceBaseTokenizer(path)
5457

@@ -61,6 +64,7 @@ def __init__(
6164
self.max_seq_len = max_seq_len
6265

6366
self.prompt_template = prompt_template
67+
self.truncation_type = truncation_type
6468

6569
@property
6670
def eos_id(self):
@@ -172,6 +176,7 @@ def tokenize_messages(
172176
messages=templated_messages,
173177
bos_id=self.bos_id,
174178
eos_id=self.eos_id if add_eos else None,
179+
truncation_type=self.truncation_type,
175180
)
176181

177182
def __call__(

torchtune/models/phi3/_model_builders.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def phi3_mini() -> TransformerDecoder:
4141
norm_eps=1e-5,
4242
)
4343

44-
def phi3_mini_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None) -> Phi3MiniTokenizer:
44+
def phi3_mini_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None, truncation_type: str = "right",) -> Phi3MiniTokenizer:
4545
"""Phi-3 Mini tokenizer.
4646
Ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/tokenizer_config.json
4747
@@ -56,6 +56,8 @@ def phi3_mini_tokenizer(path: str, special_tokens_path: Optional[str] = None, ma
5656
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
5757
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
5858
prepend/append tags.
59+
truncation_type (str): type of truncation to apply, either "left" or "right".
60+
Default is "right".
5961
6062
Note:
6163
This tokenizer includes typical LM EOS and BOS tokens like
@@ -68,7 +70,7 @@ def phi3_mini_tokenizer(path: str, special_tokens_path: Optional[str] = None, ma
6870
"""
6971
special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None
7072
template = _get_prompt_template(prompt_template) if prompt_template is not None else None
71-
return Phi3MiniTokenizer(path=path, special_tokens=special_tokens, max_seq_len=max_seq_len, prompt_template=template)
73+
return Phi3MiniTokenizer(path=path, special_tokens=special_tokens, max_seq_len=max_seq_len, prompt_template=template, truncation_type=truncation_type)
7274

7375

7476
def lora_phi3_mini(

0 commit comments

Comments
 (0)