JAX DataLoader Quick Reference Guide
Common Issues and Quick Fixes
Issue
Quick Fix
Out of Memory
use_mmap=True, batch_size=16, num_workers=1
Memory Leaks
cleanup_interval=1000, debug=True
High Memory Usage
use_pinned_memory=False, prefetch_size=1
Issue
Quick Fix
Slow Loading
num_workers=4, prefetch_size=2, use_mmap=True
High CPU Usage
num_workers=2, worker_type='thread'
Disk I/O Bottleneck
use_mmap=True, cache_size=1000
Issue
Quick Fix
CUDA OOM
gpu_memory_fraction=0.8, batch_size=16
Device Not Found
device='cpu'
GPU Sync Issues
sync_every_batch=True
Issue
Quick Fix
Shape Mismatch
dynamic_shapes=True, padding_value=0
Type Errors
dtype=jnp.float32, convert_types=True
Batch Size Issues
validate_shapes(dataset, batch_size)
Issue
Quick Fix
Sync Errors
sync_every_batch=True, barrier_timeout=30
Load Imbalance
dynamic_balancing=True, balance_interval=100
Node Communication
num_nodes=2, node_rank=0
Issue
Quick Fix
Slow Augmentation
jit=True, parallel=True
Memory Issues
cache_augmented=True
Randomness
seed=42
Issue
Quick Fix
No Progress Bar
show_progress=True
Inaccurate Progress
progress_interval=1
Performance Impact
update_interval=10
Issue
Quick Fix
Cache Memory
cache_size=1000, eviction_policy='lru'
Cache Thrashing
cleanup_threshold=0.8
Prefetch Issues
cache_prefetch=True
Common Error Messages and Solutions
# CUDA out of memory
loader = JAXDataLoader (
dataset = data ,
batch_size = 16 , # Reduce batch size
gpu_memory_fraction = 0.8
)
# System out of memory
loader = JAXDataLoader (
dataset = data ,
use_mmap = True ,
num_workers = 1
)
# Slow data loading
loader = JAXDataLoader (
dataset = data ,
num_workers = 4 ,
prefetch_size = 2
)
# High CPU usage
loader = JAXDataLoader (
dataset = data ,
num_workers = 2 ,
use_mmap = True
)
# Shape mismatch
loader = JAXDataLoader (
dataset = data ,
dynamic_shapes = True
)
# Type errors
loader = JAXDataLoader (
dataset = data ,
dtype = jnp .float32 ,
convert_types = True
)
Quick Configuration Templates
loader = JAXDataLoader (
dataset = data ,
batch_size = 32 ,
num_workers = 2 ,
prefetch_size = 1
)
loader = JAXDataLoader (
dataset = data ,
batch_size = 16 ,
use_mmap = True ,
use_pinned_memory = True ,
num_workers = 1
)
loader = JAXDataLoader (
dataset = data ,
batch_size = 64 ,
num_workers = 4 ,
prefetch_size = 2 ,
cache_size = 1000
)
loader = JAXDataLoader (
dataset = data ,
batch_size = 32 ,
gpu_memory_fraction = 0.8 ,
device = 'cuda:0'
)
Advanced Configuration Templates
Large Dataset Configuration
loader = JAXDataLoader (
dataset = data ,
batch_size = 128 ,
num_workers = 8 ,
prefetch_size = 4 ,
use_mmap = True ,
cache_size = 5000 ,
cache_strategy = 'lru'
)
Small Dataset Configuration
loader = JAXDataLoader (
dataset = data ,
batch_size = 8 ,
num_workers = 1 ,
prefetch_size = 1 ,
use_pinned_memory = True ,
cache_size = 100
)
loader = JAXDataLoader (
dataset = data ,
batch_size = 32 ,
dtype = jnp .float16 , # Use half precision
convert_types = True ,
gpu_memory_fraction = 0.9
)
loader = JAXDataLoader (
dataset = data ,
batch_size = 64 ,
num_workers = 4 ,
device = 'cuda:0' ,
gpu_memory_fraction = 0.7 ,
sync_every_batch = True
)
Streaming Data Configuration
loader = JAXDataLoader (
dataset = data ,
batch_size = 32 ,
num_workers = 2 ,
prefetch_size = 2 ,
use_mmap = True ,
streaming = True ,
cache_size = 0 # Disable caching for streaming
)
loader = JAXDataLoader (
dataset = image_dataset ,
batch_size = 32 ,
num_workers = 4 ,
prefetch_size = 2 ,
augmenter = JAXDataAugmentation (
augmentations = ['random_flip' , 'random_rotation' , 'color_jitter' ],
probability = 0.5
),
cache_size = 1000
)
2. Natural Language Processing
loader = JAXDataLoader (
dataset = text_dataset ,
batch_size = 64 ,
num_workers = 2 ,
dynamic_shapes = True ,
padding_value = 0 ,
dtype = jnp .int32
)
loader = JAXDataLoader (
dataset = video_dataset ,
batch_size = 8 , # Smaller batches for video
num_workers = 4 ,
use_mmap = True ,
cache_size = 100 ,
gpu_memory_fraction = 0.8
)
4. Reinforcement Learning
loader = JAXDataLoader (
dataset = replay_buffer ,
batch_size = 128 ,
num_workers = 2 ,
shuffle = True ,
cache_size = 0 # No caching for RL
)
loader = JAXDataLoader (
dataset = client_data ,
batch_size = 32 ,
num_workers = 2 ,
distributed_config = DistributedConfig (
num_nodes = 10 ,
node_rank = 0 ,
sync_every_batch = True
)
)
Troubleshooting Flowchart
graph TD
A[Problem] --> B{Memory Issue?}
B -->|Yes| C[Reduce batch size]
C --> D[Enable memory mapping]
D --> E[Monitor memory usage]
B -->|No| F{Performance Issue?}
F -->|Yes| G[Increase workers]
G --> H[Enable caching]
H --> I[Use memory mapping]
F -->|No| J{GPU Issue?}
J -->|Yes| K[Check device availability]
K --> L[Limit GPU memory]
L --> M[Enable sync]
J -->|No| N{Data Loading Issue?}
N -->|Yes| O[Validate shapes]
O --> P[Enable type conversion]
P --> Q[Use dynamic shapes]
N -->|No| R{Distributed Issue?}
R -->|Yes| S[Configure sync]
S --> T[Enable load balancing]
T --> U[Monitor workers]
R -->|No| V{Augmentation Issue?}
V -->|Yes| W[Enable JIT]
W --> X[Use parallel processing]
X --> Y[Cache augmented data]
V -->|No| Z{Progress Issue?}
Z -->|Yes| AA[Set interval]
AA --> BB[Use custom format]
BB --> CC[Monitor performance]
Z -->|No| DD{Caching Issue?}
DD -->|Yes| EE[Set cache size]
EE --> FF[Enable prefetch]
FF --> GG[Monitor memory]
DD -->|No| HH[Check error logs]
HH --> II[Enable debug mode]
II --> JJ[Contact support]
Loading
Expanded Parameter Reference Tables
Parameter
Value
Effect
Recommended For
batch_size
16-64
Lower = less memory
Small GPUs, OOM errors
num_workers
1-4
Lower = less memory
Memory-constrained systems
prefetch_size
1-2
Lower = less memory
Streaming data
use_mmap
True
Reduces memory usage
Large datasets
cache_size
0-1000
Lower = less memory
Memory-sensitive tasks
use_pinned_memory
False
Reduces memory pressure
CPU-only systems
cleanup_interval
100-1000
More frequent = less memory
Long-running tasks
gpu_memory_fraction
0.5-0.8
Lower = more headroom
Multi-GPU systems
Parameter
Value
Effect
Recommended For
num_workers
2-8
Higher = better performance
Multi-core CPUs
prefetch_size
2-4
Higher = better performance
Fast storage
cache_size
100-5000
Higher = better performance
Repeated access
use_mmap
True
Better I/O performance
Large files
worker_type
'thread'
Lower overhead
CPU-bound tasks
cache_strategy
'lru'/'fifo'
LRU = better hit rate
Mixed access patterns
sync_every_batch
True
Better consistency
Distributed training
dynamic_balancing
True
Better load distribution
Uneven workloads
Parameter
Value
Effect
Recommended For
gpu_memory_fraction
0.7-0.9
Higher = more GPU memory
Single-GPU systems
batch_size
32-128
Higher = better utilization
Large GPUs
sync_every_batch
True
Better synchronization
Multi-GPU training
device
'cuda:0'
Specific GPU device
Multi-GPU systems
use_pinned_memory
True
Better GPU transfer
GPU-intensive tasks
gpu_prefetch
True
Better GPU utilization
High-throughput tasks
gpu_cache_size
100-1000
GPU-side caching
Repeated GPU operations
gpu_streams
2-4
Concurrent GPU ops
Multi-stream workloads
Parameter
Value
Effect
Recommended For
dynamic_shapes
True
Flexible batch sizes
Variable-length data
padding_value
0/-1
Batch padding
NLP tasks
dtype
jnp.float32
Data type
General purpose
convert_types
True
Automatic conversion
Mixed data types
shuffle
True
Random order
Training data
drop_last
True
Drop incomplete batches
Fixed-size models
collate_fn
custom
Custom batching
Specialized tasks
sampler
custom
Custom sampling
Imbalanced data
Detailed Troubleshooting Steps
Identify Memory Usage
from jax_dataloader .memory import MemoryTracker
tracker = MemoryTracker (verbose = True )
with tracker :
loader = JAXDataLoader (dataset = data )
Reduce Memory Footprint
Decrease batch size by 50%
Reduce number of workers
Disable prefetching
Enable memory mapping
Monitor Memory Growth
from jax_dataloader .memory import MemoryMonitor
monitor = MemoryMonitor (
check_interval = 100 ,
threshold_mb = 1024 ,
alert = True
)
Profile Data Loading
from jax_dataloader .profile import DataLoaderProfiler
profiler = DataLoaderProfiler (
track_memory = True ,
track_performance = True ,
track_errors = True
)
Optimize Configuration
Increase workers gradually
Enable caching
Use memory mapping
Adjust prefetch size
Monitor System Resources
from jax_dataloader .monitor import SystemMonitor
monitor = SystemMonitor (
track_cpu = True ,
track_memory = True ,
track_disk = True
)
Check GPU Availability
from jax_dataloader .utils import get_available_devices
devices = get_available_devices ()
if not devices :
print ("No GPU available" )
Monitor GPU Usage
from jax_dataloader .gpu import GPUMonitor
monitor = GPUMonitor (
device = 'cuda:0' ,
track_memory = True ,
track_utilization = True
)
Optimize GPU Memory
Reduce batch size
Enable memory mapping
Use pinned memory
Limit GPU memory fraction
Small Dataset (10K samples)
Configuration
Memory (MB)
Time/epoch (s)
GPU Util (%)
Basic
256
1.2
45
Memory-opt
128
1.5
40
Performance-opt
512
0.8
65
GPU-opt
384
0.6
85
Medium Dataset (100K samples)
Configuration
Memory (MB)
Time/epoch (s)
GPU Util (%)
Basic
1024
12.5
50
Memory-opt
512
15.0
45
Performance-opt
2048
8.2
70
GPU-opt
1536
6.5
90
Large Dataset (1M samples)
Configuration
Memory (MB)
Time/epoch (s)
GPU Util (%)
Basic
4096
125.0
55
Memory-opt
2048
150.0
50
Performance-opt
8192
82.0
75
GPU-opt
6144
65.0
95
Basic Configuration
loader = JAXDataLoader (
dataset = data ,
batch_size = 32 ,
num_workers = 2 ,
prefetch_size = 1
)
Memory-Optimized
loader = JAXDataLoader (
dataset = data ,
batch_size = 16 ,
use_mmap = True ,
num_workers = 1 ,
prefetch_size = 1
)
Performance-Optimized
loader = JAXDataLoader (
dataset = data ,
batch_size = 64 ,
num_workers = 4 ,
prefetch_size = 2 ,
cache_size = 1000
)
GPU-Optimized
loader = JAXDataLoader (
dataset = data ,
batch_size = 128 ,
gpu_memory_fraction = 0.9 ,
use_pinned_memory = True ,
gpu_prefetch = True
)
Tests run on NVIDIA RTX 3080
CPU: AMD Ryzen 9 5900X
Memory: 32GB DDR4
Storage: NVMe SSD
All times are averages over 10 epochs
GPU utilization measured with nvidia-smi
Memory usage includes both CPU and GPU memory