Skip to content

Commit 1e65e17

Browse files
committed
deploy: 4e9c2fd
1 parent 9f9713b commit 1e65e17

22 files changed

+107
-72
lines changed

.doctrees/environment.pickle

0 Bytes
Binary file not shown.
104 Bytes
Binary file not shown.
290 Bytes
Binary file not shown.
812 Bytes
Binary file not shown.

_sources/get_started/customization.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Below is a summary of all available customization interfaces and their purposes.
4040

4141
**Signature**:
4242
```python
43-
async def generate_rollout(args, rollout_id, *, evaluation=False) -> RolloutFnTrainOutput | RolloutFnEvalOutput
43+
def generate_rollout(args, rollout_id, data_source, evaluation=False) -> RolloutFnTrainOutput | RolloutFnEvalOutput
4444
```
4545

4646
**Use Cases**:
@@ -140,7 +140,7 @@ class DynamicFilterOutput:
140140

141141
**Signature**:
142142
```python
143-
def buffer_filter(samples: list[list[Sample]]) -> list[list[Sample]]
143+
def buffer_filter(args, rollout_id, buffer: list[list[Sample]], num_samples: int) -> list[list[Sample]]
144144
```
145145

146146
**Use Cases**:
@@ -177,7 +177,7 @@ def filter_function(args, samples: list[Sample]) -> None
177177

178178
**Signature**:
179179
```python
180-
def process_function(args, samples: list[list[Sample]]) -> None
180+
def process_function(args, samples: list[list[Sample]], data_source) -> None
181181
```
182182

183183
**Use Cases**:

_sources/get_started/quick_start.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,12 @@ The filtering function `check_reward_nonzero_std` in the example will check whet
359359

360360
```python
361361
def check_reward_nonzero_std(args, samples: list[Sample], **kwargs):
362-
rewards = [sample.reward for sample in samples]
363-
return torch.tensor(rewards, dtype=torch.float).std() > 0.0
362+
rewards = [sample.get_reward_value(args) for sample in samples]
363+
keep = torch.tensor(rewards, dtype=torch.float).std() > 0.0
364+
return DynamicFilterOutput(
365+
keep=keep,
366+
reason=None if keep else f"zero_std_{round(rewards[0], 1)}",
367+
)
364368
```
365369

366370
If the filtering function is very strict, causing a large number of prompt groups to be discarded, the system will monitor the number of pending tasks in `remaining_batch_size`. Once the number of pending tasks drops below the target number (32) due to too many being discarded, the system will automatically trigger a new round of oversampling, requesting `over_sampling_batch_size` (64) new prompts again to repeat the above process.

