Skip to content

Llama 4 Scout and Maverick archs #1106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions src/fairseq2/data/text/tokenizers/llama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from pathlib import Path
from typing import Final, final

from typing_extensions import override

from fairseq2.assets import AssetCard
from fairseq2.data import VocabularyInfo
from fairseq2.data.text.tokenizers import (
TextTokenizer,
TextTokenizerLoadError,
)
from fairseq2.data.text.tokenizers.tiktoken import (
TiktokenDecoder,
TiktokenEncoder,
TiktokenModel,
)
from fairseq2.typing import Device


def get_reserved_special_tokens(
name: str, count: int, start_index: int = 0
) -> list[str]:
return [
f"<|{name}_reserved_special_token_{i}|>"
for i in range(start_index, start_index + count)
]


# 200005, ..., 200079
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|header_start|>",
"<|header_end|>",
"<|eom|>",
"<|eot|>",
"<|step|>",
"<|text_post_train_reserved_special_token_0|>",
"<|text_post_train_reserved_special_token_1|>",
"<|text_post_train_reserved_special_token_2|>",
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
"<|text_post_train_reserved_special_token_6|>",
"<|text_post_train_reserved_special_token_7|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 8
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>

# 200080, ..., 201133
LLAMA4_VISION_SPECIAL_TOKENS = [
"<|image_start|>",
"<|image_end|>",
"<|vision_reserved_special_token_0|>",
"<|vision_reserved_special_token_1|>",
"<|tile_x_separator|>",
"<|tile_y_separator|>",
"<|vision_reserved_special_token_2|>",
"<|vision_reserved_special_token_3|>",
"<|vision_reserved_special_token_4|>",
"<|vision_reserved_special_token_5|>",
"<|image|>",
"<|vision_reserved_special_token_6|>",
"<|patch|>",
] + get_reserved_special_tokens(
"vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>

# 201134, ..., 201143
LLAMA4_REASONING_SPECIAL_TOKENS = [
"<|reasoning_reserved_special_token_0|>",
"<|reasoning_reserved_special_token_1|>",
"<|reasoning_reserved_special_token_2|>",
"<|reasoning_reserved_special_token_3|>",
"<|reasoning_reserved_special_token_4|>",
"<|reasoning_reserved_special_token_5|>",
"<|reasoning_reserved_special_token_6|>",
"<|reasoning_reserved_special_token_7|>",
"<|reasoning_thinking_start|>",
"<|reasoning_thinking_end|>",
]

LLAMA4_SPECIAL_TOKENS = (
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS
+ LLAMA4_VISION_SPECIAL_TOKENS
+ LLAMA4_REASONING_SPECIAL_TOKENS
)

BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|fim_prefix|>",
"<|fim_middle|>",
"<|fim_suffix|>",
]


@final
class LLaMA4Tokenizer(TextTokenizer):
"""Represents a LLaMA 4 tokenizer."""

O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # fmt: skip

num_reserved_special_tokens = 2048

_SPLIT_REGEX: Final = O200K_PATTERN

_model: TiktokenModel
_eos_token: str
special_tokens: dict[str, int]

def __init__(self, path: Path, custom_eos: str | None = "<|eot|>") -> None:
"""
:param path:
The path to the tiktoken BPE file.
:param custom_eos:
If not ``None``, replaces the original EOS token.
"""
self._eos_token = custom_eos or "<|end_of_text|>"

special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
assert len(set(special_tokens)) == len(special_tokens)
assert len(special_tokens) <= self.num_reserved_special_tokens

