Skip to content

qwen2.5 modeling support + conversion back to hf ckpt format #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions src/fairseq2/assets/cards/models/qwen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

name: qwen25_7b
model_family: qwen
model_arch: qwen25_7b
checkpoint: "hg://qwen/qwen2.5-7b"
tokenizer: "hg://qwen/qwen2.5-7b"
tokenizer_family: qwen

---

name: qwen25_7b_instruct
model_family: qwen
model_arch: qwen25_7b
model_config:
_set_:
max_seq_len: 32768
checkpoint: "hg://qwen/qwen2.5-7b-instruct"
tokenizer: "hg://qwen/qwen2.5-7b-instruct"
tokenizer_family: qwen
use_im_end: true

---

name: qwen25_14b
model_family: qwen
model_arch: qwen25_14b
checkpoint: "hg://qwen/qwen2.5-14b"
tokenizer: "hg://qwen/qwen2.5-14b"
tokenizer_family: qwen

---

name: qwen25_14b_instruct
model_family: qwen
model_arch: qwen25_14b
checkpoint: "hg://qwen/qwen2.5-14b-instruct"
tokenizer: "hg://qwen/qwen2.5-14b-instruct"
tokenizer_family: qwen
use_im_end: true

---

name: qwen25_32b
model_family: qwen
model_arch: qwen25_32b
checkpoint: "hg://qwen/qwen2.5-32b"
tokenizer: "hg://qwen/qwen2.5-32b"
tokenizer_family: qwen

---

name: qwen25_32b_instruct
model_family: qwen
model_arch: qwen25_32b
checkpoint: "hg://qwen/qwen2.5-32b-instruct"
tokenizer: "hg://qwen/qwen2.5-32b-instruct"
tokenizer_family: qwen
use_im_end: true

---

name: qwen3_0.6b
model_family: qwen
model_arch: qwen3_0.6b
checkpoint: "hg://qwen/qwen3-0.6b-base"
tokenizer: "hg://qwen/qwen3-0.6b-base"
tokenizer_family: qwen

---

name: qwen3_0.6b_instruct
model_family: qwen
model_arch: qwen3_0.6b
checkpoint: "hg://qwen/qwen3-0.6b"
tokenizer: "hg://qwen/qwen3-0.6b"
tokenizer_family: qwen
use_im_end: true

---

name: qwen3_1.7b
model_family: qwen
model_arch: qwen3_1.7b
checkpoint: "hg://qwen/qwen3-1.7b-base"
tokenizer: "hg://qwen/qwen3-1.7b-base"
tokenizer_family: qwen

---

name: qwen3_1.7b_instruct
model_family: qwen
model_arch: qwen3_1.7b
checkpoint: "hg://qwen/qwen3-1.7b"
tokenizer: "hg://qwen/qwen3-1.7b"
tokenizer_family: qwen
use_im_end: true

---

name: qwen3_4b
model_family: qwen
model_arch: qwen3_4b
checkpoint: "hg://qwen/qwen3-4b-base"
tokenizer: "hg://qwen/qwen3-4b-base"
tokenizer_family: qwen

---

name: qwen3_4b_instruct
model_family: qwen
model_arch: qwen3_4b
checkpoint: "hg://qwen/qwen3-4b"
tokenizer: "hg://qwen/qwen3-4b"
tokenizer_family: qwen
use_im_end: true

---

name: qwen3_8b
model_family: qwen
model_arch: qwen3_8b
checkpoint: "hg://qwen/qwen3-8b-base"
tokenizer: "hg://qwen/qwen3-8b-base"
tokenizer_family: qwen

---

name: qwen3_8b_instruct
model_family: qwen
model_arch: qwen3_8b
checkpoint: "hg://qwen/qwen3-8b"
tokenizer: "hg://qwen/qwen3-8b"
tokenizer_family: qwen
use_im_end: true

---

name: qwen3_14b
model_family: qwen
model_arch: qwen3_14b
checkpoint: "hg://qwen/qwen3-14b-base"
tokenizer: "hg://qwen/qwen3-14b-base"
tokenizer_family: qwen

---

name: qwen3_14b_instruct
model_family: qwen
model_arch: qwen3_14b
checkpoint: "hg://qwen/qwen3-14b"
tokenizer: "hg://qwen/qwen3-14b"
tokenizer_family: qwen
use_im_end: true

---

name: qwen3_32b_instruct
model_family: qwen
model_arch: qwen3_32b
checkpoint: "hg://qwen/qwen3-32b"
tokenizer: "hg://qwen/qwen3-32b"
tokenizer_family: qwen
use_im_end: true
8 changes: 0 additions & 8 deletions src/fairseq2/data/text/tokenizers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,6 @@ def load_llama3_tokenizer(path: Path, card: AssetCard) -> TextTokenizer:
eoh_token="<|end_header_id|>",
special_tokens=special_tokens,
)
except ValueError as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' asset card does not contain a valid text tokenizer configuration of the '{LLAMA_TOKENIZER_FAMILY}' family. See the nested exception for details." # fmt: skip
) from ex
except (OSError, RuntimeError) as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' text tokenizer model cannot be loaded. See the nested exception for details." # fmt: skip
Expand Down Expand Up @@ -286,10 +282,6 @@ def load_llama3_hg_tokenizer(path: Path, card: AssetCard) -> TextTokenizer:
boh_token="<|start_header_id|>",
eoh_token="<|end_header_id|>",
)
except ValueError as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' asset card does not contain a valid text tokenizer configuration of the '{LLAMA_TOKENIZER_FAMILY}' family. See the nested exception for details." # fmt: skip
) from ex
except (OSError, RuntimeError) as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' text tokenizer model cannot be loaded. See the nested exception for details." # fmt: skip
Expand Down
131 changes: 131 additions & 0 deletions src/fairseq2/data/text/tokenizers/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from pathlib import Path
from typing import Final, final

