Skip to content

Commit df2eecf

Browse files
authored
add tools and reasoning to marin tokenizer template (#1876)
* enhance ChatProcessor to support system prompts and per-example chat template kwargs, needed to support smoltalk and others * pre-commit * tests * pre-commit * bleck * oops * a more complete chat template for marin * uv * logging * simplify simplify * remap thinking tokens * more datasets * tweak chat template to support tool calls * correct place * loss masking
1 parent fc6bdd3 commit df2eecf

8 files changed

Lines changed: 1018 additions & 520 deletions

File tree

experiments/create_marin_tokenizer.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def load_llama3_tokenizer() -> PreTrainedTokenizer:
126126
{"role": "assistant", "content": "Great!"},
127127
]
128128

129+
SIMPLE_CONVERSATION = [
130+
{"role": "user", "content": "What is 2 + 2?"},
131+
{"role": "assistant", "content": "The answer is 4."},
132+
]
133+
129134

130135
def chat_template_checks(marin_tokenizer: PreTrainedTokenizer):
131136
"""Test that chat template is correctly set."""
@@ -148,10 +153,13 @@ def chat_template_checks(marin_tokenizer: PreTrainedTokenizer):
148153
out = marin_tokenizer.apply_chat_template(
149154
TEST_CONVERSATION, tokenize=True, return_dict=True, return_assistant_tokens_mask=True
150155
)
151-
expected_length = len(marin_tokenizer(REASONING_TRACE_EXAMPLE + "I'm doing well, thanks!")["input_ids"]) + len(
152-
marin_tokenizer("Great!")["input_ids"]
156+
expected_length = (
157+
len(marin_tokenizer(REASONING_TRACE_EXAMPLE + "I'm doing well, thanks!")["input_ids"])
158+
+ len(marin_tokenizer("Great!")["input_ids"])
153159
)
154-
assert np.sum(out["assistant_masks"]) == expected_length
160+
assert (
161+
np.sum(out["assistant_masks"]) == expected_length
162+
), f"Expected {expected_length} assistant tokens, got {np.sum(out['assistant_masks'])}"
155163

156164
"""Test that decoding of assistant tokens is correct."""
157165
out = marin_tokenizer.apply_chat_template(
@@ -161,11 +169,64 @@ def chat_template_checks(marin_tokenizer: PreTrainedTokenizer):
161169
expected_text = REASONING_TRACE_EXAMPLE + "I'm doing well, thanks!<|eot_id|>Great!<|eot_id|>"
162170
assert marin_tokenizer.decode(ids[np.array(out["assistant_masks"]).astype(bool)]) == expected_text
163171

164-
"""Test that add_generation_prompt adds the final newline."""
165172
assert marin_tokenizer.apply_chat_template(TEST_CONVERSATION, tokenize=False, add_generation_prompt=True).endswith(
166173
"<|start_header_id|>assistant<|end_header_id|>\n"
167174
)
168175

176+
print(marin_tokenizer.apply_chat_template(TEST_CONVERSATION, tokenize=False, add_generation_prompt=True))
177+
178+
rendered = marin_tokenizer.apply_chat_template(
179+
SIMPLE_CONVERSATION,
180+
tokenize=False,
181+
add_generation_prompt=True,
182+
enable_thinking=True,
183+
)
184+
assert "Reasoning: /think" in rendered
185+
assert "The answer is 4." in rendered
186+
187+
rendered = marin_tokenizer.apply_chat_template(
188+
SIMPLE_CONVERSATION,
189+
tokenize=False,
190+
add_generation_prompt=False,
191+
enable_thinking=False,
192+
)
193+
assert "Reasoning: /nothink" in rendered
194+
195+
rendered = marin_tokenizer.apply_chat_template(
196+
SIMPLE_CONVERSATION,
197+
tokenize=False,
198+
add_generation_prompt=False,
199+
enable_thinking="experimental",
200+
)
201+
assert "Reasoning: experimental" in rendered
202+
203+
rendered = marin_tokenizer.apply_chat_template(
204+
SIMPLE_CONVERSATION,
205+
tokenize=False,
206+
add_generation_prompt=False,
207+
xml_tools=[
208+
'{"type": "function", "function": {"name": "final_answer", "description": "Provides final answers."}}',
209+
],
210+
python_tools=[
211+
'{"type": "function", "function": {"name": "python_exec", "description": "Execute Python code."}}',
212+
],
213+
enable_thinking=True,
214+
)
215+
assert "### Tools" in rendered
216+
assert "You may call one or more functions" in rendered
217+
assert "<tools>" in rendered
218+
assert "final_answer" in rendered
219+
assert "When you send a message containing Python code" in rendered
220+
assert "python_exec" in rendered
221+
print(rendered)
222+
rendered_tokens = marin_tokenizer.tokenize(rendered)
223+
# print individual tokens with their ids for debugging
224+
for token in rendered_tokens:
225+
token_id = marin_tokenizer.convert_tokens_to_ids(token)
226+
print(f"Token: {token} | ID: {token_id}")
227+
print(len(rendered_tokens))
228+
assert len(rendered_tokens) < 512, "Rendered template is too long!"
229+
169230

170231
def special_tokens_injection_check(marin_tokenizer: PreTrainedTokenizer):
171232
"""Test that special tokens are correctly replaced."""
@@ -181,7 +242,7 @@ def run_all_tests(marin_tokenizer: PreTrainedTokenizer):
181242

182243

183244
# ============ Main function ============
184-
def main():
245+
def main(dry_run: bool = False):
185246
"""
186247
Create and save a modified version of the llama3 tokenizer.
187248
@@ -205,12 +266,16 @@ def main():
205266
marin_tokenizer.save_pretrained(temp_path)
206267
marin_tokenizer = AutoTokenizer.from_pretrained(temp_path, local_files_only=True)
207268

208-
# Run tests to make sure that the tokenizer is modified correctly
209269
run_all_tests(marin_tokenizer)
210270

211-
# Push to huggingface
212-
marin_tokenizer.push_to_hub(marin_tokenizer_hf_path)
271+
if not dry_run:
272+
marin_tokenizer.push_to_hub(marin_tokenizer_hf_path)
213273

214274

215275
if __name__ == "__main__":
216-
main()
276+
import sys
277+
278+
if "--dry-run" in sys.argv:
279+
main(dry_run=True)
280+
else:
281+
main()

experiments/marin_models.py

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,166 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
# flake8: noqa
1415

1516
"""
1617
Various models and templates for Marin.
1718
"""
1819

19-
marin_tokenizer = "stanford-crfm/marin-tokenizer"
20+
marin_tokenizer = "marin-community/marin-tokenizer"
2021
"""
2122
The HF Hub name for the Marin tokenizer.
2223
The Marin tokenizer is (currently) just the Llama 3 tokenizer with a custom chat template (MARIN_CHAT_TEMPLATE).
2324
"""
2425

25-
# to be clear this is the Olmo 2 template except we use llama3's special tokens
26+
# inspired by the smollm3 template and the Olmo 2 template, using llama3's special tokens
2627
MARIN_CHAT_TEMPLATE = """
2728
{{ bos_token }}
29+
{%- if enable_thinking is defined -%}
30+
{%- if enable_thinking is sameas true -%}
31+
{%- set _reasoning_mode = "/think" -%}
32+
{%- elif enable_thinking is sameas false -%}
33+
{%- set _reasoning_mode = "/nothink" -%}
34+
{%- else -%}
35+
{%- set _reasoning_mode = enable_thinking -%}
36+
{%- endif -%}
37+
{%- else -%}
38+
{%- set _reasoning_mode = none -%}
39+
{%- endif -%}
40+
{%- set _custom_instructions = custom_instructions | default(None, true) -%}
41+
{%- set _xml_tools_list = xml_tools | default([], true) -%}
42+
{%- if tools is defined and tools -%}
43+
{%- set _xml_tools_list = tools -%}
44+
{%- endif -%}
45+
{%- set _python_tools = python_tools | default([], true) -%}
46+
{%- set _has_aux_header = (_reasoning_mode is not none) or _custom_instructions or (_xml_tools_list) or (_python_tools) -%}
47+
{%- if _has_aux_header -%}
48+
<|start_header_id|>system<|end_header_id|>
49+
{%- if _reasoning_mode is not none -%}
50+
Reasoning: {{ _reasoning_mode }}
51+
{%- endif %}
52+
{%- if _custom_instructions %}
53+
{{ _custom_instructions | trim }}
54+
{%- endif %}
55+
{% if _xml_tools_list or _python_tools %}
56+
{{ "\n### Tools\n" }}
57+
You may call one or more functions to assist with the user query.
58+
{% if _xml_tools_list %}
59+
You are provided with function signatures within <tools> </tools> tags:
60+
61+
<tools>
62+
{% for tool in _xml_tools_list %}
63+
{{ tool | string }}{% if not loop.last %}
64+
{% endif %}
65+
{% endfor %}
66+
</tools>
67+
68+
For each function call, pass a json object with function name and arguments within <tool_call> </tool_call> tags:
69+
<tool_call>
70+
{"name": <function-name>, "arguments": <args-json-object>}
71+
</tool_call>
72+
73+
{% endif %}
74+
{% if _python_tools %}
75+
When you send a message containing Python code between <|python_tag|> and <|eom_id|> tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output.
76+
77+
You can use the following tools in your python code like regular functions:
78+
<tools>
79+
{% for tool in _python_tools %}
80+
{{ tool | string }}{% if not loop.last %}
81+
{% endif %}
82+
{% endfor %}
83+
</tools>
84+
{% endif %}
85+
{% endif %}
86+
<|eot_id|>
87+
{%- endif -%}
2888
{%- for message in messages -%}
29-
{%- if message['role'] == 'assistant' -%}
30-
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
31-
{% generation %}{{- message['content'] | trim }}<|eot_id|>{% endgeneration %}\n
89+
{%- set has_tool_calls = message.get('tool_calls') is not none and message.get('tool_calls') -%}
90+
{%- if not (message.get('role') in ['tool', 'ipython'] or has_tool_calls) -%}
91+
{%- if message.get('role') == 'assistant' -%}
92+
<|start_header_id|>assistant<|end_header_id|>
93+
{% set content = message.get('content') %}
94+
{% if content is string %}
95+
{% generation %}{{- content | trim }}<|eot_id|>{% endgeneration %}
96+
{% elif content is mapping %}
97+
{% generation %}{{- content.get('text', '') | trim }}<|eot_id|>{% endgeneration %}
98+
{% elif content is iterable %}
99+
{% generation %}
100+
{%- for chunk in content -%}
101+
{%- if chunk.get('type') == 'text' -%}
102+
{{ chunk.get('text', '') | trim }}
103+
{%- endif -%}
104+
{%- endfor -%}
105+
<|eot_id|>
106+
{% endgeneration %}
32107
{% else %}
108+
{% generation %}{% endgeneration %}<|eot_id|>
109+
{% endif %}
110+
{%- else -%}
33111
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
34-
{{ message['content'] | trim }}<|eot_id|>
112+
{% set content = message.get('content') %}
113+
{% if content is string %}
114+
{{ content | trim }}<|eot_id|>
115+
{% elif content is mapping %}
116+
{{ content.get('text', '') | trim }}<|eot_id|>
117+
{% elif content is iterable %}
118+
{%- for chunk in content -%}
119+
{%- if chunk.get('type') == 'text' -%}
120+
{{ chunk.get('text', '') | trim }}
121+
{%- endif -%}
122+
{%- endfor -%}<|eot_id|>
123+
{% else %}
124+
<|eot_id|>
125+
{% endif %}
126+
{%- endif -%}
127+
128+
{%- elif message.get('role') == 'tool' -%}
129+
{%- set _tool_name = message.get('name') -%}
130+
{%- set _tool_id = message.get('tool_call_id') -%}
131+
{%- set _attr_name = ' name=\"' ~ _tool_name ~ '\"' if _tool_name else '' -%}
132+
{%- set _attr_id = ' id=\"' ~ _tool_id ~ '\"' if _tool_id else '' -%}
133+
<|start_header_id|>tool<|end_header_id|>
134+
<tool_response{{ _attr_name }}{{ _attr_id }}>
135+
{%- set tool_content = message.get('content') -%}
136+
{%- if tool_content is mapping or (tool_content is iterable and tool_content is not string) -%}
137+
{{- tool_content | tojson }}
138+
{%- else -%}
139+
{{- tool_content if tool_content is not none else '' }}
140+
{%- endif -%}
141+
</tool_response><|eot_id|>
142+
{{- "\n" -}}
143+
{%- elif message.get('role') == 'ipython' -%}
144+
<|start_header_id|>ipython<|end_header_id|>
145+
{% set ipy_content = message.get('content') %}
146+
{% if ipy_content is string %}
147+
{{- { "output": ipy_content } | tojson -}}
148+
{% elif ipy_content is iterable %}
149+
{%- for chunk in ipy_content -%}
150+
{%- if chunk.get('type') == 'text' -%}
151+
{{- { "output": chunk.get('text', '') } | tojson -}}
152+
{%- endif -%}
153+
{%- endfor -%}
154+
{% else %}
155+
{{- { "output": ipy_content } | tojson -}}
35156
{% endif %}
157+
<|eot_id|>
158+
{% elif has_tool_calls -%}
159+
{%- if message.tool_calls|length != 1 -%}
160+
{{- raise_exception("This template expects exactly one tool call per assistant turn.") -}}
161+
{%- endif -%}
162+
{%- set tool_call = message.tool_calls[0].function -%}
163+
<|start_header_id|>assistant<|end_header_id|>
164+
{% generation %}
165+
{{- '{\"name\": \"' + tool_call.name + '\", \"arguments\": ' -}}
166+
{{- tool_call.arguments | tojson -}}
167+
{{- \"}\" -}}<|eot_id|>
168+
{% endgeneration %}
169+
{%- endif -%}
36170
{%- endfor -%}
37171
{%- if add_generation_prompt -%}
38-
<|start_header_id|>assistant<|end_header_id|>\n{% endif -%}
172+
<|start_header_id|>assistant<|end_header_id|>
173+
{% endif -%}
39174
""".strip()
40175

41176
"""

0 commit comments

Comments
 (0)