reserved_tokens = [
f"<|reserved_special_token_{i}|>"
for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens

self._model = TiktokenModel(
path,
split_regex=self._SPLIT_REGEX,
unk_token=None,
bos_token="<|begin_of_text|>",
eos_token=self._eos_token,
pad_token="<|finetune_right_pad|>",
boh_token="<|header_start|>",
eoh_token="<|header_end|>",
special_tokens=special_tokens,
)
self.special_tokens = self._model.special_tokens

@override
def create_encoder(
self,
*,
task: str | None = None,
lang: str | None = None,
mode: str | None = None,
device: Device | None = None,
pin_memory: bool = False,
) -> TiktokenEncoder:
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

match mode:
case None | "default":
prefix_tokens = ["<|begin_of_text|>"]
suffix_tokens = [self._eos_token]
case "prompt":
prefix_tokens = ["<|begin_of_text|>"]
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = []
case "prompt_response":
prefix_tokens = []
suffix_tokens = [self._eos_token]
case "as_is":
prefix_tokens = []
suffix_tokens = []
case _:
raise ValueError(
f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response, as_is"
)

return TiktokenEncoder(
self._model,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)

@override
def create_raw_encoder(
self, *, device: Device | None = None, pin_memory: bool = False
) -> TiktokenEncoder:
return TiktokenEncoder(self._model, device=device, pin_memory=pin_memory)

@override
def create_decoder(self) -> TiktokenDecoder:
return TiktokenDecoder(self._model)

@property
@override
def vocab_info(self) -> VocabularyInfo:
return self._model.vocab_info


LLAMA4_TOKENIZER_FAMILY: Final = "llama4"


def load_llama4_tokenizer(path: Path, card: AssetCard) -> TextTokenizer:
try:
return LLaMA4Tokenizer(path)
except ValueError as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' asset card does not contain a valid text tokenizer configuration of the '{LLAMA4_TOKENIZER_FAMILY}' family. See the nested exception for details." # fmt: skip
) from ex
except RuntimeError as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' text tokenizer cannot be loaded. See the nested exception for details." # fmt: skip
) from ex
2 changes: 2 additions & 0 deletions src/fairseq2/data/text/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TiktokenModel:
_encoding: Encoding
_num_bpe_tokens: int
_vocab_info: VocabularyInfo
special_tokens: dict[str, int]

def __init__(
self,
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
}
else:
special_token_map = {}
self.special_tokens = special_token_map

