-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Describe the bug
Hi team, first off, I love the datasets library! 🥰
I'm encountering a potential memory leak / increasing memory usage when training a model on a very large DatasetDict.
Setup: I have a DatasetDict containing 362 distinct datasets, which sum up to ~2.8 billion rows.
Training Task: I'm performing contrastive learning with SentenceTransformer and Accelerate on a single node with 4 H100, which requires me to sample from only one dataset at a time.
Training Loop: At each training step, I sample ~16,000 examples from a single dataset, and then switch to a different dataset for the next step. I iterate through all 362 datasets this way.
Problem: The process's memory usage continuously increases over time, eventually causing a stale status where GPUs would stop working. It seems memory from previously sampled datasets isn't being released. I've set num_workers=0 for all experiments.
Chart 1: Standard DatasetDict The memory usage grows steadily until it make the training stale (RSS memory) 
Chart 2: IterableDatasetDict I also tried to use IterableDatasetDict and IterableDataset. The memory curve is "smoother," but the result is the same: it grows indefinitely and the training become stale. 
Any feedback or guidance on how to manage this memory would be greatly appreciated!
Steps to reproduce the bug
WIP, I'll add some code that manage to reproduce this error, but not straightforward.
Expected behavior
The memory usage should remain relatively constant or plateau after a few steps. Memory used for sampling one dataset should be released before or during the sampling of the next dataset.
Environment info
Python: 3.12
Datasets: 4.3.0
SentenceTransformers: 5.1.1