Skip to content

Commit ec2abf6

Browse files
committed
[transform] Add conversation trace hooks
This adds minimal postprocessing and row-ID hook seams for SFT conversation transforms so Hermes-style trace normalization can land without reshaping the existing adapter stack. It also preserves transformed dataset hash stability when the new hooks are unset and adds regression tests for hook ordering and signature behavior.
1 parent 4c71b82 commit ec2abf6

File tree

4 files changed

+192
-19
lines changed

4 files changed

+192
-19
lines changed

experiments/posttrain/instruction_datasets.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
22. open-thoughts/OpenThoughts3-1.2M # Original OT3 dataset; smoltalk2 uses a slightly different version
3838
"""
3939

40-
import dataclasses
4140
import hashlib
4241
import json
4342
from collections.abc import Sequence
@@ -57,7 +56,13 @@
5756
ConversationToDolmaConfig,
5857
convert_conversation_to_dolma,
5958
)
60-
from marin.transform.conversation.adapters import InputDatasetFormat, TransformAdapter
59+
from marin.transform.conversation.adapters import (
60+
InputDatasetFormat,
61+
MessagePostprocessFn,
62+
RowIdFn,
63+
TransformAdapter,
64+
transform_adapter_signature,
65+
)
6166
from marin.transform.conversation.transform_conversation import (
6267
TransformSFTDatasetConfig,
6368
transform_hf_dataset,
@@ -142,6 +147,8 @@ def multi_turn_adapter(
142147
metadata_remap: dict[str, str] | None = None,
143148
replacements: dict[str, str] | None = None,
144149
extra_metadata_fn=None,
150+
message_postprocess_fn: MessagePostprocessFn | None = None,
151+
row_id_fn: RowIdFn | None = None,
145152
) -> TransformAdapter:
146153
return TransformAdapter(
147154
dataset_format=InputDatasetFormat.SINGLE_COLUMN_MULTI_TURN,
@@ -154,6 +161,8 @@ def multi_turn_adapter(
154161
metadata_remap=metadata_remap or {},
155162
replacements=replacements,
156163
extra_metadata_fn=extra_metadata_fn,
164+
message_postprocess_fn=message_postprocess_fn,
165+
row_id_fn=row_id_fn,
157166
)
158167

159168

@@ -166,6 +175,8 @@ def instruction_response_adapter(
166175
metadata_remap: dict[str, str] | None = None,
167176
replacements: dict[str, str] | None = None,
168177
extra_metadata_fn=None,
178+
message_postprocess_fn: MessagePostprocessFn | None = None,
179+
row_id_fn: RowIdFn | None = None,
169180
) -> TransformAdapter:
170181
return TransformAdapter(
171182
dataset_format=InputDatasetFormat.INSTRUCTION_RESPONSE,
@@ -176,6 +187,8 @@ def instruction_response_adapter(
176187
metadata_remap=metadata_remap or {},
177188
replacements=replacements,
178189
extra_metadata_fn=extra_metadata_fn,
190+
message_postprocess_fn=message_postprocess_fn,
191+
row_id_fn=row_id_fn,
179192
)
180193

181194

@@ -186,6 +199,8 @@ def instruct_column_response_adapter(
186199
metadata_remap: dict[str, str] | None = None,
187200
replacements: dict[str, str] | None = None,
188201
extra_metadata_fn=None,
202+
message_postprocess_fn: MessagePostprocessFn | None = None,
203+
row_id_fn: RowIdFn | None = None,
189204
) -> TransformAdapter:
190205
return TransformAdapter(
191206
dataset_format=InputDatasetFormat.INSTRUCT_COLUMN_RESPONSE,
@@ -195,6 +210,8 @@ def instruct_column_response_adapter(
195210
metadata_remap=metadata_remap or {},
196211
replacements=replacements,
197212
extra_metadata_fn=extra_metadata_fn,
213+
message_postprocess_fn=message_postprocess_fn,
214+
row_id_fn=row_id_fn,
198215
)
199216

200217

@@ -210,6 +227,8 @@ def instruct_msg_response_adapter(
210227
metadata_remap: dict[str, str] | None = None,
211228
replacements: dict[str, str] | None = None,
212229
extra_metadata_fn=None,
230+
message_postprocess_fn: MessagePostprocessFn | None = None,
231+
row_id_fn: RowIdFn | None = None,
213232
) -> TransformAdapter:
214233
return TransformAdapter(
215234
dataset_format=InputDatasetFormat.INSTRUCT_MSG_RESPONSE,
@@ -223,6 +242,8 @@ def instruct_msg_response_adapter(
223242
metadata_remap=metadata_remap or {},
224243
replacements=replacements,
225244
extra_metadata_fn=extra_metadata_fn,
245+
message_postprocess_fn=message_postprocess_fn,
246+
row_id_fn=row_id_fn,
226247
)
227248

228249

@@ -561,26 +582,18 @@ def get_directory_friendly_dataset_name(hf_dataset_id: str) -> str:
561582
return dataset_name
562583

563584

585+
def get_adapter_signature_string(adapter: TransformAdapter) -> str:
586+
"""Return the stable JSON signature used to version transformed instruction datasets."""
587+
return json.dumps(transform_adapter_signature(adapter), sort_keys=True)
588+
589+
564590
def transform_dataset_step(dataset_cfg: InstructionDatasetConfig) -> ExecutorStep:
565591
"""ExecutorStep that preprocesses the input dataset into a canonicalized format for SFT training."""
566592
adapter = dataset_cfg.adapter
567593
output_name = dataset_cfg.name if dataset_cfg.name is not None else dataset_cfg.hf_dataset_id
568594
dataset_name = get_directory_friendly_dataset_name(output_name)
569595

570-
adapter_dict = dataclasses.asdict(adapter)
571-
adapter_dict["dataset_format"] = adapter_dict["dataset_format"].value
572-
573-
def canonicalize(value):
574-
if isinstance(value, dict):
575-
return {k: canonicalize(v) for k, v in sorted(value.items())}
576-
if isinstance(value, list):
577-
return [canonicalize(x) for x in value]
578-
if callable(value):
579-
return f"{value.__module__}.{value.__qualname__}"
580-
return value
581-
582-
adapter_signature = canonicalize(adapter_dict)
583-
adapter_signature_str = json.dumps(adapter_signature, sort_keys=True)
596+
adapter_signature_str = get_adapter_signature_string(adapter)
584597

585598
config_str = f"{dataset_name}-\
586599
{dataset_cfg.revision}\

lib/marin/src/marin/transform/conversation/adapters.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,36 @@ class InputDatasetFormat(str, Enum):
5555
INSTRUCT_MSG_RESPONSE: str = "instruct_msg_response"
5656

5757

58+
MessagePostprocessFn = Callable[[list[OpenAIChatMessage], dict[str, Any]], list[OpenAIChatMessage]]
59+
RowIdFn = Callable[[dict[str, Any], list[dict[str, Any]]], str]
60+
61+
_OPTIONAL_SIGNATURE_FIELDS = frozenset({"message_postprocess_fn", "row_id_fn"})
62+
63+
64+
def _canonicalize_signature_value(value: Any) -> Any:
65+
if isinstance(value, dict):
66+
return {k: _canonicalize_signature_value(v) for k, v in sorted(value.items())}
67+
if isinstance(value, list):
68+
return [_canonicalize_signature_value(x) for x in value]
69+
if callable(value):
70+
return f"{value.__module__}.{value.__qualname__}"
71+
return value
72+
73+
74+
def transform_adapter_signature(adapter: "TransformAdapter") -> dict[str, Any]:
75+
"""Return a stable, JSON-serializable signature for a transform adapter.
76+
77+
Newly added optional trace hooks are omitted when unset so existing dataset output hashes
78+
remain stable.
79+
"""
80+
adapter_dict = dataclasses.asdict(adapter)
81+
adapter_dict["dataset_format"] = adapter_dict["dataset_format"].value
82+
adapter_dict = {
83+
key: value for key, value in adapter_dict.items() if not (key in _OPTIONAL_SIGNATURE_FIELDS and value is None)
84+
}
85+
return _canonicalize_signature_value(adapter_dict)
86+
87+
5888
@dataclass
5989
class TransformAdapter:
6090
dataset_format: InputDatasetFormat = InputDatasetFormat.INSTRUCTION_RESPONSE
@@ -87,6 +117,8 @@ class TransformAdapter:
87117
metadata_remap: dict[str, str] = field(default_factory=dict)
88118
replacements: dict[str, str] | None = None
89119
extra_metadata_fn: Callable[[dict[str, Any]], dict[str, Any]] | None = None
120+
message_postprocess_fn: MessagePostprocessFn | None = None
121+
row_id_fn: RowIdFn | None = None
90122

91123
def transform_conversation_to_openai_format(
92124
self,

lib/marin/src/marin/transform/conversation/transform_conversation.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def transform_row(row: dict, cfg: TransformSFTDatasetConfig, adapter: TransformA
126126
logger.warning(f"{source} returning no valid messages")
127127
return None
128128

129-
transformed_row_messages = [message.model_dump() for message in transformed_row_messages]
129+
if adapter.message_postprocess_fn:
130+
transformed_row_messages = adapter.message_postprocess_fn(transformed_row_messages, row)
130131

131-
# Create a unique ID for the row based on the text
132-
row_idx = generate_hash_from_messages(transformed_row_messages)
132+
transformed_row_messages = [message.model_dump() for message in transformed_row_messages]
133133
metadata_columns = unwrap_versioned_value(cfg.metadata_columns)
134134
metadata_remap = adapter.metadata_remap or {}
135135
replacements = adapter.replacements if adapter.replacements is not None else DEFAULT_TEXT_REPLACEMENTS
@@ -154,6 +154,13 @@ def transform_row(row: dict, cfg: TransformSFTDatasetConfig, adapter: TransformA
154154
transformed_row_messages = [_normalize_tool_structures(message) for message in transformed_row_messages]
155155
else:
156156
transformed_row_messages = [_normalize_tool_structures(message) for message in transformed_row_messages]
157+
158+
if adapter.row_id_fn:
159+
row_idx = adapter.row_id_fn(row, transformed_row_messages)
160+
else:
161+
# Create a unique ID for the row based on the transformed text.
162+
row_idx = generate_hash_from_messages(transformed_row_messages)
163+
157164
if adapter.extra_metadata_fn:
158165
extra_from_fn = adapter.extra_metadata_fn(row)
159166
if extra_from_fn:

tests/transform/test_conversation.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33

44
"""Tests for conversation data transformation scripts."""
55

6+
import dataclasses
7+
import json
68
from pathlib import Path
79

10+
from experiments.posttrain.instruction_datasets import get_adapter_signature_string, instruction_response_adapter
11+
from marin.core.conversation import OpenAIChatMessage
812
from marin.transform.conversation.adapters import InputDatasetFormat, TransformAdapter
913
from marin.transform.conversation.conversation_to_dolma import transform_conversation_to_dolma
1014
from marin.transform.conversation.preference_data_adapters import PreferenceTransformAdapter
@@ -49,6 +53,35 @@
4953
}
5054

5155

56+
def _replace_assistant_message(messages: list[OpenAIChatMessage], row: dict[str, str]) -> list[OpenAIChatMessage]:
57+
return [
58+
messages[0],
59+
OpenAIChatMessage(role="assistant", content=row["postprocessed_response"]),
60+
]
61+
62+
63+
def _row_id_from_source(row: dict[str, str], _messages: list[dict[str, object]]) -> str:
64+
return row["custom_row_id"]
65+
66+
67+
def _legacy_adapter_signature_string(adapter: TransformAdapter) -> str:
68+
adapter_dict = dataclasses.asdict(adapter)
69+
adapter_dict["dataset_format"] = adapter_dict["dataset_format"].value
70+
adapter_dict.pop("message_postprocess_fn", None)
71+
adapter_dict.pop("row_id_fn", None)
72+
73+
def canonicalize(value):
74+
if isinstance(value, dict):
75+
return {key: canonicalize(inner_value) for key, inner_value in sorted(value.items())}
76+
if isinstance(value, list):
77+
return [canonicalize(item) for item in value]
78+
if callable(value):
79+
return f"{value.__module__}.{value.__qualname__}"
80+
return value
81+
82+
return json.dumps(canonicalize(adapter_dict), sort_keys=True)
83+
84+
5285
class TestTransformAdapters:
5386
"""Test the different adapter formats."""
5487

@@ -127,6 +160,94 @@ def test_transform_with_replacements(self):
127160
assert "<|end_think|>" in response_message.content
128161
assert "<think>" not in response_message.content
129162

163+
def test_transform_applies_message_postprocess_before_replacements(self):
164+
"""Test message postprocessing runs before text replacements."""
165+
adapter = TransformAdapter(
166+
dataset_format=InputDatasetFormat.INSTRUCTION_RESPONSE,
167+
instruction_column="instruction",
168+
response_column="response",
169+
replacements={"<think>": "<|start_think|>", "</think>": "<|end_think|>"},
170+
message_postprocess_fn=_replace_assistant_message,
171+
)
172+
173+
row = {
174+
"instruction": "Solve this",
175+
"response": "placeholder",
176+
"postprocessed_response": "<think>Use the replacement path</think>",
177+
}
178+
179+
cfg = TransformSFTDatasetConfig(
180+
source="test/dataset",
181+
revision="main",
182+
output_path="/tmp/output",
183+
metadata_columns=[],
184+
adapter=adapter,
185+
)
186+
187+
result = transform_row(row, cfg, adapter)
188+
189+
assert result is not None
190+
response_message = result.messages[1]
191+
assert response_message.content == "<|start_think|>Use the replacement path<|end_think|>"
192+
193+
def test_transform_uses_row_id_hook(self):
194+
"""Test row ids can come from a source-provided identifier."""
195+
adapter = TransformAdapter(
196+
dataset_format=InputDatasetFormat.INSTRUCTION_RESPONSE,
197+
instruction_column="instruction",
198+
response_column="response",
199+
row_id_fn=_row_id_from_source,
200+
)
201+
202+
row = {
203+
"instruction": "Question",
204+
"response": "Answer",
205+
"custom_row_id": "trace-123",
206+
}
207+
208+
cfg = TransformSFTDatasetConfig(
209+
source="test/dataset",
210+
revision="main",
211+
output_path="/tmp/output",
212+
metadata_columns=[],
213+
adapter=adapter,
214+
)
215+
216+
result = transform_row(row, cfg, adapter)
217+
218+
assert result is not None
219+
assert result.id == "trace-123"
220+
221+
222+
class TestInstructionDatasetAdapterSignatures:
223+
"""Test instruction dataset adapter signature stability."""
224+
225+
def test_signature_omits_unset_trace_hooks(self):
226+
adapter = instruction_response_adapter(
227+
instruction_column="instruction",
228+
response_column="response",
229+
)
230+
231+
signature_string = get_adapter_signature_string(adapter)
232+
signature = json.loads(signature_string)
233+
234+
assert "message_postprocess_fn" not in signature
235+
assert "row_id_fn" not in signature
236+
assert signature_string == _legacy_adapter_signature_string(adapter)
237+
238+
def test_signature_includes_set_trace_hooks(self):
239+
adapter = instruction_response_adapter(
240+
instruction_column="instruction",
241+
response_column="response",
242+
message_postprocess_fn=_replace_assistant_message,
243+
row_id_fn=_row_id_from_source,
244+
)
245+
246+
signature = json.loads(get_adapter_signature_string(adapter))
247+
248+
assert signature["message_postprocess_fn"].endswith("._replace_assistant_message")
249+
assert signature["row_id_fn"].endswith("._row_id_from_source")
250+
130251

131252
class TestPreferenceDataTransform:
132253
"""Test preference data (DPO) transformation."""

0 commit comments

Comments
 (0)