Skip to content

Commit 81583ca

Browse files
committed
[datasets] Add Hermes trace SFT integration
Normalize Hermes tool responses into Marin's chat format while preserving raw think/tool-call assistant turns. Register the glm-5.1 and kimi splits, add focused fixtures and regression tests, and add a trace-focused pilot SFT experiment.
1 parent ec2abf6 commit 81583ca

File tree

9 files changed

+597
-0
lines changed

9 files changed

+597
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Trace-focused Hermes SFT pilot built on the SmolTalk2 + Nemotron recipe."""
5+
6+
import dataclasses
7+
import math
8+
import re
9+
10+
from levanter.data.text import ChatLmDatasetFormat
11+
12+
from experiments.defaults import default_sft, default_tokenize
13+
from experiments.evals.evals import default_sft_eval
14+
from experiments.llama import llama_8b
15+
from experiments.marin_models import marin_tokenizer
16+
from experiments.posttrain.instruction_datasets import INSTRUCTION_DATASET_NAME_TO_CONFIG, get_instruction_dataset
17+
from experiments.simple_sft_config import SimpleSFTConfig
18+
from fray.cluster import ResourceConfig
19+
from marin.execution.executor import ExecutorStep, executor_main
20+
from marin.processing.tokenize import lm_mixture_data_config
21+
22+
SLUGIFY_PATTERN = re.compile(r"[^a-z0-9]+")
23+
TARGET_EPOCHS = 3
24+
TRAIN_BATCH_SIZE = 2048
25+
26+
# Row counts captured on 2026-04-16 from the Hugging Face dataset page / datasets-server.
27+
TRACE_PILOT_DATASETS = {
28+
"smoltalk2_smolagents_toolcalling_traces_think": (
29+
"HuggingFaceTB/smoltalk2/smolagents_toolcalling_traces_think",
30+
9079,
31+
),
32+
"smoltalk2_hermes_function_calling_v1_no_think": (
33+
"HuggingFaceTB/smoltalk2/hermes_function_calling_v1_no_think",
34+
8961,
35+
),
36+
"smoltalk2_xlam_traces_no_think": (
37+
"HuggingFaceTB/smoltalk2/xlam_traces_no_think",
38+
59962,
39+
),
40+
"nemotron_v2_chat": ("nvidia/Nemotron-Post-Training-Dataset-v2/chat", 627720),
41+
"nemotron_v2_code": ("nvidia/Nemotron-Post-Training-Dataset-v2/code", 175000),
42+
"hermes_glm_5_1": ("lambda/hermes-agent-reasoning-traces/glm-5.1", 7055),
43+
"hermes_kimi": ("lambda/hermes-agent-reasoning-traces/kimi", 7646),
44+
}
45+
46+
47+
def _slugify(value: str) -> str:
48+
slug = SLUGIFY_PATTERN.sub("_", value.lower()).strip("_")
49+
return slug or "dataset"
50+
51+
52+
def create_tokenization_step(dataset_identifier: str, short_name: str) -> ExecutorStep:
53+
dataset_config = INSTRUCTION_DATASET_NAME_TO_CONFIG[dataset_identifier]
54+
dataset = get_instruction_dataset(dataset_identifier, splits=dataset_config.splits)
55+
return default_tokenize(
56+
name=f"{short_name}_marin_tokenizer",
57+
dataset=dataset / "**/*.jsonl.gz",
58+
tokenizer=marin_tokenizer,
59+
format=ChatLmDatasetFormat(),
60+
)
61+
62+
63+
dataset_ids = {
64+
_slugify(short_name): dataset_identifier for short_name, (dataset_identifier, _count) in TRACE_PILOT_DATASETS.items()
65+
}
66+
mixture_weights = {
67+
_slugify(short_name): row_count for short_name, (_dataset_identifier, row_count) in TRACE_PILOT_DATASETS.items()
68+
}
69+
tokenized_datasets = {
70+
short_name: create_tokenization_step(dataset_identifier, short_name)
71+
for short_name, dataset_identifier in dataset_ids.items()
72+
}
73+
74+
assert set(tokenized_datasets.keys()) == set(mixture_weights.keys())
75+
76+
total_examples = sum(mixture_weights.values())
77+
num_train_steps = math.ceil(TARGET_EPOCHS * total_examples / TRAIN_BATCH_SIZE)
78+
79+
pilot_sft_config = SimpleSFTConfig(
80+
train_batch_size=TRAIN_BATCH_SIZE,
81+
num_train_steps=num_train_steps,
82+
learning_rate=1e-5,
83+
resources=ResourceConfig.with_tpu("v4-128"),
84+
tokenizer=marin_tokenizer,
85+
initialize_from_hf="marin-community/marin-8b-base",
86+
max_seq_len=8192,
87+
seed=0,
88+
)
89+
90+
pilot_mixture = lm_mixture_data_config(
91+
tokenized_datasets,
92+
mixture_weights,
93+
shuffle=True,
94+
missing_weights_are_validation=True,
95+
)
96+
97+
llama_8b_8k = dataclasses.replace(llama_8b, max_seq_len=8192)
98+
99+
marin_8b_sft_hermes_trace_pilot = default_sft(
100+
name="marin_8b_sft_hermes_trace_pilot",
101+
tokenized=pilot_mixture,
102+
model_config=llama_8b_8k,
103+
sft_config=pilot_sft_config,
104+
tags=["llama", "smoltalk2", "nemotron_v2", "hermes_trace", "sft"],
105+
)
106+
107+
marin_8b_sft_hermes_trace_pilot_evals = default_sft_eval(
108+
marin_8b_sft_hermes_trace_pilot,
109+
use_levanter_inference=True,
110+
resource_config=ResourceConfig.with_tpu("v4-8"),
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
executor_main(steps=[marin_8b_sft_hermes_trace_pilot, *marin_8b_sft_hermes_trace_pilot_evals])

experiments/posttrain/instruction_datasets.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
20. nvidia/Nemotron-Post-Training-Dataset-v2
3636
21. HuggingFaceH4/no_robots
3737
22. open-thoughts/OpenThoughts3-1.2M # Original OT3 dataset; smoltalk2 uses a slightly different version
38+
23. lambda/hermes-agent-reasoning-traces
3839
"""
3940

