Skip to content

Commit 20e3ede

Browse files
author
Ilia Kulikov
committed
hf tokenizer support added
1 parent 7bd4c0d commit 20e3ede

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from collections.abc import Sequence
10+
from pathlib import Path
11+
from typing import final
12+
13+
import torch
14+
from torch import Tensor
15+
from typing_extensions import override
16+
17+
from fairseq2.data import VocabularyInfo
18+
from fairseq2.data.text.tokenizers import (
19+
TextTokenDecoder,
20+
TextTokenEncoder,
21+
)
22+
from fairseq2.typing import Device
23+
from transformers import AutoTokenizer
24+
25+
26+
@final
27+
class HuggingfaceTokenizerEncoder(TextTokenEncoder):
28+
"""Represents a tiktoken decoder."""
29+
30+
_tokenizer: AutoTokenizer
31+
_prefix_indices: list[int]
32+
_suffix_indices: list[int]
33+
_prefix_index_tensor: Tensor | None
34+
_suffix_index_tensor: Tensor | None
35+
_device: Device | None
36+
_pin_memory: bool
37+
38+
def __init__(
39+
self,
40+
tokenizer: AutoTokenizer,
41+
*,
42+
prefix_tokens: Sequence[str] | None = None,
43+
suffix_tokens: Sequence[str] | None = None,
44+
device: Device | None = None,
45+
pin_memory: bool = False,
46+
) -> None:
47+
"""
48+
:param tokenizer:
49+
The huggingface :class:`AutoTokenizer` object.
50+
:param prefix_tokens:
51+
The prefix tokens to encode with input text.
52+
:param suffix_tokens:
53+
The suffix tokens to encode with input text.
54+
:param device:
55+
The device on which to construct tensors.
56+
:param pin_memory:
57+
If ``True``, uses pinned memory while constructing tensors.
58+
"""
59+
self._tokenizer = tokenizer
60+
61+
# Prefix
62+
if prefix_tokens:
63+
self._prefix_indices = self._tokenizer.convert_tokens_to_ids(prefix_tokens)
64+
65+
self._prefix_index_tensor = torch.tensor(
66+
self._prefix_indices, dtype=torch.int64, device=device
67+
)
68+
else:
69+
self._prefix_indices = []
70+
71+
self._prefix_index_tensor = None
72+
73+
# Suffix
74+
if suffix_tokens:
75+
self._suffix_indices = self._tokenizer.convert_tokens_to_ids(suffix_tokens)
76+
77+
self._suffix_index_tensor = torch.tensor(
78+
self._suffix_indices, dtype=torch.int64, device=device
79+
)
80+
else:
81+
self._suffix_indices = []
82+
83+
self._suffix_index_tensor = None
84+
85+
self._device = device
86+
self._pin_memory = pin_memory
87+
88+
@override
89+
def __call__(self, text: str) -> Tensor:
90+
# fairseq2 tokenizer adds special tokens on its own
91+
indices = self._tokenizer.encode(text, add_special_tokens=False)
92+
93+
if self._prefix_indices:
94+
indices = self._prefix_indices + indices
95+
96+
if self._suffix_indices:
97+
indices.extend(self._suffix_indices)
98+
99+
return torch.tensor(
100+
indices, dtype=torch.int64, device=self._device, pin_memory=self._pin_memory
101+
)
102+
103+
@override
104+
def encode_as_tokens(self, text: str) -> list[str]:
105+
indices = self(text).tolist()
106+
107+
tokens = self._tokenizer.convert_tds_to_tokens(indices)
108+
109+
return tokens
110+
111+
@property
112+
@override
113+
def prefix_indices(self) -> Tensor | None:
114+
return self._prefix_index_tensor
115+
116+
@property
117+
@override
118+
def suffix_indices(self) -> Tensor | None:
119+
return self._suffix_index_tensor
120+
121+
122+
@final
123+
class HuggingfaceTokenizerDecoder(TextTokenDecoder):
124+
"""Represents a tiktoken decoder."""
125+
126+
_tokenizer: AutoTokenizer
127+
128+
def __init__(self, tokenizer: AutoTokenizer) -> None:
129+
self._tokenizer = tokenizer
130+
131+
@override
132+
def __call__(self, token_indices: Tensor) -> str:
133+
if token_indices.dim() != 1:
134+
raise ValueError(
135+
f"`token_indices` must be one dimensional, but has {token_indices.dim()} dimensions instead."
136+
)
137+
138+
return self._tokenizer.decode(token_indices)
139+
140+
@override
141+
def decode_from_tokens(self, tokens: Sequence[str]) -> str:
142+
indices = self._tokenizer.convert_tokens_to_ids(tokens)
143+
144+
return self._tokenizer.decode(indices)

