Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c3e42da

Browse files
committedMar 13, 2025·
phi4 config and style fixes
1 parent dab36d2 commit c3e42da

14 files changed

+90
-39
lines changed
 

‎docs/source/api_ref_models.rst

+20
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,26 @@ To download the Qwen2 1.5B model, for example:
299299
qwen2.lora_qwen2_7b
300300
qwen2.qwen2_tokenizer
301301

302+
phi-4
303+
-----
304+
305+
Models from the `Phi-4 family <https://arxiv.org/abs/2412.08905>`_.
306+
307+
To download the Phi-4 instruct model:
308+
309+
.. code-block:: bash
310+
311+
tune download microsoft/phi-4 --hf-token <HF_TOKEN>
312+
313+
.. autosummary::
314+
:toctree: generated/
315+
:nosignatures:
316+
317+
phi3.phi4_14b
318+
phi3.lora_phi4_14b
319+
phi3.qlora_phi4_14b
320+
phi3.phi4_tokenizer
321+
302322
phi-3
303323
-----
304324

‎recipes/configs/phi4/evaluation.yaml

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@ checkpointer:
1414
_component_: torchtune.training.FullModelHFCheckpointer
1515
checkpoint_dir: /tmp/phi-4
1616
checkpoint_files: [
17-
model-00001-of-00002.safetensors,
18-
model-00002-of-00002.safetensors
17+
model-00001-of-00006.safetensors,
18+
model-00002-of-00006.safetensors,
19+
model-00003-of-00006.safetensors,
20+
model-00004-of-00006.safetensors,
21+
model-00005-of-00006.safetensors,
22+
model-00006-of-00006.safetensors,
1923
]
2024
recipe_checkpoint: null
2125
output_dir: ${output_dir}
22-
model_type: PHI3_MINI
26+
model_type: PHI4
2327
resume_from_checkpoint: False
2428

2529
# Tokenizer
2630
tokenizer:
27-
_component_: torchtune.models.phi4.phi4_14b_tokenizer
31+
_component_: torchtune.models.phi4.phi4_tokenizer
2832
vocab_path: /tmp/phi-4/vocab.json
2933
merges_path: /tmp/phi-4/merges.txt
3034
max_seq_len: null

‎recipes/configs/phi4/full.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ model:
2525

2626
# Tokenizer
2727
tokenizer:
28-
_component_: torchtune.models.phi4.phi4_14b_tokenizer
28+
_component_: torchtune.models.phi4.phi4_tokenizer
2929
vocab_path: /tmp/phi-4/vocab.json
3030
merges_path: /tmp/phi-4/merges.txt
3131
max_seq_len: null
@@ -44,7 +44,7 @@ checkpointer:
4444
]
4545
recipe_checkpoint: null
4646
output_dir: ${output_dir}
47-
model_type: PHI3_MINI
47+
model_type: PHI4
4848
resume_from_checkpoint: False
4949

5050
# Dataset

‎recipes/configs/phi4/full_low_memory.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ model:
2727

2828
# Tokenizer
2929
tokenizer:
30-
_component_: torchtune.models.phi4.phi4_14b_tokenizer
30+
_component_: torchtune.models.phi4.phi4_tokenizer
3131
vocab_path: /tmp/phi-4/vocab.json
3232
merges_path: /tmp/phi-4/merges.txt
3333
max_seq_len: null
@@ -46,7 +46,7 @@ checkpointer:
4646
]
4747
recipe_checkpoint: null
4848
output_dir: ${output_dir}
49-
model_type: PHI3_MINI
49+
model_type: PHI4
5050
resume_from_checkpoint: False
5151

5252
# Dataset

‎recipes/configs/phi4/lora.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ model:
3131

3232
# Tokenizer
3333
tokenizer:
34-
_component_: torchtune.models.phi4.phi4_14b_tokenizer
34+
_component_: torchtune.models.phi4.phi4_tokenizer
3535
vocab_path: /tmp/phi-4/vocab.json
3636
merges_path: /tmp/phi-4/merges.txt
3737
max_seq_len: null
@@ -50,7 +50,7 @@ checkpointer:
5050
]
5151
recipe_checkpoint: null
5252
output_dir: ${output_dir}
53-
model_type: PHI3_MINI
53+
model_type: PHI4
5454
resume_from_checkpoint: False
5555
save_adapter_weights_only: False
5656

‎recipes/configs/phi4/lora_single_device.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ model:
2929

3030
# Tokenizer
3131
tokenizer:
32-
_component_: torchtune.models.phi4.phi4_14b_tokenizer
32+
_component_: torchtune.models.phi4.phi4_tokenizer
3333
vocab_path: /tmp/phi-4/vocab.json
3434
merges_path: /tmp/phi-4/merges.txt
3535
max_seq_len: null
@@ -48,7 +48,7 @@ checkpointer:
4848
]
4949
recipe_checkpoint: null
5050
output_dir: ${output_dir}
51-
model_type: PHI3_MINI
51+
model_type: PHI4
5252
resume_from_checkpoint: False
5353
save_adapter_weights_only: False
5454