4041
import hashlib
@@ -67,6 +68,10 @@
6768
TransformSFTDatasetConfig,
6869
transform_hf_dataset,
6970
)
71+
from marin.transform.conversation.trace_normalization import (
72+
hermes_trace_row_id,
73+
normalize_hermes_trace_messages,
74+
)
7075

7176
SMOLTALK2_SPLITS = [
7277
"LongAlign_64k_Qwen3_32B_yarn_131k_think",
@@ -109,6 +114,7 @@
109114
]
110115

111116
NEMOTRON_V1_SPLITS = ["chat", "code", "math", "stem", "tool_calling"]
117+
HERMES_TRACE_REVISION = "aa7c93605c71578869938359075b1765cf1b26e1"
112118

113119

114120
@dataclass(frozen=True)
@@ -309,6 +315,42 @@ def __call__(self, row: dict[str, Any]) -> dict[str, Any]:
309315
metadata_columns=["id", "category", "source"],
310316
name="teknium/OpenHermes-2.5",
311317
),
318+
"lambda/hermes-agent-reasoning-traces/glm-5.1": InstructionDatasetConfig(
319+
hf_dataset_id="lambda/hermes-agent-reasoning-traces",
320+
revision=HERMES_TRACE_REVISION,
321+
adapter=multi_turn_adapter(
322+
conversation_column="conversations",
323+
role_key="from",
324+
user_value="human",
325+
assistant_value="gpt",
326+
system_value="system",
327+
content_key="value",
328+
message_postprocess_fn=normalize_hermes_trace_messages,
329+
row_id_fn=hermes_trace_row_id,
330+
),
331+
metadata_columns=["category", "subcategory", "task"],
332+
name="lambda/hermes-agent-reasoning-traces/glm-5.1",
333+
subsets=["glm-5.1"],
334+
splits=["train"],
335+
),
336+
"lambda/hermes-agent-reasoning-traces/kimi": InstructionDatasetConfig(
337+
hf_dataset_id="lambda/hermes-agent-reasoning-traces",
338+
revision=HERMES_TRACE_REVISION,
339+
adapter=multi_turn_adapter(
340+
conversation_column="conversations",
341+
role_key="from",
342+
user_value="human",
343+
assistant_value="gpt",
344+
system_value="system",
345+
content_key="value",
346+
message_postprocess_fn=normalize_hermes_trace_messages,
347+
row_id_fn=hermes_trace_row_id,
348+
),
349+
metadata_columns=["category", "subcategory", "task"],
350+
name="lambda/hermes-agent-reasoning-traces/kimi",
351+
subsets=["kimi"],
352+
splits=["train"],
353+
),
312354
"allenai/tulu-v2-sft-mixture-olmo-4096": InstructionDatasetConfig(
313355
hf_dataset_id="allenai/tulu-v2-sft-mixture-olmo-4096",
314356
revision="7a7c388",
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Normalization helpers for trace-like conversation datasets."""
5+
6+
import hashlib
7+
import json
8+
import re
9+
from typing import Any
10+
11+
from marin.core.conversation import OpenAIChatMessage
12+
13+
_HERMES_TOOL_RESPONSE_RE = re.compile(
14+
r"^\s*<tool_response(?P<attrs>[^>]*)>\s*(?P<body>.*?)\s*</tool_response>\s*$",
15+
re.DOTALL,
16+
)
17+
_HERMES_TOOL_RESPONSE_ATTR_RE = re.compile(r"""(?P<key>name|id)\s*=\s*(?P<quote>["'])(?P<value>.*?)(?P=quote)""")
18+
19+
20+
def _hash_messages(messages: list[dict[str, Any]]) -> str:
21+
return hashlib.sha256(str(messages).encode()).hexdigest()
22+
23+
24+
def _string_or_none(value: Any) -> str | None:
25+
return value if isinstance(value, str) and value else None
26+
27+
28+
def _parse_tool_response_attrs(attrs: str) -> tuple[str | None, str | None]:
29+
name: str | None = None
30+
tool_call_id: str | None = None
31+
32+
for match in _HERMES_TOOL_RESPONSE_ATTR_RE.finditer(attrs):
33+
key = match.group("key")
34+
value = match.group("value")
35+
if key == "name":
36+
name = value
37+
elif key == "id":
38+
tool_call_id = value
39+
40+
return name, tool_call_id
41+
42+
43+
def _parse_tool_response_body(
44+
body: str,
45+
*,
46+
name: str | None,
47+
tool_call_id: str | None,
48+
) -> tuple[str | None, str | None, Any] | None:
49+
try:
50+
payload = json.loads(body)
51+
except json.JSONDecodeError:
52+
return None
53+
54+
if not isinstance(payload, dict):
55+
return None
56+
57+
normalized_name = name or _string_or_none(payload.get("name"))
58+
normalized_tool_call_id = tool_call_id or _string_or_none(payload.get("tool_call_id"))
59+
if "content" in payload:
60+
return normalized_name, normalized_tool_call_id, payload["content"]
61+
62+
return normalized_name, normalized_tool_call_id, payload
63+
64+
65+
def _normalize_hermes_tool_response_message(message: OpenAIChatMessage) -> OpenAIChatMessage:
66+
if message.role != "tool" or not isinstance(message.content, str):
67+
return message
68+
69+
match = _HERMES_TOOL_RESPONSE_RE.fullmatch(message.content)
70+
if match is None:
71+
return message
72+
73+
name, tool_call_id = _parse_tool_response_attrs(match.group("attrs"))
74+
parsed_body = _parse_tool_response_body(
75+
match.group("body").strip(),
76+
name=name,
77+
tool_call_id=tool_call_id,
78+
)
79+
if parsed_body is None:
80+
return message
81+
82+
normalized_name, normalized_tool_call_id, normalized_content = parsed_body
83+
return message.model_copy(
84+
update={
85+
"content": normalized_content,
86+
"name": normalized_name,
87+
"tool_call_id": normalized_tool_call_id,
88+
}
89+
)
90+
91+
92+
def normalize_hermes_trace_messages(
93+
messages: list[OpenAIChatMessage],
94+
row: dict[str, Any],
95+
) -> list[OpenAIChatMessage]:
96+
"""Normalize Hermes trace messages for Marin's conversation pipeline.
97+
98+
Hermes assistant turns already contain the desired `<think>` and `<tool_call>` blocks, so we
99+
leave them untouched. Tool turns arrive wrapped in `<tool_response>` tags; Marin's chat
100+
template would wrap those again, so we strip only the outer wrapper when the payload parses
101+
cleanly. If wrapper parsing fails, we preserve the raw source content unchanged.
102+
"""
103+
104+
return [_normalize_hermes_tool_response_message(message) for message in messages]
105+
106+
107+
def hermes_trace_row_id(row: dict[str, Any], messages: list[dict[str, Any]]) -> str:
108+
"""Return the source trace ID when available, otherwise fall back to the message hash."""
109+
110+
source_id = row.get("id")
111+
if isinstance(source_id, str) and source_id:
112+
return source_id
113+
return _hash_messages(messages)

tests/test_marin_chat_template.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright The Marin Authors
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import json
45
import tempfile
56
from collections.abc import Sequence
7+
from pathlib import Path
68

79
import pytest
810

@@ -11,8 +13,16 @@
1113
load_llama3_tokenizer,
1214
run_all_tests,
1315
)
16+
from experiments.posttrain.instruction_datasets import INSTRUCTION_DATASET_NAME_TO_CONFIG
1417
from levanter.data.text import ChatProcessor
1518
from levanter.tokenizers import load_tokenizer
19+
from marin.transform.conversation.transform_conversation import TransformSFTDatasetConfig, transform_row
20+
21+
FIXTURE_DIR = Path(__file__).parent / "transform" / "fixtures" / "agent_traces"
22+
23+
24+
def _load_agent_trace_fixture(name: str) -> dict:
25+
return json.loads((FIXTURE_DIR / name).read_text(encoding="utf-8"))
1626

1727

1828
@pytest.fixture()
@@ -113,3 +123,32 @@ def test_marin_chat_template_ipython_output(fresh_marin_tokenizer):
113123
assert "<|start_header_id|>ipython<|end_header_id|>" in rendered
114124
assert '{"output": "4\\n"}' in rendered
115125
assert result["assistant_masks"].sum() > 0
126+
127+
128+
def test_marin_chat_template_normalizes_hermes_tool_responses(fresh_marin_tokenizer):
129+
tokenizer = fresh_marin_tokenizer
130+
processor = ChatProcessor(tokenizer, mask_user_turns=True)
131+
132+
dataset_cfg = INSTRUCTION_DATASET_NAME_TO_CONFIG["lambda/hermes-agent-reasoning-traces/glm-5.1"]
133+
row = _load_agent_trace_fixture("hermes_glm_sample.json")
134+
cfg = TransformSFTDatasetConfig(
135+
source=dataset_cfg.hf_dataset_id,
136+
revision=dataset_cfg.revision,
137+
output_path="/tmp/output",
138+
metadata_columns=dataset_cfg.metadata_columns,
139+
adapter=dataset_cfg.adapter,
140+
subsets=dataset_cfg.subsets,
141+
splits=dataset_cfg.splits,
142+
)
143+
transformed = transform_row(row, cfg, dataset_cfg.adapter)
144+
assert transformed is not None
145+
146+
batch = [{"messages": [message.model_dump() for message in transformed.messages]}]
147+
result = processor(batch)[0]
148+
rendered = decode_sequence(tokenizer, result["input_ids"])
149+
150+
assert "<|start_think|>" in rendered
151+
assert '<tool_response name="write_file" id="glm-tool-call-001">' in rendered
152+
assert '"bytes_written": 15' in rendered
153+
assert '<tool_response>\n{"tool_call_id": "glm-tool-call-001"' not in rendered
154+
assert result["assistant_masks"].sum() > 0
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"id": "glm-trace-001",
3+
"conversations": [
4+
{
5+
"from": "system",
6+
"value": "You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags.\n<tools>\n[{\"name\": \"write_file\", \"description\": \"Write content to a file\", \"parameters\": {\"type\": \"object\", \"properties\": {\"path\": {\"type\": \"string\"}, \"content\": {\"type\": \"string\"}}, \"required\": [\"path\", \"content\"]}}]\n</tools>"
7+
},
8+
{
9+
"from": "human",
10+
"value": "Write a tiny Python script that prints hello."
11+
},
12+
{
13+
"from": "gpt",
14+
"value": "<think>\nI should write the file first.\n</think>\n<tool_call>\n{\"name\": \"write_file\", \"arguments\": {\"path\": \"hello.py\", \"content\": \"print('hello')\\n\"}}\n</tool_call>"
15+
},
16+
{
17+
"from": "tool",
18+
"value": "<tool_response>\n{\"tool_call_id\": \"glm-tool-call-001\", \"name\": \"write_file\", \"content\": {\"bytes_written\": 15, \"dirs_created\": false}}\n</tool_response>"
19+
},
20+
{
21+
"from": "gpt",
22+
"value": "<think>\nThe file was written successfully.\n</think>\nThe script is ready."
23+
}
24+
],
25+
"tools": "[{\"name\": \"write_file\", \"description\": \"Write content to a file\", \"parameters\": {\"type\": \"object\", \"properties\": {\"path\": {\"type\": \"string\"}, \"content\": {\"type\": \"string\"}}, \"required\": [\"path\", \"content\"]}}]",
26+
"category": "Terminal & Coding",
27+
"subcategory": "Terminal Tasks",
28+
"task": "Write a tiny Python script that prints hello."
29+
}

0 commit comments

Comments
 (0)