_sources/get_started/usage.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When
186186
- `gspo` ([https://arxiv.org/abs/2507.18071](https://arxiv.org/abs/2507.18071))
187187
- `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262))
188188
- `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347))
189-
- `on_policy_distillation`
189+
190+
Note: On-policy distillation (OPD) is now orthogonal to the advantage estimator. Use `--use-opd` and `--opd-kl-coef` to enable OPD on top of any estimator.
190191
- `--calculate-per-token-loss`: By default, slime calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`.
191192
- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl).
192193
- `--true-on-policy-mode`: Enable True On-Policy mode, which strictly ensures that data is generated by the current policy during training.
@@ -266,19 +267,19 @@ slime supports customizing data generation (rollout) to various degrees.
266267
- You can completely replace the `generate_rollout` in sglang\_example.py by using the `--rollout-function-path` parameter. You just need to ensure that the function signature passed via `--rollout-function-path` is as follows:
267268

268269
```python
269-
def generate_rollout(args, rollout_id, data_buffer, evaluation=False) -> list[list[Sample]]:
270+
def generate_rollout(args, rollout_id, data_source, evaluation=False) -> RolloutFnTrainOutput | RolloutFnEvalOutput:
270271
"""
271272
Args:
272273
args: the whole args
273274
rollout_id: int, the id of the rollout, used for deterministic data generation
274-
data_buffer: the data buffer to store the generated samples
275+
data_source: the data source to get and store samples
275276
evaluation: bool, whether the rollout is for evaluation or not
276277
277278
Returns:
278-
list[list[Sample]]: a list of samples generated by the rollout
279+
RolloutFnTrainOutput | RolloutFnEvalOutput: the output of the rollout
279280
"""
280281
...
281-
return samples
282+
return output
282283
```
283284

284285
Where:
@@ -287,7 +288,7 @@ slime supports customizing data generation (rollout) to various degrees.
287288

288289
- `rollout_id`: The ID of the current data generation round, used to ensure data order when resuming training.
289290

290-
- `data_buffer`: A globally unique data buffer in slime, which can be used to get initial prompts, data IDs, and store partially generated samples for later use.
291+
- `data_source`: A globally unique data source in slime, which can be used to get initial prompts, data IDs, and store partially generated samples for later use.
291292

292293
- `evaluation`: A boolean indicating if the rollout is for evaluation. You can configure a separate evaluation function using `--eval-function-path`.
293294

@@ -296,10 +297,7 @@ slime supports customizing data generation (rollout) to various degrees.
296297
- `tokens`: The tokens for the prompt + response.
297298
- `response_length`: The total length of the response. For multi-turn tasks, this is the length of the tokens remaining after the first-turn prompt.
298299
- `reward`: The reward for this data sample.
299-
- `truncated`: Whether this data sample was truncated, similar to `finish_reason == length` in SGLang.
300-
301-
And if there are scenarios like tool calls or multi-turn usage, ensure the `loss_mask` is correct:
302-
300+
- `status`: The status of this data sample (e.g., `Sample.Status.COMPLETED`, `Sample.Status.TRUNCATED`, `Sample.Status.ABORTED`, `Sample.Status.FAILED`).
303301
- `loss_mask` should be the same length as `response_length`, with `1` for tokens that should be included in the loss calculation and `0` for those that should be masked out.
304302

305303
- In some cases, you may only need to replace the data generation logic. You can do this using `--custom-generate-function-path`. A simplified implementation of this function is as follows:
@@ -325,9 +323,14 @@ slime supports customizing data generation (rollout) to various degrees.
325323
# set sample
326324
sample.tokens = prompt_tokens_ids + response_token_ids
327325
sample.response_length = len(response_token_ids)
328-
sample.truncated = output["meta_info"]["finish_reason"]["type"] == "length"
326+
finish_reason = output["meta_info"]["finish_reason"]["type"]
327+
if finish_reason == "length":
328+
sample.status = Sample.Status.TRUNCATED
329+
elif finish_reason == "abort":
330+
sample.status = Sample.Status.ABORTED
331+
else:
332+
sample.status = Sample.Status.COMPLETED
329333
sample.response = output["text"]
330-
sample.aborted = output["meta_info"]["finish_reason"]["type"] == "abort"
331334

332335
return sample
333336
```

get_started/customization.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ <h3>1. Rollout Function (<code class="docutils literal notranslate"><span class=
574574
<p><strong>Default</strong>: <code class="docutils literal notranslate"><span class="pre">slime.rollout.sglang_rollout.generate_rollout</span></code></p>
575575
<p><strong>Purpose</strong>: Override the entire rollout generation logic.</p>
576576
<p><strong>Signature</strong>:</p>
577-
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">async</span> <span class="k">def</span><span class="w"> </span><span class="nf">generate_rollout</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">rollout_id</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">evaluation</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RolloutFnTrainOutput</span> <span class="o">|</span> <span class="n">RolloutFnEvalOutput</span>
577+
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">generate_rollout</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">rollout_id</span><span class="p">,</span> <span class="n">data_source</span><span class="p">,</span> <span class="n">evaluation</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RolloutFnTrainOutput</span> <span class="o">|</span> <span class="n">RolloutFnEvalOutput</span>
578578
</pre></div>
579579
</div>
580580
<p><strong>Use Cases</strong>:</p>
@@ -662,7 +662,7 @@ <h3>5. Buffer Filter (<code class="docutils literal notranslate"><span class="pr
662662
<p><strong>Default</strong>: <code class="docutils literal notranslate"><span class="pre">None</span></code></p>
663663
<p><strong>Purpose</strong>: Filter samples in the rollout buffer before training.</p>
664664
<p><strong>Signature</strong>:</p>
665-
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">buffer_filter</span><span class="p">(</span><span class="n">samples</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">]]</span>
665+
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">buffer_filter</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">rollout_id</span><span class="p">,</span> <span class="n">buffer</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">]],</span> <span class="n">num_samples</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">]]</span>
666666
</pre></div>
667667
</div>
668668
<p><strong>Use Cases</strong>:</p>
@@ -694,7 +694,7 @@ <h3>7. Rollout All Samples Process (<code class="docutils literal notranslate"><
694694
<p><strong>Default</strong>: <code class="docutils literal notranslate"><span class="pre">None</span></code></p>
695695
<p><strong>Purpose</strong>: Process all samples (including filtered ones) after rollout.</p>
696696
<p><strong>Signature</strong>:</p>
697-
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">process_function</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="kc">None</span>
697+
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">process_function</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">]],</span> <span class="n">data_source</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span>
698698
</pre></div>
699699
</div>
700700
<p><strong>Use Cases</strong>:</p>