from typing_extensions import override

from fairseq2.assets import AssetCard, AssetCardError, AssetCardFieldNotFoundError
from fairseq2.data import VocabularyInfo
from fairseq2.data.text.tokenizers import (
TextTokenDecoder,
TextTokenEncoder,
TextTokenizer,
TextTokenizerLoadError,
text_tokenizer_asset_card_error,
)
from fairseq2.data.text.tokenizers.hg import (
HuggingFaceTokenDecoder,
HuggingFaceTokenEncoder,
HuggingFaceTokenModel,
load_hg_token_model,
)
from fairseq2.device import Device


@final
class QwenTokenizer(TextTokenizer):
"""Represents a Qwen tokenizer"""

_model: HuggingFaceTokenModel
_eos_token: str

def __init__(self, model: HuggingFaceTokenModel, eos_token: str) -> None:
self._model = model

self._eos_token = eos_token

@override
def create_encoder(
self,
*,
task: str | None = None,
lang: str | None = None,
mode: str | None = None,
device: Device | None = None,
pin_memory: bool = False,
) -> TextTokenEncoder:
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

match mode:
case None | "default":
suffix_tokens = [self._eos_token]
case "prompt":
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = []
case "prompt_response":
suffix_tokens = [self._eos_token]
case "as_is":
suffix_tokens = []
case _:
raise ValueError(
f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response, as_is"
)

return HuggingFaceTokenEncoder(
self._model,
prefix_tokens=[],
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)

@override
def create_raw_encoder(
self, *, device: Device | None = None, pin_memory: bool = False
) -> TextTokenEncoder:
return HuggingFaceTokenEncoder(
self._model, device=device, pin_memory=pin_memory
)

@override
def create_decoder(self, *, skip_special_tokens: bool = False) -> TextTokenDecoder:
return HuggingFaceTokenDecoder(
self._model, skip_special_tokens=skip_special_tokens
)

@property
@override
def vocab_info(self) -> VocabularyInfo:
return self._model.vocab_info


QWEN_TOKENIZER_FAMILY: Final = "qwen"


def load_qwen_tokenizer(path: Path, card: AssetCard) -> TextTokenizer:
try:
use_im_end = card.field("use_im_end").as_(bool)
except AssetCardFieldNotFoundError:
use_im_end = False
except AssetCardError as ex:
raise text_tokenizer_asset_card_error(card.name) from ex

eos_token = "<|im_end|>" if use_im_end else "<|endoftext|>"

try:
model = load_hg_token_model(
path,
unk_token=None,
bos_token=None,
eos_token=eos_token,
pad_token="<|endoftext|>",
boh_token=None,
eoh_token=None,
)
except (OSError, RuntimeError) as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' text tokenizer model cannot be loaded. See the nested exception for details." # fmt: skip
) from ex

return QwenTokenizer(model, eos_token)
2 changes: 1 addition & 1 deletion src/fairseq2/models/jepa/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create_self_attention(self, layer_idx: int) -> MultiheadAttention:
return StandardMultiheadAttention(
config.model_dim,
config.num_encoder_attn_heads,
sdpa=sdpa,
sdpa,
bias=config.qkv_bias,
output_proj=output_proj,
)
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/jepa/classifier/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def create_cross_attention(self) -> MultiheadAttention:
return StandardMultiheadAttention(
encoder_config.model_dim,
encoder_config.num_encoder_attn_heads,
sdpa=sdpa,
sdpa,
bias=encoder_config.qkv_bias,
output_proj=output_proj,
)
Expand Down
8 changes: 7 additions & 1 deletion src/fairseq2/models/llama/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ def permute_rotary(w: Tensor, num_heads: int) -> Tensor:
# fmt: on
}

return convert_checkpoint(checkpoint, key_map)
checkpoint = convert_checkpoint(checkpoint, key_map)

# Safetensors does not support shared tensors.
if config.tie_embeddings:
checkpoint["final_proj.weight"] = checkpoint["decoder_frontend.embed.weight"] # fmt: skip

return checkpoint

if "tok_embeddings.weight" in checkpoint: # reference
key_map = {
Expand Down
9 changes: 7 additions & 2 deletions src/fairseq2/models/llama/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class LLaMAConfig:
encoder, aiming to increase the context length.
"""

dropout_p: float = 0.0
"""The dropout probability on outputs of Transformer layers."""

init_std: float | None = None
"""
If not ``None``, the standard deviation to initialize input embeddings and
Expand All @@ -93,8 +96,8 @@ class LLaMAConfig:
the decoder.
"""

dropout_p: float = 0.0
"""The dropout probability on outputs of Transformer layers."""
shard_embed_dim: bool = True
"""If ``True``, shards the embedding dimension for tensor parallelism."""


@dataclass
Expand Down Expand Up @@ -201,6 +204,7 @@ def llama3_8b() -> LLaMAConfig:
config.ffn_inner_dim_multiplier = 1.3
config.ffn_inner_dim_multiple_of = 1024
config.rope_theta = 500_000.0
config.shard_embed_dim = False

return config

Expand All @@ -212,6 +216,7 @@ def llama3_70b() -> LLaMAConfig:
config.vocab_size = 128_256
config.pad_idx = 128_004
config.rope_theta = 500_000.0
config.shard_embed_dim = False

return config

Expand Down
Loading
Loading