diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index d27ee1d02..3ff490f79 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -484,10 +484,20 @@ def _standardize_dataset(examples): } if not isinstance(dataset, IterableDataset): - from multiprocessing import cpu_count - - if num_proc is None or type(num_proc) is not int: - num_proc = cpu_count() + import psutil + + if num_proc is None or type(num_proc) is not int: + # Use a memory-aware default to prevent OOM with large datasets + num_proc = min(max(psutil.cpu_count()+4, 2), 64) + try: + memory_gb_left = psutil.virtual_memory().available / 1024 / 1024 / 1024 + if memory_gb_left < 4: + num_proc = 1 # Too risky, so set to 1 + else: + # Limit based on available memory (assume ~1GB per worker) + num_proc = min(num_proc, max(1, int(memory_gb_left))) + except: + pass dataset_map_kwargs['num_proc'] = num_proc dataset_map_kwargs['desc'] = "Unsloth: Standardizing formats"