get_started/quick_start.html

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,12 @@ <h3>Dynamic Sampling<a class="headerlink" href="#dynamic-sampling" title="Link t
847847
<p>Then each sampling will directly sample 64 prompts, and each prompt will be sampled 8 times. Because slime performs asynchronous sampling internally, we will successively obtain 8 responses for each prompt. When receiving responses, the function corresponding to <code class="docutils literal notranslate"><span class="pre">dynamic_sampling_filter_path</span></code> will be used for filtering. If it passes, these 8 pieces of data will be kept; otherwise, they will be discarded.</p>
848848
<p>The filtering function <code class="docutils literal notranslate"><span class="pre">check_reward_nonzero_std</span></code> in the example will check whether the standard deviation of rewards for a group of samples is greater than zero, ensuring that the reward scores of each group of samples left have differences, thereby avoiding overly homogeneous data and improving data diversity.</p>
849849
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">check_reward_nonzero_std</span><span class="p">(</span><span class="n">args</span><span class="p">,</span> <span class="n">samples</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Sample</span><span class="p">],</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
850-
<span class="n">rewards</span> <span class="o">=</span> <span class="p">[</span><span class="n">sample</span><span class="o">.</span><span class="n">reward</span> <span class="k">for</span> <span class="n">sample</span> <span class="ow">in</span> <span class="n">samples</span><span class="p">]</span>
851-
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">rewards</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">std</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.0</span>
850+
<span class="n">rewards</span> <span class="o">=</span> <span class="p">[</span><span class="n">sample</span><span class="o">.</span><span class="n">get_reward_value</span><span class="p">(</span><span class="n">args</span><span class="p">)</span> <span class="k">for</span> <span class="n">sample</span> <span class="ow">in</span> <span class="n">samples</span><span class="p">]</span>
851+
<span class="n">keep</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">rewards</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span><span class="o">.</span><span class="n">std</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.0</span>
852+
<span class="k">return</span> <span class="n">DynamicFilterOutput</span><span class="p">(</span>
853+
<span class="n">keep</span><span class="o">=</span><span class="n">keep</span><span class="p">,</span>
854+
<span class="n">reason</span><span class="o">=</span><span class="kc">None</span> <span class="k">if</span> <span class="n">keep</span> <span class="k">else</span> <span class="sa">f</span><span class="s2">&quot;zero_std_</span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">rewards</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="w"> </span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
855+
<span class="p">)</span>
852856
</pre></div>
853857
</div>
854858
<p>If the filtering function is very strict, causing a large number of prompt groups to be discarded, the system will monitor the number of pending tasks in <code class="docutils literal notranslate"><span class="pre">remaining_batch_size</span></code>. Once the number of pending tasks drops below the target number (32) due to too many being discarded, the system will automatically trigger a new round of oversampling, requesting <code class="docutils literal notranslate"><span class="pre">over_sampling_batch_size</span></code> (64) new prompts again to repeat the above process.</p>

0 commit comments

Comments
 (0)