self._encoding = Encoding(
name=path.stem,
Expand Down
45 changes: 32 additions & 13 deletions src/fairseq2/generation/_sampling/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class _AbstractSamplingSequenceGeneratorOp(ABC):
_num_gens: int
_min_prompt_len: int
_max_prompt_len: int
_min_seq_len: int
_min_seq_lens: Tensor
_max_seq_len: int
_echo_prompt: bool
_compute_scores: bool
Expand All @@ -458,7 +458,7 @@ class _AbstractSamplingSequenceGeneratorOp(ABC):
_step_hooks: dict[int, StepHook]
_step_nr: int
_state_bag: IncrementalStateBag
_prompt_lens: Tensor | None
_prompt_lens: Tensor
_prompt_mask: Tensor | None
_prompt_indices: Tensor
_seqs: Tensor
Expand Down Expand Up @@ -505,6 +505,12 @@ def __init__(
if prompt_padding_mask is None:
self._min_prompt_len, min_prompt_idx = prompt_seqs.size(1), 0
self._max_prompt_len, max_prompt_idx = prompt_seqs.size(1), 0
prompt_seq_lens = torch.full(
(prompt_seqs.size(0),),
prompt_seqs.size(1),
dtype=prompt_seqs.dtype,
device=prompt_seqs.device,
)
else:
prompt_seq_lens = prompt_padding_mask.seq_lens

Expand All @@ -525,7 +531,7 @@ def __init__(
f"The length of `prompt_seqs[{int(max_prompt_idx)}]` must be less than `max_seq_len` ({max_seq_len}), but is {self._max_prompt_len} instead."
)

self._min_seq_len = min(max_seq_len, self._max_prompt_len + min_gen_len)
self._min_seq_lens = (prompt_seq_lens + min_gen_len).clamp(max=max_seq_len)
self._max_seq_len = min(max_seq_len, self._max_prompt_len + max_gen_len)

self._echo_prompt = echo_prompt
Expand All @@ -544,13 +550,12 @@ def __init__(
self._max_seq_len, capacity_increment=decode_capacity_increment
)

# (P)
self._prompt_lens = prompt_seq_lens

if prompt_padding_mask is None:
self._prompt_lens = None
self._prompt_mask = None
else:
# (P)
self._prompt_lens = prompt_padding_mask.seq_lens

# (P, S_prm)
self._prompt_mask = prompt_padding_mask.materialize()

Expand Down Expand Up @@ -649,8 +654,7 @@ def _prefill(self) -> None:

logits = model_output.logits

if self._temperature != 1.0:
logits /= self._temperature
self._apply_temperature(logits)

# (P, S_prm - 1, V)
probs = softmax(logits, dim=-1, dtype=torch.float32)
Expand Down Expand Up @@ -689,6 +693,7 @@ def _prefill(self) -> None:
self._counters.prefill_size += prefill_len * self._seqs.size(0)

def _step(self) -> bool:

# Generate the next step output.
model_output = self._decode(self._seqs[:, self._step_nr - 1 : self._step_nr])

Expand All @@ -698,8 +703,7 @@ def _step(self) -> bool:

logits = model_output.logits

if self._temperature != 1.0:
logits /= self._temperature
self._apply_temperature(logits)

# (N, 1, V)
probs = softmax(logits, dim=-1, dtype=torch.float32)
Expand Down Expand Up @@ -733,8 +737,10 @@ def _step(self) -> bool:
probs[:, self._pad_idx] = 0

# Do not allow EOS till we reach the minimum sequence length.
if self._step_nr < self._min_seq_len - 1:
probs[:, self._eos_idx] = 0
mask_min_len = self._step_nr < self._min_seq_lens - 1
mask_after_prompt = self._step_nr >= self._prompt_lens
do_not_eos_mask = mask_min_len & mask_after_prompt
probs[do_not_eos_mask, self._eos_idx] = 0

# (N)
vocab_indices = self._sampler.sample(probs)
Expand Down Expand Up @@ -853,6 +859,16 @@ def _finish_sequence(self, seq_idx: int) -> None:

self._output[prompt_idx].append(Hypothesis(seq, score, step_scores))

def _apply_temperature(self, logits: Tensor) -> None:
if self._temperature == 0.0:
max_indices = torch.argmax(logits, dim=-1) # (N, 1)
logits.fill_(float("-inf"))
# (N, 1, V)
logits.scatter_(-1, max_indices.unsqueeze(-1), 0.0)

elif self._temperature != 1.0:
logits.div_(self._temperature)

def _reorder_state(self, new_order: Tensor) -> None:
self._state_bag.reorder(new_order)

Expand All @@ -867,6 +883,9 @@ def _reorder_state(self, new_order: Tensor) -> None:
# (N) -> (N - F)
self._prompt_indices = self._prompt_indices.index_select(dim=0, index=new_order)

# (N) -> (N - F)
self._min_seq_lens = self._min_seq_lens.index_select(dim=0, index=new_order)

# (N, S) -> (N - F, S)
self._seqs = self._seqs.index_select(dim=0, index=new_order)

Expand Down
9 changes: 9 additions & 0 deletions src/fairseq2/models/llama/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ class LLaMAConfig:
dropout_p: float = 0.0
"""The dropout probability on outputs of Transformer layers."""

use_qk_norm: bool = False
"""If ``True``, applies layer normalization to the projected query and key."""

nope_layer_interval: int | None = None
"""
If not ``None``, will use a NoPE layer (no positional embedding)
instead of a RoPE layer every ``nope_layer_interval`` layers.
"""


@dataclass
class LLaMARopeScalingConfig:
Expand Down
Loading
Loading