Skip to content

[BUG: MistralCommonTokenizer from transformers is not supported by trl SFT] #148

@DzmitryPihulski

Description

@DzmitryPihulski

Python -VV

Python 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]

Pip Freeze

absl-py==2.3.1
accelerate==1.11.0
aiofiles==24.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.12.14
aiosignal==1.4.0
airportsdata==20250706
annotated-doc==0.0.3
annotated-types==0.7.0
anyio==4.9.0
astor==0.8.1
async-timeout==5.0.1
attrs==25.4.0
Authlib==1.6.5
bitsandbytes==0.48.1
blake3==1.0.5
Brotli==1.1.0
cachetools==6.1.0
cbor2==5.7.1
certifi==2025.10.5
cffi==2.0.0
chardet==5.2.0
charset-normalizer==3.4.4
click==8.2.1
cloudpickle==3.1.1
colorama==0.4.6
compressed-tensors==0.10.2
cryptography==46.0.3
cupy-cuda12x==13.5.1
DataProperty==1.1.0
datasets==4.0.0
deepspeed==0.18.1
depyf==0.18.0
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
dnspython==2.8.0
einops==0.8.1
email_validator==2.2.0
evaluate==0.4.6
exceptiongroup==1.3.0
fastapi==0.116.1
fastapi-cli==0.0.8
fastapi-cloud-cli==0.1.4
fastrlock==0.8.3
ffmpy==0.6.4
filelock==3.18.0
flash_attn==2.8.3
frozendict==2.4.6
frozenlist==1.7.0
fsspec==2025.3.0
gguf==0.17.1
gitdb==4.0.12
GitPython==3.1.45
gradio==5.49.1
gradio_client==1.13.3
groovy==0.1.2
grpcio==1.76.0
h11==0.16.0
hf-xet==1.1.5
hjson==3.1.0
httpcore==1.0.9
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.33.4
idna==3.11
interegular==0.3.3
itsdangerous==2.2.0
Jinja2==3.1.6
jiter==0.10.0
joblib==1.5.2
jsonlines==4.0.0
jsonschema==4.24.1
jsonschema-specifications==2025.9.1
lark==1.2.2
liger_kernel==0.6.2
llguidance==0.7.30
llvmlite==0.44.0
lm-format-enforcer==0.10.11
lm_eval==0.4.9.1
lxml==6.0.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mbstrdecoder==1.1.4
mdurl==0.1.2
mistral_common==1.8.1
more-itertools==10.8.0
mpmath==1.3.0
msgpack==1.1.1
msgspec==0.19.0
multidict==6.6.3
multiprocess==0.70.16
narwhals==2.9.0
nest-asyncio==1.6.0
networkx==3.3
ninja==1.11.1.4
nltk==3.9.2
numba==0.61.2
numexpr==2.14.1
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
openai==1.90.0
openai-harmony==0.0.4
opencv-python-headless==4.12.0.88
orjson==3.11.4
outlines==0.1.11
outlines_core==0.1.26
packaging==25.0
pandas==2.3.1
partial-json-parser==0.2.1.1.post6
pathvalidate==3.3.1
peft==0.17.1
pillow==11.3.0
platformdirs==4.5.0
plotly==6.3.1
portalocker==3.2.0
prometheus-fastapi-instrumentator==7.1.0
prometheus_client==0.22.1
propcache==0.3.2
protobuf==6.31.1
psutil==7.1.2
py-cpuinfo==9.0.0
pyarrow==21.0.0
pybase64==1.4.1
pybind11==3.0.1
pycountry==24.6.1
pycparser==2.23
pydantic==2.11.7
pydantic-extra-types==2.10.5
pydantic_core==2.33.2
pydub==0.25.1
Pygments==2.19.2
pytablewriter==1.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.1.1
python-json-logger==3.3.0
python-multipart==0.0.20
pytz==2025.2
PyYAML==6.0.3
pyzmq==27.0.0
ray==2.47.1
referencing==0.37.0
regex==2024.11.6
requests==2.32.4
rich==14.0.0
rich-toolkit==0.14.8
rignore==0.6.2
rouge_score==0.1.2
rpds-py==0.28.0
ruff==0.14.2
sacrebleu==2.5.1
safehttpx==0.1.7
safetensors==0.5.3
scikit-learn==1.7.2
scipy==1.15.3
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==2.33.0
setproctitle==1.3.7
shellingham==1.5.4
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
soundfile==0.13.1
soxr==1.0.0
sqlitedict==2.1.0
starlette==0.47.1
sympy==1.14.0
tabledata==1.3.4
tabulate==0.9.0
tcolorpy==0.1.7
threadpoolctl==3.6.0
tiktoken==0.9.0
tokenizers==0.21.2
tomlkit==0.13.3
torch==2.7.0
torchaudio==2.7.0
torchvision==0.22.0
tqdm==4.67.1
tqdm-multiprocess==0.0.11
trackio==0.7.0
transformers==4.53.2
triton==3.3.0
trl==0.24.0
typepy==1.3.4
typer==0.16.0
typing-inspection==0.4.1
typing_extensions==4.14.1
tzdata==2025.2
urllib3==2.5.0
uvicorn==0.35.0
uvloop==0.21.0
vllm==0.9.2
wandb==0.22.2
watchfiles==1.1.0
websockets==15.0.1
word2number==1.1
xformers==0.0.30
xgrammar==0.1.19
xxhash==3.5.0
yarl==1.20.1
zstandard==0.25.0

Reproduction Steps

  1. In SLURM job, when using apptainer .sif image
  2. trying to run SFTTrainer with mistralai/Magistral-Small-2509
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
from trl import SFTConfig, SFTTrainer

...
model_name = "mistralai/Magistral-Small-2509"
tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name,
        trust_remote_code=True,
        use_fast=True,
        tokenizer_type="mistral"
    )

model = Mistral3ForConditionalGeneration.from_pretrained(
    model_name,
    trust_remote_code=True,
    dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

sft_config = SFTConfig(...)
trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        args=sft_config,
       ...
    )

Output:

...
INFO:mistral_common.tokens.tokenizers.tekken:Vocab size: 150000
INFO:mistral_common.tokens.tokenizers.tekken:Cutting vocab to first 130072 tokens.
...
Traceback (most recent call last):
  File "/lustre/tmp/slurm/4191161/mount/src/training.py", line 167, in <module>
    main()
  File "/lustre/tmp/slurm/4191161/mount/src/training.py", line 155, in main
    launch_sft_train(
  File "/lustre/tmp/slurm/4191161/mount/src/training.py", line 101, in launch_sft_train
    trainer = SFTTrainer(
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py", line 626, in __init__
    raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
TypeError: The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`
srun: error: r12-10: task 0: Exited with exit code 1

Expected Behavior

Expected the training to start with the provided model and tokenizer.

Additional Context

The tokenizer class for "mistralai/Magistral-Small-2509" is implemented in https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_mistral_common.py
But it looks like it inherit not from the expected classes: PreTrainedTokenizerBase or a ProcessorMixin, but from PushToHubMixin.

So the main question here is why it is implemented that way?

As I understand this issue was also spotted by unsloth in their version of SFT code, they used their own copy of the model https://huggingface.co/unsloth/Magistral-Small-2509-GGUF with the tokenizer created with Llama tokenizer transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast

This tokenizer works fine with the model. As I understand it preserves the text:token_id mapping of the original tokenizer.

Suggested Solutions

Maybe somehow rewrite this tokenizer class to something common, so it will be available for other things like SFT.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions