Skip to content

Commit 6a4311c

Browse files
committed
Updated documentation to reflect new data loading workflow
1 parent a740db6 commit 6a4311c

File tree

4 files changed

+203
-75
lines changed

4 files changed

+203
-75
lines changed

docs/source/api/data.rst

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,66 +9,82 @@ Data loading
99
Functions
1010
---------
1111

12-
.. autofunction:: image_classification_tools.pytorch.data.make_data_loaders
12+
.. autofunction:: image_classification_tools.pytorch.data.load_datasets
13+
14+
.. autofunction:: image_classification_tools.pytorch.data.prepare_splits
15+
16+
.. autofunction:: image_classification_tools.pytorch.data.create_dataloaders
17+
18+
.. autofunction:: image_classification_tools.pytorch.data.generate_augmented_data
1319

1420
Overview
1521
--------
1622

17-
The data module provides flexible data loading capabilities with support for:
23+
The data module provides a flexible three-step data loading workflow:
24+
25+
1. **Load datasets**: Load train/test datasets from PyTorch dataset classes or directories
26+
2. **Prepare splits**: Split data into train/val(/test) with configurable ratios
27+
3. **Create dataloaders**: Create DataLoaders with optional memory preloading strategies
1828

19-
* torchvision datasets (CIFAR-10, MNIST, ImageFolder, etc.)
20-
* Custom train/eval transforms
21-
* Configurable batch sizes
22-
* Optional GPU preloading
23-
* Automatic train/validation splitting (default 80/20)
29+
Key features:
30+
31+
* Support for torchvision datasets (CIFAR-10, MNIST, etc.) and custom ImageFolder datasets
32+
* Separate train and evaluation transforms
33+
* Flexible splitting: 2-way (train/val) or 3-way (train/val/test)
34+
* Three memory strategies: lazy loading, CPU preloading, or GPU preloading
35+
* Data augmentation with chunking for large datasets
36+
* Configurable batch sizes and workers
2437

2538
Example usage
2639
-------------
2740

28-
MNIST dataset:
41+
Basic workflow (CIFAR-10 with GPU preloading):
2942

3043
.. code-block:: python
3144
3245
from pathlib import Path
46+
import torch
3347
from torchvision import datasets, transforms
34-
from image_classification_tools.pytorch.data import make_data_loaders
48+
from image_classification_tools.pytorch.data import (
49+
load_datasets, prepare_splits, create_dataloaders
50+
)
3551
52+
# Define transforms
3653
transform = transforms.Compose([
3754
transforms.ToTensor(),
38-
transforms.Normalize((0.5,), (0.5,))
55+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
3956
])
4057
41-
train_loader, val_loader, test_loader = make_data_loaders(
42-
data_dir=Path('./data'),
43-
dataset_class=datasets.MNIST,
44-
batch_size=128,
58+
# Step 1: Load datasets
59+
train_dataset, test_dataset = load_datasets(
60+
data_source=datasets.CIFAR10,
4561
train_transform=transform,
46-
eval_transform=transform
62+
eval_transform=transform,
63+
download=True,
64+
root=Path('./data/cifar10')
4765
)
4866
49-
CIFAR-10 dataset:
50-
51-
.. code-block:: python
52-
53-
from torchvision import datasets
54-
55-
transform = transforms.Compose([
56-
transforms.ToTensor(),
57-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
58-
])
67+
# Step 2: Prepare splits (2-way: train/val from train_dataset)
68+
train_dataset, val_dataset, test_dataset = prepare_splits(
69+
train_dataset=train_dataset,
70+
test_dataset=test_dataset,
71+
train_val_split=0.8 # 80% train, 20% val
72+
)
5973
60-
train_loader, val_loader, test_loader = make_data_loaders(
61-
data_dir=Path('./data'),
62-
dataset_class=datasets.CIFAR10,
74+
# Step 3: Create dataloaders with GPU preloading
75+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76+
train_loader, val_loader, test_loader = create_dataloaders(
77+
train_dataset, val_dataset, test_dataset,
6378
batch_size=128,
64-
train_transform=transform,
65-
eval_transform=transform
79+
preload_to_memory=True,
80+
device=device
6681
)
6782
68-
With data augmentation:
83+
With data augmentation (lazy loading):
6984

7085
.. code-block:: python
7186
87+
# Define separate transforms for training and evaluation
7288
train_transform = transforms.Compose([
7389
transforms.RandomHorizontalFlip(),
7490
transforms.RandomRotation(15),
@@ -81,25 +97,48 @@ With data augmentation:
8197
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
8298
])
8399
84-
train_loader, val_loader, test_loader = make_data_loaders(
85-
data_dir=Path('./data'),
86-
dataset_class=datasets.CIFAR10,
87-
batch_size=128,
100+
# Load with different transforms
101+
train_dataset, test_dataset = load_datasets(
102+
data_source=datasets.CIFAR10,
88103
train_transform=train_transform,
89104
eval_transform=eval_transform,
90-
device=None # Keep on CPU for on-the-fly augmentation
105+
root=Path('./data/cifar10')
91106
)
92107
93-
Custom dataset with ImageFolder:
108+
# Prepare splits
109+
train_dataset, val_dataset, test_dataset = prepare_splits(
110+
train_dataset=train_dataset,
111+
test_dataset=test_dataset,
112+
train_val_split=0.8
113+
)
94114
95-
.. code-block:: python
115+
# Create dataloaders with lazy loading (no preloading)
116+
train_loader, val_loader, test_loader = create_dataloaders(
117+
train_dataset, val_dataset, test_dataset,
118+
batch_size=128,
119+
preload_to_memory=False, # Lazy loading for augmentation
120+
num_workers=4,
121+
pin_memory=True
122+
)
96123
97-
from torchvision.datasets import ImageFolder
124+
3-way split (no separate test set):
125+
126+
.. code-block:: python
98127
99-
train_loader, val_loader, test_loader = make_data_loaders(
100-
data_dir=Path('./my_dataset'),
101-
dataset_class=ImageFolder,
102-
batch_size=64,
128+
# Load only training data (no test set available)
129+
train_dataset, _ = load_datasets(
130+
data_source=datasets.ImageFolder,
103131
train_transform=transform,
104-
eval_transform=transform
132+
eval_transform=transform,
133+
root=Path('./my_dataset/train')
105134
)
135+
136+
# 3-way split: train/val/test all from train_dataset
137+
train_dataset, val_dataset, test_dataset = prepare_splits(
138+
train_dataset=train_dataset,
139+
test_dataset=None, # Will split test from train_dataset
140+
train_val_split=0.8, # 80/20 split of remaining data
141+
test_split=0.15 # Reserve 15% for testing
142+
)
143+
# Results in approximately: 68% train, 17% val, 15% test
144+

docs/source/index.rst

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ Minimal example classifying MNIST digits:
3838
.. code-block:: python
3939
4040
import torch
41+
from pathlib import Path
4142
from torchvision import datasets, transforms
42-
from image_classification_tools.pytorch.data import make_data_loaders
43+
from image_classification_tools.pytorch.data import (
44+
load_datasets, prepare_splits, create_dataloaders
45+
)
4346
from image_classification_tools.pytorch.training import train_model
4447
4548
# Load data
@@ -48,12 +51,25 @@ Minimal example classifying MNIST digits:
4851
transforms.Normalize((0.5,), (0.5,))
4952
])
5053
51-
train_loader, val_loader, test_loader = make_data_loaders(
52-
data_dir='./data',
53-
dataset_class=datasets.MNIST,
54-
batch_size=64,
54+
# Load, split, and create dataloaders
55+
train_dataset, test_dataset = load_datasets(
56+
data_source=datasets.MNIST,
5557
train_transform=transform,
56-
eval_transform=transform
58+
eval_transform=transform,
59+
download=True,
60+
root=Path('./data/mnist')
61+
)
62+
63+
train_dataset, val_dataset, test_dataset = prepare_splits(
64+
train_dataset, test_dataset, train_val_split=0.8
65+
)
66+
67+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68+
train_loader, val_loader, test_loader = create_dataloaders(
69+
train_dataset, val_dataset, test_dataset,
70+
batch_size=64,
71+
preload_to_memory=True,
72+
device=device
5773
)
5874
5975
# Define model

docs/source/quickstart.rst

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,39 @@ This example shows the complete workflow using the MNIST dataset.
2626
from pathlib import Path
2727
import torch
2828
from torchvision import datasets, transforms
29-
from image_classification_tools.pytorch.data import make_data_loaders
29+
from image_classification_tools.pytorch.data import (
30+
load_datasets, prepare_splits, create_dataloaders
31+
)
3032
3133
# Define preprocessing
3234
transform = transforms.Compose([
3335
transforms.ToTensor(),
3436
transforms.Normalize((0.5,), (0.5,))
3537
])
3638
37-
# Create data loaders
38-
train_loader, val_loader, test_loader = make_data_loaders(
39-
data_dir=Path('./data'),
40-
dataset_class=datasets.MNIST,
41-
batch_size=128,
39+
# Step 1: Load datasets
40+
train_dataset, test_dataset = load_datasets(
41+
data_source=datasets.MNIST,
4242
train_transform=transform,
4343
eval_transform=transform,
44-
device='cuda' if torch.cuda.is_available() else 'cpu'
44+
download=True,
45+
root=Path('./data/mnist')
46+
)
47+
48+
# Step 2: Prepare splits
49+
train_dataset, val_dataset, test_dataset = prepare_splits(
50+
train_dataset=train_dataset,
51+
test_dataset=test_dataset,
52+
train_val_split=0.8
53+
)
54+
55+
# Step 3: Create dataloaders
56+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57+
train_loader, val_loader, test_loader = create_dataloaders(
58+
train_dataset, val_dataset, test_dataset,
59+
batch_size=128,
60+
preload_to_memory=True,
61+
device=device
4562
)
4663
4764
2. Define model
@@ -83,7 +100,7 @@ This example shows the complete workflow using the MNIST dataset.
83100
criterion=criterion,
84101
optimizer=optimizer,
85102
device=device,
86-
lazy_loading=False, # Data already on device from make_data_loaders
103+
lazy_loading=False, # Set to False when using preload_to_memory=True
87104
epochs=20,
88105
print_every=5
89106
)
@@ -124,16 +141,40 @@ For datasets in ImageFolder format:
124141

125142
.. code-block:: python
126143
144+
from pathlib import Path
127145
from torchvision.datasets import ImageFolder
128146
129-
train_loader, val_loader, test_loader = make_data_loaders(
130-
data_dir=Path('./my_dataset'),
131-
dataset_class=ImageFolder,
132-
batch_size=64,
147+
# Define transform
148+
transform = transforms.Compose([
149+
transforms.Resize((224, 224)),
150+
transforms.ToTensor(),
151+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
152+
std=[0.229, 0.224, 0.225])
153+
])
154+
155+
# Load datasets from directory structure
156+
train_dataset, test_dataset = load_datasets(
157+
data_source=Path('./my_dataset'),
133158
train_transform=transform,
134159
eval_transform=transform
135160
)
136161
162+
# If no test directory exists, use 3-way split
163+
train_dataset, val_dataset, test_dataset = prepare_splits(
164+
train_dataset=train_dataset,
165+
test_dataset=test_dataset, # Will be None if no test/ directory
166+
train_val_split=0.8,
167+
test_split=0.1 # Only used if test_dataset is None
168+
)
169+
170+
# Create dataloaders
171+
train_loader, val_loader, test_loader = create_dataloaders(
172+
train_dataset, val_dataset, test_dataset,
173+
batch_size=64,
174+
preload_to_memory=False, # Lazy loading for large datasets
175+
num_workers=4
176+
)
177+
137178
Your directory structure should be:
138179

139180
.. code-block:: text
@@ -195,13 +236,26 @@ Improve generalization with data augmentation:
195236
transforms.Normalize((0.5,), (0.5,))
196237
])
197238
198-
# Use different transforms for training and evaluation
199-
train_loader, val_loader, test_loader = make_data_loaders(
200-
data_dir=data_dir,
201-
dataset_class=datasets.MNIST,
202-
batch_size=128,
239+
# Load with separate transforms
240+
train_dataset, test_dataset = load_datasets(
241+
data_source=datasets.MNIST,
203242
train_transform=train_transform,
204-
eval_transform=eval_transform
243+
eval_transform=eval_transform,
244+
root=Path('./data/mnist')
245+
)
246+
247+
# Prepare splits
248+
train_dataset, val_dataset, test_dataset = prepare_splits(
249+
train_dataset, test_dataset, train_val_split=0.8
250+
)
251+
252+
# Create dataloaders with lazy loading (important for augmentation)
253+
train_loader, val_loader, test_loader = create_dataloaders(
254+
train_dataset, val_dataset, test_dataset,
255+
batch_size=128,
256+
preload_to_memory=False, # Use lazy loading for on-the-fly augmentation
257+
num_workers=4,
258+
pin_memory=True
205259
)
206260
207261
Hyperparameter optimization

0 commit comments

Comments
 (0)