Skip to content

Conversation

@joellidin
Copy link
Collaborator

What happened

Training was randomly crashing with IndexError: index -1 is out of bounds for dimension 1 with size 0 at trainer.py:798.

Tracked it down - dataset was returning empty sequences, literally
torch.Size([1, 0]) with zero tokens.

The problem

The math that broke:

  • Shard 13 has 107,374,182,400 tokens
  • With seqlen=2048, that's exactly 52,428,800 complete sequences
  • But sample_ids file had 52,428,816 entries
  • Those last 16 "samples" were phantom - they pointed past the end of the token
    array

When sampler picked one of those phantom indices:

start = 52,428,801 × 2,048 = 107,374,184,448

But tokens only go up to 107,374,182,400, so numpy returns empty array []

Why it happened

Bug in preprocessing (02_consolidate_shards.py):

raw_idx = np.arange(0, tok_u32.shape[0] + 1, seq_len)  # that +1 is wrong

This created one extra boundary. If the shard didn't divide evenly by seqlen,
we'd get phantom sample IDs for the incomplete chunk at the end.

Runtime trusted the wrong thing:

self.total_samples = int(self.sample_ids.shape[0])  # blindly trusted sample_ids

Should've been checking against actual token count.

The fix

Runtime (sharded_dataset.py):

actual_samples = total_tokens // self.seqlen
sample_ids_count = int(self.sample_ids.shape[0])
self.total_samples = min(actual_samples, sample_ids_count)  # use the smaller one

Now logs warning if there's a mismatch and uses the safe count.

Preprocessing (02_consolidate_shards.py):

num_complete_samples = total_tokens // seq_len  # only count complete sequences
raw_idx = np.arange(0, num_complete_samples * seq_len + 1, seq_len)

Also fixed the dtype loading - now properly uses np.load for .npy files instead
of reinterpreting uint16 as uint32.

Impact

  • Runtime fix prevents crashes immediately (just ignores phantom samples)
  • Preprocessing fix means new shards won't have this issue
  • Lost like 16 samples out of 52M per affected shard, basically nothing

Notes

The preprocessing script also had a dtype issue where it was loading files as
uint16 then viewing as uint32, but since step 01 saves as uint32 in .npy
format, we should just load them properly with np.load which respects the
embedded dtype.

Prevent IndexError at trainer.py:798 caused by phantom sample IDs
created for incomplete sequences at the end of shards. The bug occurred
when preprocessing created more sample IDs than actual complete
sequences, causing out-of-bounds token access.
Fix the root cause of phantom sample IDs that led to empty sequences
being created at the end of shards.

- Only create sample IDs for complete sequences (no partial)
- Calculate num_complete_samples = total_tokens // seq_len
- Load .npy files using np.load to respect embedded dtype
- Warn when partial sequences are discarded at shard end
- Remove off-by-one error in raw_idx range calculation

This prevents the creation of invalid sample IDs that reference beyond
the actual token data, eliminating the empty sequence bug that caused
NaN losses during training.
@joellidin joellidin merged commit 072cf98 into dev Nov 20, 2025
3 of 5 checks passed
@joellidin joellidin deleted the fix/invalid-samples branch November 20, 2025 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants