Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions skyrl/train/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def compute_prompt_mini_batch_boundaries(
train_batch_size: int,
is_stepwise: bool,
n_samples_per_prompt: int,
is_training: bool = True,
) -> List[Tuple[int, int]]:
"""Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list.

Expand All @@ -206,10 +207,12 @@ def compute_prompt_mini_batch_boundaries(
train_batch_size: Number of prompts in a training batch. For sanity check.
is_stepwise: Whether the training is step-wise. For sanity check.
n_samples_per_prompt: how many samples per prompt. For sanity check.
is_training: Whether this is a training batch (strict validation) or eval batch (allows partial batches).
Defaults to True for backward compatibility.
Returns:
List of (start, end) indices of the mini-batches. The length of the list is the number of
mini-batches, guaranteed to be `train_batch_size // mini_batch_size` regardless of whether
the training is step-wise or not.
mini-batches, guaranteed to be `train_batch_size // mini_batch_size` during training, but may differ
during evaluation if the final batch is partial.

Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly
``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible
Expand Down Expand Up @@ -244,23 +247,35 @@ def compute_prompt_mini_batch_boundaries(
prompt_end_indices.append(i)
prompt_end_indices.append(len(uids))

# seen_uids should equal to the number of prompts and equal to `train_batch_size`
# Check that num_prompts matches expected batch size
num_prompts = len(prompt_end_indices)
assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size
assert train_batch_size % mini_batch_size == 0

# Compute boundaries.
if is_training:
assert (
num_prompts == train_batch_size and len(seen_uids) == train_batch_size
), f"Expected {train_batch_size} prompts in training batch, got {num_prompts}."
assert train_batch_size % mini_batch_size == 0
else:
if num_prompts != train_batch_size:
logger.info(
f"Partial batch detected during eval: got {num_prompts} prompts but "
f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries."
)

# Compute boundaries. Handle partial batches during eval.
boundaries: List[Tuple[int, int]] = []
start_seq = 0

for i in range(0, num_prompts, mini_batch_size):
end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index
end_prompt_idx = min(i + mini_batch_size - 1, num_prompts - 1)
end_seq = prompt_end_indices[end_prompt_idx]
boundaries.append((start_seq, end_seq))
start_seq = end_seq
assert len(boundaries) == train_batch_size // mini_batch_size

if is_training:
assert len(boundaries) == train_batch_size // mini_batch_size

# Assert that the mini-batch boundaries are uniform for non-step-wise training.
if not is_stepwise:
if not is_stepwise and is_training:
expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size
for i, (start, end) in enumerate(boundaries):
assert start == i * expected_num_seq_in_mini_batch
Expand Down
20 changes: 17 additions & 3 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,17 @@ def init_weight_sync_state(self):
self.dispatch.init_weight_sync_state(self.inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")

def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
def convert_to_training_input(
self, generator_output: GeneratorOutput, uids: List[str], is_training: bool = True
) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training

Args:
generator_output (GeneratorOutput): Generated rollouts and associated data.
uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
is_training (bool): Whether this batch is for training (strict batch size) or evaluation
(allows partial batches). Defaults to True for backward compatibility.
Returns:
training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
order of `generator_output` and hence `uids`.
Expand Down Expand Up @@ -680,11 +684,21 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
is_stepwise = self.cfg.generator.step_wise_trajectories
training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
uids,
self.cfg.trainer.policy_mini_batch_size,
train_batch_size,
is_stepwise,
n_samples_per_prompt,
is_training=is_training,
)
if self.cfg.trainer.critic.model.path is not None:
training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
uids,
self.cfg.trainer.critic_mini_batch_size,
train_batch_size,
is_stepwise,
n_samples_per_prompt,
is_training=is_training,
)

# 5. Record metadata and metrics.
Expand Down
80 changes: 79 additions & 1 deletion tests/train/test_prompt_mini_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,88 @@ def test_same_step_count_as_non_stepwise(self):
)

assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2

# Non-step-wise boundaries should be uniform
assert non_stepwise_bounds == [(0, 640), (640, 1280)]

def test_eval_partial_batch_nonstepwise(self):
"""Test eval mode with partial batches during non-stepwise training.

This addresses the issue where evaluation crashes when val set size is
not divisible by train_batch_size. With is_training=False, partial
batches should be allowed.
"""
train_batch_size = 4
spp = 2
is_stepwise = False
mini_batch_size = 2

# Only 3 prompts instead of 4 (partial batch)
uids = ["p0", "p0", "p1", "p1", "p2", "p2"]

# Should work fine with is_training=False
boundaries = compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)
# With 3 prompts and mini_batch_size=2, we get 2 mini-batches:
# First mini-batch: prompts 0-1 (sequences 0-4)
# Second mini-batch: prompt 2 (sequences 4-6)
assert boundaries == [(0, 4), (4, 6)]

def test_eval_partial_batch_single_minibatch(self):
"""Test eval mode with partial batch that fits in single mini-batch."""
train_batch_size = 4
spp = 2
is_stepwise = False
mini_batch_size = 2

# Only 1 prompt instead of 4 (very partial batch)
uids = ["p0", "p0"]

boundaries = compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)
# With 1 prompt and mini_batch_size=2, we get 1 mini-batch
assert boundaries == [(0, 2)]

def test_eval_rejects_noncontiguous_uids(self):
"""Test that eval mode still enforces contiguous uids."""
train_batch_size = 4
spp = 2
is_stepwise = False
mini_batch_size = 2
# Non-contiguous uids: p0 appears at index 0-1 and 4-5
uids = ["p0", "p0", "p1", "p1", "p0", "p0"]

with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"):
compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)

def test_eval_stepwise_partial_batch(self):
"""Test eval mode with stepwise training and partial batch."""
mini_batch_size = 2
train_batch_size = 4
spp = 2
is_stepwise = True

# Only 3 prompts instead of 4
uids = _make_uids_stepwise(
[
("p0", 2, [3, 2]), # 5 seqs
("p1", 2, [1, 4]), # 5 seqs
("p2", 2, [2, 1]), # 3 seqs
]
)

# Should work fine with is_training=False
boundaries = compute_prompt_mini_batch_boundaries(
uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False
)
# With 3 prompts and mini_batch_size=2, we get 2 mini-batches:
# First: prompts 0-1 (sequences 0-10)
# Second: prompt 2 (sequences 10-13)
assert boundaries == [(0, 10), (10, 13)]


# ---------------------------------------------------------------------------
# Tests for MeshDispatch.stage_chunks
Expand Down
Loading