src/fairseq2/data/text/tokenizers/llama.py

+111
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,97 @@
2424
TiktokenEncoder,
2525
TiktokenModel,
2626
)
27+
from fairseq2.data.text.tokenizers.huggingface_tokenizer import (
28+
HuggingfaceTokenizerEncoder,
29+
HuggingfaceTokenizerDecoder,
30+
)
2731
from fairseq2.typing import Device
32+
from transformers import AutoTokenizer
33+
34+
35+
@final
36+
class LLaMA3TokenizerHuggingFace(TextTokenizer):
37+
"""Represents a HuggingFace version of LLama 3 tokenizer"""
38+
39+
_tokenizer: AutoTokenizer
40+
_bos_token: str
41+
_eos_token: str
42+
43+
def __init__(self, path: Path) -> None:
44+
45+
self._tokenizer = AutoTokenizer.from_pretrained(path)
46+
47+
self._eos_token = self._tokenizer.special_tokens_map["eos_token"]
48+
self._bos_token = self._tokenizer.special_tokens_map["bos_token"]
49+
50+
@override
51+
def create_encoder(
52+
self,
53+
*,
54+
task: str | None = None,
55+
lang: str | None = None,
56+
mode: str | None = None,
57+
device: Device | None = None,
58+
pin_memory: bool = False,
59+
) -> TiktokenEncoder:
60+
if task is not None:
61+
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")
62+
63+
if lang is not None:
64+
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")
65+
66+
match mode:
67+
case None | "default":
68+
prefix_tokens = [self._bos_token]
69+
suffix_tokens = [self._eos_token]
70+
case "prompt":
71+
prefix_tokens = [self._bos_token]
72+
# In prompt mode, we expect the generator to finish the sequence.
73+
suffix_tokens = []
74+
case "prompt_response":
75+
prefix_tokens = []
76+
suffix_tokens = [self._eos_token]
77+
case "as_is":
78+
prefix_tokens = []
79+
suffix_tokens = []
80+
case _:
81+
raise ValueError(
82+
f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response, as_is"
83+
)
84+
85+
return HuggingfaceTokenizerEncoder(
86+
self._tokenizer,
87+
prefix_tokens=prefix_tokens,
88+
suffix_tokens=suffix_tokens,
89+
device=device,
90+
pin_memory=pin_memory,
91+
)
92+
93+
@override
94+
def create_raw_encoder(
95+
self, *, device: Device | None = None, pin_memory: bool = False
96+
) -> TiktokenEncoder:
97+
return HuggingfaceTokenizerEncoder(
98+
self._tokenizer, device=device, pin_memory=pin_memory
99+
)
100+
101+
@override
102+
def create_decoder(self) -> TiktokenDecoder:
103+
return HuggingfaceTokenizerDecoder(self._model)
104+
105+
@property
106+
@override
107+
def vocab_info(self) -> VocabularyInfo:
108+
bos_idx = self._tokenizer.convert_tokens_to_ids(self._bos_token)
109+
eos_idx = self._tokenizer.convert_tokens_to_ids(self._eos_token)
110+
vocab_info = VocabularyInfo(
111+
size=len(self._tokenizer),
112+
bos_idx=bos_idx,
113+
eos_idx=eos_idx,
114+
unk_idx=None,
115+
pad_idx=None,
116+
)
117+
return vocab_info
28118

29119

30120
@final
@@ -139,6 +229,27 @@ def vocab_info(self) -> VocabularyInfo:
139229

140230

141231
def load_llama_tokenizer(path: Path, card: AssetCard) -> TextTokenizer:
232+
233+
# first check if this is HuggingFace tokenizer
234+
try:
235+
use_hf = card.field("use_hf_tokenizer").as_(bool)
236+
except AssetCardFieldNotFoundError:
237+
use_hf = False
238+
except AssetCardError as ex:
239+
raise text_tokenizer_asset_card_error(card.name) from ex
240+
241+
if use_hf:
242+
try:
243+
return LLaMA3TokenizerHuggingFace(path)
244+
except ValueError as ex:
245+
raise TextTokenizerLoadError(
246+
card.name, f"The '{card.name}' asset card does not contain a valid text tokenizer configuration of the '{LLAMA_TOKENIZER_FAMILY}' family. See the nested exception for details." # fmt: skip
247+
) from ex
248+
except RuntimeError as ex:
249+
raise TextTokenizerLoadError(
250+
card.name, f"The '{card.name}' text tokenizer cannot be loaded. See the nested exception for details." # fmt: skip
251+
) from ex
252+
142253
try:
143254
use_v2 = card.field("use_v2_tokenizer").as_(bool)
144255
except AssetCardFieldNotFoundError:

0 commit comments

Comments
 (0)