Skip to content

Commit e6d6718

Browse files
committed
Switched to new dataloader and added more hyperparameters to optimization
1 parent 242e0f3 commit e6d6718

File tree

3 files changed

+130
-222
lines changed

3 files changed

+130
-222
lines changed

data/pytorch/cnn_optimization.db

0 Bytes
Binary file not shown.

notebooks/04-optimized-CNN.ipynb

Lines changed: 56 additions & 190 deletions
Large diffs are not rendered by default.

src/cifar10_tools/pytorch/hyperparameter_optimization.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
import torch.optim as optim
1313
from torch.utils.data import DataLoader
1414

15+
from cifar10_tools.pytorch.data import make_data_loaders
16+
1517

1618
def create_cnn(
1719
n_conv_blocks: int,
1820
initial_filters: int,
19-
fc_units_1: int,
20-
fc_units_2: int,
21-
dropout_rate: float,
21+
n_fc_layers: int,
22+
base_kernel_size: int,
23+
conv_dropout_rate: float,
24+
fc_dropout_rate: float,
25+
pooling_strategy: str,
2226
use_batch_norm: bool,
2327
num_classes: int = 10,
2428
in_channels: int = 3,
@@ -29,9 +33,11 @@ def create_cnn(
2933
Args:
3034
n_conv_blocks: Number of convolutional blocks (1-5)
3135
initial_filters: Number of filters in first conv layer (doubles each block)
32-
fc_units_1: Number of units in first fully connected layer
33-
fc_units_2: Number of units in second fully connected layer
34-
dropout_rate: Dropout probability
36+
n_fc_layers: Number of fully connected layers (1-8)
37+
base_kernel_size: Base kernel size (decreases by 2 per block, min 3)
38+
conv_dropout_rate: Dropout probability after convolutional blocks
39+
fc_dropout_rate: Dropout probability in fully connected layers
40+
pooling_strategy: Pooling type ('max' or 'avg')
3541
use_batch_norm: Whether to use batch normalization
3642
num_classes: Number of output classes (default: 10 for CIFAR-10)
3743
in_channels: Number of input channels (default: 3 for RGB)
@@ -44,28 +50,35 @@ def create_cnn(
4450
current_channels = in_channels
4551
current_size = input_size
4652

53+
# Convolutional blocks
4754
for block_idx in range(n_conv_blocks):
4855
out_channels = initial_filters * (2 ** block_idx)
56+
kernel_size = max(3, base_kernel_size - 2 * block_idx)
57+
padding = kernel_size // 2
4958

5059
# First conv in block
51-
layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1))
60+
layers.append(nn.Conv2d(current_channels, out_channels, kernel_size=kernel_size, padding=padding))
5261

5362
if use_batch_norm:
5463
layers.append(nn.BatchNorm2d(out_channels))
5564

5665
layers.append(nn.ReLU())
5766

5867
# Second conv in block
59-
layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
68+
layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding))
6069

6170
if use_batch_norm:
6271
layers.append(nn.BatchNorm2d(out_channels))
6372

6473
layers.append(nn.ReLU())
6574

66-
# Pooling and dropout
67-
layers.append(nn.MaxPool2d(2, 2))
68-
layers.append(nn.Dropout(dropout_rate))
75+
# Pooling
76+
if pooling_strategy == 'max':
77+
layers.append(nn.MaxPool2d(2, 2))
78+
else: # avg
79+
layers.append(nn.AvgPool2d(2, 2))
80+
81+
layers.append(nn.Dropout(conv_dropout_rate))
6982

