@@ -9,66 +9,82 @@ Data loading
99Functions
1010---------
1111
12- .. autofunction :: image_classification_tools.pytorch.data.make_data_loaders
12+ .. autofunction :: image_classification_tools.pytorch.data.load_datasets
13+
14+ .. autofunction :: image_classification_tools.pytorch.data.prepare_splits
15+
16+ .. autofunction :: image_classification_tools.pytorch.data.create_dataloaders
17+
18+ .. autofunction :: image_classification_tools.pytorch.data.generate_augmented_data
1319
1420Overview
1521--------
1622
17- The data module provides flexible data loading capabilities with support for:
23+ The data module provides a flexible three-step data loading workflow:
24+
25+ 1. **Load datasets **: Load train/test datasets from PyTorch dataset classes or directories
26+ 2. **Prepare splits **: Split data into train/val(/test) with configurable ratios
27+ 3. **Create dataloaders **: Create DataLoaders with optional memory preloading strategies
1828
19- * torchvision datasets (CIFAR-10, MNIST, ImageFolder, etc.)
20- * Custom train/eval transforms
21- * Configurable batch sizes
22- * Optional GPU preloading
23- * Automatic train/validation splitting (default 80/20)
29+ Key features:
30+
31+ * Support for torchvision datasets (CIFAR-10, MNIST, etc.) and custom ImageFolder datasets
32+ * Separate train and evaluation transforms
33+ * Flexible splitting: 2-way (train/val) or 3-way (train/val/test)
34+ * Three memory strategies: lazy loading, CPU preloading, or GPU preloading
35+ * Data augmentation with chunking for large datasets
36+ * Configurable batch sizes and workers
2437
2538Example usage
2639-------------
2740
28- MNIST dataset :
41+ Basic workflow (CIFAR-10 with GPU preloading) :
2942
3043.. code-block :: python
3144
3245 from pathlib import Path
46+ import torch
3347 from torchvision import datasets, transforms
34- from image_classification_tools.pytorch.data import make_data_loaders
48+ from image_classification_tools.pytorch.data import (
49+ load_datasets, prepare_splits, create_dataloaders
50+ )
3551
52+ # Define transforms
3653 transform = transforms.Compose([
3754 transforms.ToTensor(),
38- transforms.Normalize((0.5 ,), (0.5 ,))
55+ transforms.Normalize((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
3956 ])
4057
41- train_loader, val_loader, test_loader = make_data_loaders(
42- data_dir = Path(' ./data' ),
43- dataset_class = datasets.MNIST ,
44- batch_size = 128 ,
58+ # Step 1: Load datasets
59+ train_dataset, test_dataset = load_datasets(
60+ data_source = datasets.CIFAR10 ,
4561 train_transform = transform,
46- eval_transform = transform
62+ eval_transform = transform,
63+ download = True ,
64+ root = Path(' ./data/cifar10' )
4765 )
4866
49- CIFAR-10 dataset:
50-
51- .. code-block :: python
52-
53- from torchvision import datasets
54-
55- transform = transforms.Compose([
56- transforms.ToTensor(),
57- transforms.Normalize((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
58- ])
67+ # Step 2: Prepare splits (2-way: train/val from train_dataset)
68+ train_dataset, val_dataset, test_dataset = prepare_splits(
69+ train_dataset = train_dataset,
70+ test_dataset = test_dataset,
71+ train_val_split = 0.8 # 80% train, 20% val
72+ )
5973
60- train_loader, val_loader, test_loader = make_data_loaders(
61- data_dir = Path(' ./data' ),
62- dataset_class = datasets.CIFAR10 ,
74+ # Step 3: Create dataloaders with GPU preloading
75+ device = torch.device(' cuda' if torch.cuda.is_available() else ' cpu' )
76+ train_loader, val_loader, test_loader = create_dataloaders(
77+ train_dataset, val_dataset, test_dataset,
6378 batch_size = 128 ,
64- train_transform = transform ,
65- eval_transform = transform
79+ preload_to_memory = True ,
80+ device = device
6681 )
6782
68- With data augmentation:
83+ With data augmentation (lazy loading) :
6984
7085.. code-block :: python
7186
87+ # Define separate transforms for training and evaluation
7288 train_transform = transforms.Compose([
7389 transforms.RandomHorizontalFlip(),
7490 transforms.RandomRotation(15 ),
@@ -81,25 +97,48 @@ With data augmentation:
8197 transforms.Normalize((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
8298 ])
8399
84- train_loader, val_loader, test_loader = make_data_loaders(
85- data_dir = Path(' ./data' ),
86- dataset_class = datasets.CIFAR10 ,
87- batch_size = 128 ,
100+ # Load with different transforms
101+ train_dataset, test_dataset = load_datasets(
102+ data_source = datasets.CIFAR10 ,
88103 train_transform = train_transform,
89104 eval_transform = eval_transform,
90- device = None # Keep on CPU for on-the-fly augmentation
105+ root = Path( ' ./data/cifar10 ' )
91106 )
92107
93- Custom dataset with ImageFolder:
108+ # Prepare splits
109+ train_dataset, val_dataset, test_dataset = prepare_splits(
110+ train_dataset = train_dataset,
111+ test_dataset = test_dataset,
112+ train_val_split = 0.8
113+ )
94114
95- .. code-block :: python
115+ # Create dataloaders with lazy loading (no preloading)
116+ train_loader, val_loader, test_loader = create_dataloaders(
117+ train_dataset, val_dataset, test_dataset,
118+ batch_size = 128 ,
119+ preload_to_memory = False , # Lazy loading for augmentation
120+ num_workers = 4 ,
121+ pin_memory = True
122+ )
96123
97- from torchvision.datasets import ImageFolder
124+ 3-way split (no separate test set):
125+
126+ .. code-block :: python
98127
99- train_loader, val_loader, test_loader = make_data_loaders(
100- data_dir = Path(' ./my_dataset' ),
101- dataset_class = ImageFolder,
102- batch_size = 64 ,
128+ # Load only training data (no test set available)
129+ train_dataset, _ = load_datasets(
130+ data_source = datasets.ImageFolder,
103131 train_transform = transform,
104- eval_transform = transform
132+ eval_transform = transform,
133+ root = Path(' ./my_dataset/train' )
105134 )
135+
136+ # 3-way split: train/val/test all from train_dataset
137+ train_dataset, val_dataset, test_dataset = prepare_splits(
138+ train_dataset = train_dataset,
139+ test_dataset = None , # Will split test from train_dataset
140+ train_val_split = 0.8 , # 80/20 split of remaining data
141+ test_split = 0.15 # Reserve 15% for testing
142+ )
143+ # Results in approximately: 68% train, 17% val, 15% test
144+
0 commit comments