‎recipes/configs/phi4/qlora_single_device.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ model:
2929

3030
# Tokenizer
3131
tokenizer:
32-
_component_: torchtune.models.phi4.phi4_14b_tokenizer
32+
_component_: torchtune.models.phi4.phi4_tokenizer
3333
vocab_path: /tmp/phi-4/vocab.json
3434
merges_path: /tmp/phi-4/merges.txt
3535
max_seq_len: null
@@ -48,7 +48,7 @@ checkpointer:
4848
]
4949
recipe_checkpoint: null
5050
output_dir: ${output_dir}
51-
model_type: PHI3_MINI
51+
model_type: PHI4
5252
resume_from_checkpoint: False
5353
save_adapter_weights_only: False
5454

‎tests/torchtune/models/phi4/test_phi4_tokenizer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
from tests.common import ASSETS
1212
from torchtune.data import Message
13-
from torchtune.models.phi4 import phi4_14b_tokenizer
13+
from torchtune.models.phi4 import phi4_tokenizer
1414

1515

16-
class TestPhi4MiniTokenizer:
16+
class TestPhi4Tokenizer:
1717
@pytest.fixture
1818
def tokenizer(self):
1919
# GPT2BaseTokenizer
20-
return phi4_14b_tokenizer(
20+
return phi4_tokenizer(
2121
vocab_path=(ASSETS / "vocab.json"),
2222
merges_path=(ASSETS / "merges.txt"),
2323
)

‎torchtune/_recipe_registry.py

+4
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,10 @@ class Recipe:
522522
name="gemma/evaluation",
523523
file_path="gemma/evaluation.yaml",
524524
),
525+
Config(
526+
name="phi4/evaluation",
527+
file_path="phi4/evaluation.yaml",
528+
),
525529
Config(
526530
name="phi3/evaluation",
527531
file_path="phi3/evaluation.yaml",

‎torchtune/models/phi4/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._model_builders import lora_phi4_14b, phi4_14b, phi4_14b_tokenizer # noqa
7+
from ._model_builders import lora_phi4_14b, phi4_14b, phi4_tokenizer # noqa
88

99
__all__ = [
1010
"phi4_14b",
11-
"phi4_14b_tokenizer",
11+
"phi4_tokenizer",
1212
"lora_phi4_14b",
1313
]

‎torchtune/models/phi4/_model_builders.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
from functools import partial
12
from typing import List, Optional
23

3-
from torchtune.models.phi3._component_builders import phi3, lora_phi3
4-
from torchtune.models.phi4._tokenizer import Phi4MiniTokenizer
4+
from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType
5+
6+
from torchtune.models.phi3._component_builders import lora_phi3, phi3
7+
from torchtune.models.phi4._tokenizer import Phi4Tokenizer
58

69
from torchtune.modules import TransformerDecoder
710
from torchtune.modules.peft import LORA_ATTN_MODULES
8-
from functools import partial
911
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
10-
from torchtune.data._prompt_templates import _TemplateType
11-
from torchtune.data._prompt_templates import _get_prompt_template
1212

1313

1414
"""
@@ -36,13 +36,21 @@ def phi4_14b() -> TransformerDecoder:
3636
norm_eps=1e-5,
3737
)
3838

39-
def phi4_14b_tokenizer(vocab_path: str = None, merges_path: str = None, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None, truncation_type: str = "right") -> Phi4MiniTokenizer:
40-
"""Phi4 (14B) tokenizer.
39+
40+
def phi4_tokenizer(
41+
vocab_path: str = None,
42+
merges_path: str = None,
43+
special_tokens_path: Optional[str] = None,
44+
max_seq_len: Optional[int] = None,
45+
prompt_template: Optional[_TemplateType] = None,
46+
truncation_type: str = "right",
47+
) -> Phi4Tokenizer:
48+
"""Phi4 tokenizer.
4149
Args:
4250
vocab_path (str): Path to vocab.json.
4351
merges_path (str): Path to merges.txt.
4452
special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face
45-
model files that contains all registered special tokens, or a local json file
53+
model files that contains all registered special tokens, or a local json file
4654
structured similarly. Default is None to use the canonical Phi4 special tokens.
4755
max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages,
4856
after which the input will be truncated. Default is None.
@@ -54,11 +62,24 @@ def phi4_14b_tokenizer(vocab_path: str = None, merges_path: str = None, special_
5462
Default is "right".
5563
5664
Returns:
57-
Phi4MiniTokenizer: Instantiation of the Phi-4 (14B) tokenizer.
65+
Phi4Tokenizer: Instantiation of the Phi-4 (14B) tokenizer.
5866
"""
59-
special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None
60-
template = _get_prompt_template(prompt_template) if prompt_template is not None else None
61-
return Phi4MiniTokenizer(vocab_path=vocab_path, merges_path=merges_path, special_tokens=special_tokens, max_seq_len=max_seq_len, prompt_template=template, truncation_type=truncation_type)
67+
special_tokens = (
68+
parse_hf_tokenizer_json(special_tokens_path)
69+
if special_tokens_path is not None
70+
else None
71+
)
72+
template = (
73+
_get_prompt_template(prompt_template) if prompt_template is not None else None
74+
)
75+
return Phi4Tokenizer(
76+
vocab_path=vocab_path,
77+
merges_path=merges_path,
78+
special_tokens=special_tokens,
79+
max_seq_len=max_seq_len,
80+
prompt_template=template,
81+
truncation_type=truncation_type,
82+
)
6283

6384

6485
def lora_phi4_14b(

‎torchtune/models/phi4/_tokenizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa
3838

3939

40-
class Phi4MiniTokenizer(ModelTokenizer, Transform):
40+
class Phi4Tokenizer(ModelTokenizer, Transform):
4141
"""
4242
TikToken tokenizer configured with Phi4 (14B) special tokens.
4343

‎torchtune/training/checkpointing/_checkpointer.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -530,10 +530,10 @@ def load_checkpoint(self) -> Dict[str, Any]:
530530
# delete the state_dict to free up memory; TODO check if this del is needed
531531
del state_dict
532532
gc.collect()
533-
if self._model_type == ModelType.PHI3_MINI:
533+
if self._model_type in (ModelType.PHI3_MINI, ModelType.PHI4):
534534
log_rank_zero(
535535
logger=logger,
536-
msg="Converting Phi-3 Mini weights from HF format."
536+
msg="Converting Phi weights from HF format."
537537
"Note that conversion of adapter weights into PEFT format is not supported.",
538538
)
539539
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune
@@ -661,7 +661,7 @@ def save_checkpoint(
661661
"""
662662
# convert the state_dict back to hf format; do this inplace
663663
if not adapter_only:
664-
if self._model_type == ModelType.PHI3_MINI:
664+
if self._model_type in (ModelType.PHI3_MINI, ModelType.PHI4):
665665
from torchtune.models.phi3._convert_weights import phi3_tune_to_hf
666666

667667
state_dict[training.MODEL_KEY] = phi3_tune_to_hf(
@@ -817,9 +817,9 @@ def save_checkpoint(
817817
f"saved to {output_path}"
818818
)
819819

820-
if self._model_type == ModelType.PHI3_MINI:
820+
if self._model_type in (ModelType.PHI3_MINI, ModelType.PHI4):
821821
logger.warning(
822-
"Saving Phi-3 Mini adapter weights to PEFT format is not supported, saving to torchtune format instead"
822+
"Saving Phi adapter weights to PEFT format is not supported, saving to torchtune format instead"
823823
)
824824
elif self._model_type == ModelType.LLAMA3_VISION:
825825
logger.warning(
@@ -860,9 +860,9 @@ def save_checkpoint(
860860
)
861861

862862
if training.ADAPTER_CONFIG in state_dict:
863-
if self._model_type == ModelType.PHI3_MINI:
863+
if self._model_type in (ModelType.PHI3_MINI, ModelType.PHI4):
864864
logger.warning(
865-
"PEFT integration for Phi-3 Mini is not supported, skipping adapter config save"
865+
"PEFT integration for Phi is not supported, skipping adapter config save"
866866
)
867867
elif self._model_type == ModelType.LLAMA3_VISION:
868868
logger.warning(

‎torchtune/training/checkpointing/_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class ModelType(Enum):
9090
LLAMA3_VISION (str): LLama3 vision family of models. See :func:`~torchtune.models.llama3_2_vision.llama3_2_vision_decoder`
9191
MISTRAL (str): Mistral family of models. See :func:`~torchtune.models.mistral.mistral`
9292
PHI3_MINI (str): Phi-3 family of models. See :func:`~torchtune.models.phi3.phi3`
93+
PHI4 (str): Phi-4 family of models. See :func:`~torchtune.models.phi4.phi4`
9394
REWARD (str): A Llama2, Llama3, or Mistral model with a classification head projecting
9495
to a single class for reward modelling.
9596
See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b`
@@ -113,6 +114,7 @@ class ModelType(Enum):
113114
LLAMA3_VISION: str = "llama3_vision"
114115
MISTRAL: str = "mistral"
115116
PHI3_MINI: str = "phi3_mini"
117+
PHI4: str = "phi4"
116118
REWARD: str = "reward"
117119
QWEN2: str = "qwen2"
118120
CLIP_TEXT: str = "clip_text"

0 commit comments

Comments
 (0)
Please sign in to comment.