7083
current_channels = out_channels
7184
current_size //= 2
@@ -74,15 +87,26 @@ def create_cnn(
7487
final_channels = initial_filters * (2 ** (n_conv_blocks - 1))
7588
flattened_size = final_channels * current_size * current_size
7689

77-
# Classifier (3 fully connected layers)
90+
# Classifier - dynamic FC layers with halving pattern
7891
layers.append(nn.Flatten())
79-
layers.append(nn.Linear(flattened_size, fc_units_1))
80-
layers.append(nn.ReLU())
81-
layers.append(nn.Dropout(dropout_rate))
82-
layers.append(nn.Linear(fc_units_1, fc_units_2))
83-
layers.append(nn.ReLU())
84-
layers.append(nn.Dropout(dropout_rate))
85-
layers.append(nn.Linear(fc_units_2, num_classes))
92+
93+
# Generate FC layer sizes by halving from flattened_size
94+
fc_sizes = []
95+
current_fc_size = flattened_size // 2
96+
for _ in range(n_fc_layers):
97+
fc_sizes.append(max(10, current_fc_size)) # Minimum 10 units
98+
current_fc_size //= 2
99+
100+
# Add FC layers
101+
in_features = flattened_size
102+
for fc_size in fc_sizes:
103+
layers.append(nn.Linear(in_features, fc_size))
104+
layers.append(nn.ReLU())
105+
layers.append(nn.Dropout(fc_dropout_rate))
106+
in_features = fc_size
107+
108+
# Output layer
109+
layers.append(nn.Linear(in_features, num_classes))
86110

87111
return nn.Sequential(*layers)
88112

@@ -113,6 +137,7 @@ def train_trial(
113137
best_val_accuracy = 0.0
114138

115139
for epoch in range(n_epochs):
140+
116141
# Training phase
117142
model.train()
118143

@@ -149,21 +174,23 @@ def train_trial(
149174

150175

151176
def create_objective(
152-
train_loader: DataLoader,
153-
val_loader: DataLoader,
177+
data_dir,
178+
train_transform,
179+
eval_transform,
154180
n_epochs: int,
155181
device: torch.device,
156182
num_classes: int = 10,
157183
in_channels: int = 3
158184
) -> Callable[[optuna.Trial], float]:
159185
'''Create an Optuna objective function for CNN hyperparameter optimization.
160186
161-
This factory function creates a closure that captures the data loaders and
162-
training configuration, returning an objective function suitable for Optuna.
187+
This factory function creates a closure that captures the data loading parameters
188+
and training configuration, returning an objective function suitable for Optuna.
163189
164190
Args:
165-
train_loader: DataLoader for training data
166-
val_loader: DataLoader for validation data
191+
data_dir: Directory containing CIFAR-10 data
192+
train_transform: Transform to apply to training data
193+
eval_transform: Transform to apply to validation data
167194
n_epochs: Number of epochs per trial
168195
device: Device to train on (cuda or cpu)
169196
num_classes: Number of output classes (default: 10)
@@ -173,7 +200,7 @@ def create_objective(
173200
Objective function for optuna.Study.optimize()
174201
175202
Example:
176-
>>> objective = create_objective(train_loader, val_loader, n_epochs=50, device=device)
203+
>>> objective = create_objective(data_dir, transform, transform, n_epochs=50, device=device)
177204
>>> study = optuna.create_study(direction='maximize')
178205
>>> study.optimize(objective, n_trials=100)
179206
'''
@@ -182,22 +209,37 @@ def objective(trial: optuna.Trial) -> float:
182209
'''Optuna objective function for CNN hyperparameter optimization.'''
183210

184211
# Suggest hyperparameters
212+
batch_size = trial.suggest_categorical('batch_size', [64, 128, 256, 512, 1024])
185213
n_conv_blocks = trial.suggest_int('n_conv_blocks', 1, 5)
186214
initial_filters = trial.suggest_categorical('initial_filters', [8, 16, 32, 64, 128])
187-
fc_units_1 = trial.suggest_categorical('fc_units_1', [128, 256, 512, 1024, 2048])
188-
fc_units_2 = trial.suggest_categorical('fc_units_2', [32, 64, 128, 256, 512])
189-
dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.75)
215+
n_fc_layers = trial.suggest_int('n_fc_layers', 1, 8)
216+
base_kernel_size = trial.suggest_int('base_kernel_size', 3, 7)
217+
conv_dropout_rate = trial.suggest_float('conv_dropout_rate', 0.0, 0.5)
218+
fc_dropout_rate = trial.suggest_float('fc_dropout_rate', 0.2, 0.75)
219+
pooling_strategy = trial.suggest_categorical('pooling_strategy', ['max', 'avg'])
190220
use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
191221
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
192222
optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
193223

224+
# Create data loaders with suggested batch size
225+
train_loader, val_loader, _ = make_data_loaders(
226+
data_dir=data_dir,
227+
batch_size=batch_size,
228+
train_transform=train_transform,
229+
eval_transform=eval_transform,
230+
device=device,
231+
download=False
232+
)
233+
194234
# Create model
195235
model = create_cnn(
196236
n_conv_blocks=n_conv_blocks,
197237
initial_filters=initial_filters,
198-
fc_units_1=fc_units_1,
199-
fc_units_2=fc_units_2,
200-
dropout_rate=dropout_rate,
238+
n_fc_layers=n_fc_layers,
239+
base_kernel_size=base_kernel_size,
240+
conv_dropout_rate=conv_dropout_rate,
241+
fc_dropout_rate=fc_dropout_rate,
242+
pooling_strategy=pooling_strategy,
201243
use_batch_norm=use_batch_norm,
202244
num_classes=num_classes,
203245
in_channels=in_channels

0 commit comments

Comments
 (0)