-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy pathtest_marin_tokenizer.py
More file actions
40 lines (30 loc) · 1.29 KB
/
test_marin_tokenizer.py
File metadata and controls
40 lines (30 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0
import tempfile
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizer
from experiments.create_marin_tokenizer import (
create_marin_tokenizer,
load_llama3_tokenizer,
special_tokens_injection_check,
)
@pytest.fixture
def marin_tokenizer():
"""Fixture that provides a configured marin tokenizer for testing.
The base llama3 tokenizer lives in a gated Hugging Face repo. When the
current environment lacks credentials (or network access), skip rather
than fail - this test exercises our tokenizer surgery, not HF auth.
"""
try:
llama3_tokenizer = load_llama3_tokenizer()
except Exception as e:
pytest.skip(f"Llama 3 tokenizer is unavailable (gated repo or no network): {e}")
tokenizer = create_marin_tokenizer(llama3_tokenizer)
# Roundtrip write-read to ensure consistency
with tempfile.TemporaryDirectory() as temp_path:
tokenizer.save_pretrained(temp_path)
tokenizer = AutoTokenizer.from_pretrained(temp_path, local_files_only=True)
return tokenizer
def test_special_tokens_injection(marin_tokenizer: PreTrainedTokenizer):
"""Test that special tokens are correctly replaced."""
special_tokens_injection_check(marin_tokenizer)