Skip to content

Commit 051d7df

Browse files
authored
Merge pull request #61 from awslabs/feat/slime-data-contract
feat(slime): support arbitrary agent payload shapes in the training backend
2 parents 1a34823 + b420d84 commit 051d7df

5 files changed

Lines changed: 130 additions & 48 deletions

File tree

src/agentcore_rl_toolkit/backends/slime/SETUP.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,22 @@ ds = load_dataset('openai/gsm8k', 'main', split='train')
112112
with open('/path/to/gsm8k_tiny.jsonl', 'w') as f:
113113
for i, row in enumerate(ds):
114114
if i >= 64: break
115+
question = row['question']
115116
answer = row['answer'].split('####')[-1].strip()
116-
f.write(json.dumps({'prompt': row['question'], 'label': answer}) + '\n')
117+
# Top-level 'prompt' is read by slime (tokenization, length filter).
118+
# 'metadata' is the agent payload verbatim — shape it however the agent expects.
119+
f.write(json.dumps({
120+
'prompt': question,
121+
'metadata': {'prompt': question, 'answer': answer},
122+
}) + '\n')
117123
"
118124
```
119125

126+
The agent-visible payload is exactly the contents of ``metadata``, so
127+
different agents can use different payload shapes (e.g. ``{'task_id': ...}``
128+
for AppWorld, ``{'repo_uri': ..., 'metadata_uri': ..., ...}`` for
129+
migration) without any slime-side changes.
130+
120131
### 3.3 Configure deployment settings
121132

122133
```bash

src/agentcore_rl_toolkit/backends/slime/examples/math_agent/train.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ ray job submit --address="http://127.0.0.1:8265" \
8080
--tensor-model-parallel-size ${TP_SIZE} \
8181
--rollout-num-gpus-per-engine ${ROLLOUT_GPUS_PER_ENGINE} \
8282
--input-key prompt \
83-
--label-key label \
8483
--rollout-batch-size 32 \
8584
--n-samples-per-prompt 8 \
8685
--rollout-max-response-len 1024 \

src/agentcore_rl_toolkit/backends/slime/integration/rollout.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -209,52 +209,20 @@ def _ensure_initialized(args: Namespace):
209209

210210

211211
def _sample_to_payload(sample) -> dict:
212-
"""Convert a slime Sample to an ART invocation payload.
213-
214-
Extracts all non-None public fields from the Sample into the payload.
215-
The agent's @rollout_entrypoint receives this dict as `payload`.
212+
"""The agent payload is the JSONL row's ``metadata`` dict, verbatim.
213+
214+
slime's Dataset reads the JSONL row's ``metadata`` field into
215+
``Sample.metadata``; we hand that dict to the agent unchanged. The JSONL's
216+
top-level ``prompt`` field is for slime (tokenization, length filtering);
217+
the agent's payload shape is entirely defined by whatever the data author
218+
put in ``metadata``. A shallow copy isolates the agent's view from
219+
downstream mutations to ``Sample.metadata`` (e.g. ``task_metadata``
220+
injection in ``_process_one_episode``).
216221
"""
217-
payload = {}
218-
219-
if hasattr(sample, "prompt") and sample.prompt:
220-
payload["prompt"] = sample.prompt
221-
if hasattr(sample, "label") and sample.label is not None:
222-
payload["answer"] = sample.label
223-
if hasattr(sample, "metadata") and sample.metadata:
224-
payload["metadata"] = sample.metadata
225-
226-
if hasattr(sample, "to_dict"):
227-
for key, value in sample.to_dict().items():
228-
if (
229-
key not in payload
230-
and value is not None
231-
and key
232-
not in (
233-
"tokens",
234-
"rollout_log_probs",
235-
"loss_mask",
236-
"teacher_log_probs",
237-
"rollout_routed_experts",
238-
"multimodal_inputs",
239-
"multimodal_train_inputs",
240-
"group_index",
241-
"index",
242-
"status",
243-
"session_id",
244-
"spec_info",
245-
"prefix_cache_info",
246-
"response_length",
247-
"response",
248-
"weight_versions",
249-
"remove_sample",
250-
"non_generation_time",
251-
"generate_function_path",
252-
"train_metadata",
253-
)
254-
):
255-
payload[key] = value
256-
257-
return payload
222+
metadata = getattr(sample, "metadata", None)
223+
if isinstance(metadata, dict):
224+
return dict(metadata)
225+
return {}
258226

259227

260228
# ---------------------------------------------------------------------------
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Tests for the slime backend's agent-payload conversion.
2+
3+
``_sample_to_payload`` is the contract between slime's ``Sample`` object
4+
and the agent's ``@rollout_entrypoint`` payload. The rule: the agent
5+
receives ``Sample.metadata`` verbatim (shallow-copied).
6+
"""
7+
8+
from types import SimpleNamespace
9+
10+
from agentcore_rl_toolkit.backends.slime.integration.rollout import _sample_to_payload
11+
12+
13+
def test_metadata_is_returned_verbatim():
14+
"""Core contract: the agent payload is Sample.metadata as-is.
15+
16+
Fields on Sample outside metadata (prompt, label) are slime's own
17+
concern and must not leak into the payload.
18+
"""
19+
sample = SimpleNamespace(
20+
prompt="slime-side prompt",
21+
label="slime-side label",
22+
metadata={"task_id": "t1", "answer": "42"},
23+
)
24+
assert _sample_to_payload(sample) == {"task_id": "t1", "answer": "42"}
25+
26+
27+
def test_returned_dict_is_a_shallow_copy():
28+
"""Mutations to the payload must not leak back into Sample.metadata.
29+
30+
``_process_one_episode`` later injects keys into ``Sample.metadata``
31+
(e.g. ``task_metadata``); the agent's view must stay stable.
32+
"""
33+
metadata = {"prompt": "hi"}
34+
sample = SimpleNamespace(metadata=metadata)
35+
36+
payload = _sample_to_payload(sample)
37+
payload["injected"] = True
38+
39+
assert "injected" not in metadata
40+
41+
42+
def test_missing_or_invalid_metadata_returns_empty_dict():
43+
"""Defensive fallback when metadata is absent or not a dict."""
44+
for sample in [
45+
SimpleNamespace(), # attribute absent
46+
SimpleNamespace(metadata=None), # explicit None
47+
SimpleNamespace(metadata="not a dict"), # wrong type
48+
]:
49+
assert _sample_to_payload(sample) == {}

uv.lock

Lines changed: 56 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)