1212import torch .optim as optim
1313from torch .utils .data import DataLoader
1414
15+ from cifar10_tools .pytorch .data import make_data_loaders
16+
1517
1618def create_cnn (
1719 n_conv_blocks : int ,
1820 initial_filters : int ,
19- fc_units_1 : int ,
20- fc_units_2 : int ,
21- dropout_rate : float ,
21+ n_fc_layers : int ,
22+ base_kernel_size : int ,
23+ conv_dropout_rate : float ,
24+ fc_dropout_rate : float ,
25+ pooling_strategy : str ,
2226 use_batch_norm : bool ,
2327 num_classes : int = 10 ,
2428 in_channels : int = 3 ,
@@ -29,9 +33,11 @@ def create_cnn(
2933 Args:
3034 n_conv_blocks: Number of convolutional blocks (1-5)
3135 initial_filters: Number of filters in first conv layer (doubles each block)
32- fc_units_1: Number of units in first fully connected layer
33- fc_units_2: Number of units in second fully connected layer
34- dropout_rate: Dropout probability
36+ n_fc_layers: Number of fully connected layers (1-8)
37+ base_kernel_size: Base kernel size (decreases by 2 per block, min 3)
38+ conv_dropout_rate: Dropout probability after convolutional blocks
39+ fc_dropout_rate: Dropout probability in fully connected layers
40+ pooling_strategy: Pooling type ('max' or 'avg')
3541 use_batch_norm: Whether to use batch normalization
3642 num_classes: Number of output classes (default: 10 for CIFAR-10)
3743 in_channels: Number of input channels (default: 3 for RGB)
@@ -44,28 +50,35 @@ def create_cnn(
4450 current_channels = in_channels
4551 current_size = input_size
4652
53+ # Convolutional blocks
4754 for block_idx in range (n_conv_blocks ):
4855 out_channels = initial_filters * (2 ** block_idx )
56+ kernel_size = max (3 , base_kernel_size - 2 * block_idx )
57+ padding = kernel_size // 2
4958
5059 # First conv in block
51- layers .append (nn .Conv2d (current_channels , out_channels , kernel_size = 3 , padding = 1 ))
60+ layers .append (nn .Conv2d (current_channels , out_channels , kernel_size = kernel_size , padding = padding ))
5261
5362 if use_batch_norm :
5463 layers .append (nn .BatchNorm2d (out_channels ))
5564
5665 layers .append (nn .ReLU ())
5766
5867 # Second conv in block
59- layers .append (nn .Conv2d (out_channels , out_channels , kernel_size = 3 , padding = 1 ))
68+ layers .append (nn .Conv2d (out_channels , out_channels , kernel_size = kernel_size , padding = padding ))
6069
6170 if use_batch_norm :
6271 layers .append (nn .BatchNorm2d (out_channels ))
6372
6473 layers .append (nn .ReLU ())
6574
66- # Pooling and dropout
67- layers .append (nn .MaxPool2d (2 , 2 ))
68- layers .append (nn .Dropout (dropout_rate ))
75+ # Pooling
76+ if pooling_strategy == 'max' :
77+ layers .append (nn .MaxPool2d (2 , 2 ))
78+ else : # avg
79+ layers .append (nn .AvgPool2d (2 , 2 ))
80+
81+ layers .append (nn .Dropout (conv_dropout_rate ))
6982
7083 current_channels = out_channels
7184 current_size //= 2
@@ -74,15 +87,26 @@ def create_cnn(
7487 final_channels = initial_filters * (2 ** (n_conv_blocks - 1 ))
7588 flattened_size = final_channels * current_size * current_size
7689
77- # Classifier (3 fully connected layers)
90+ # Classifier - dynamic FC layers with halving pattern
7891 layers .append (nn .Flatten ())
79- layers .append (nn .Linear (flattened_size , fc_units_1 ))
80- layers .append (nn .ReLU ())
81- layers .append (nn .Dropout (dropout_rate ))
82- layers .append (nn .Linear (fc_units_1 , fc_units_2 ))
83- layers .append (nn .ReLU ())
84- layers .append (nn .Dropout (dropout_rate ))
85- layers .append (nn .Linear (fc_units_2 , num_classes ))
92+
93+ # Generate FC layer sizes by halving from flattened_size
94+ fc_sizes = []
95+ current_fc_size = flattened_size // 2
96+ for _ in range (n_fc_layers ):
97+ fc_sizes .append (max (10 , current_fc_size )) # Minimum 10 units
98+ current_fc_size //= 2
99+
100+ # Add FC layers
101+ in_features = flattened_size
102+ for fc_size in fc_sizes :
103+ layers .append (nn .Linear (in_features , fc_size ))
104+ layers .append (nn .ReLU ())
105+ layers .append (nn .Dropout (fc_dropout_rate ))
106+ in_features = fc_size
107+
108+ # Output layer
109+ layers .append (nn .Linear (in_features , num_classes ))
86110
87111 return nn .Sequential (* layers )
88112
@@ -113,6 +137,7 @@ def train_trial(
113137 best_val_accuracy = 0.0
114138
115139 for epoch in range (n_epochs ):
140+
116141 # Training phase
117142 model .train ()
118143
@@ -149,21 +174,23 @@ def train_trial(
149174
150175
151176def create_objective (
152- train_loader : DataLoader ,
153- val_loader : DataLoader ,
177+ data_dir ,
178+ train_transform ,
179+ eval_transform ,
154180 n_epochs : int ,
155181 device : torch .device ,
156182 num_classes : int = 10 ,
157183 in_channels : int = 3
158184) -> Callable [[optuna .Trial ], float ]:
159185 '''Create an Optuna objective function for CNN hyperparameter optimization.
160186
161- This factory function creates a closure that captures the data loaders and
162- training configuration, returning an objective function suitable for Optuna.
187+ This factory function creates a closure that captures the data loading parameters
188+ and training configuration, returning an objective function suitable for Optuna.
163189
164190 Args:
165- train_loader: DataLoader for training data
166- val_loader: DataLoader for validation data
191+ data_dir: Directory containing CIFAR-10 data
192+ train_transform: Transform to apply to training data
193+ eval_transform: Transform to apply to validation data
167194 n_epochs: Number of epochs per trial
168195 device: Device to train on (cuda or cpu)
169196 num_classes: Number of output classes (default: 10)
@@ -173,7 +200,7 @@ def create_objective(
173200 Objective function for optuna.Study.optimize()
174201
175202 Example:
176- >>> objective = create_objective(train_loader, val_loader , n_epochs=50, device=device)
203+ >>> objective = create_objective(data_dir, transform, transform , n_epochs=50, device=device)
177204 >>> study = optuna.create_study(direction='maximize')
178205 >>> study.optimize(objective, n_trials=100)
179206 '''
@@ -182,22 +209,37 @@ def objective(trial: optuna.Trial) -> float:
182209 '''Optuna objective function for CNN hyperparameter optimization.'''
183210
184211 # Suggest hyperparameters
212+ batch_size = trial .suggest_categorical ('batch_size' , [64 , 128 , 256 , 512 , 1024 ])
185213 n_conv_blocks = trial .suggest_int ('n_conv_blocks' , 1 , 5 )
186214 initial_filters = trial .suggest_categorical ('initial_filters' , [8 , 16 , 32 , 64 , 128 ])
187- fc_units_1 = trial .suggest_categorical ('fc_units_1' , [128 , 256 , 512 , 1024 , 2048 ])
188- fc_units_2 = trial .suggest_categorical ('fc_units_2' , [32 , 64 , 128 , 256 , 512 ])
189- dropout_rate = trial .suggest_float ('dropout_rate' , 0.2 , 0.75 )
215+ n_fc_layers = trial .suggest_int ('n_fc_layers' , 1 , 8 )
216+ base_kernel_size = trial .suggest_int ('base_kernel_size' , 3 , 7 )
217+ conv_dropout_rate = trial .suggest_float ('conv_dropout_rate' , 0.0 , 0.5 )
218+ fc_dropout_rate = trial .suggest_float ('fc_dropout_rate' , 0.2 , 0.75 )
219+ pooling_strategy = trial .suggest_categorical ('pooling_strategy' , ['max' , 'avg' ])
190220 use_batch_norm = trial .suggest_categorical ('use_batch_norm' , [True , False ])
191221 learning_rate = trial .suggest_float ('learning_rate' , 1e-5 , 1e-1 , log = True )
192222 optimizer_name = trial .suggest_categorical ('optimizer' , ['Adam' , 'SGD' , 'RMSprop' ])
193223
224+ # Create data loaders with suggested batch size
225+ train_loader , val_loader , _ = make_data_loaders (
226+ data_dir = data_dir ,
227+ batch_size = batch_size ,
228+ train_transform = train_transform ,
229+ eval_transform = eval_transform ,
230+ device = device ,
231+ download = False
232+ )
233+
194234 # Create model
195235 model = create_cnn (
196236 n_conv_blocks = n_conv_blocks ,
197237 initial_filters = initial_filters ,
198- fc_units_1 = fc_units_1 ,
199- fc_units_2 = fc_units_2 ,
200- dropout_rate = dropout_rate ,
238+ n_fc_layers = n_fc_layers ,
239+ base_kernel_size = base_kernel_size ,
240+ conv_dropout_rate = conv_dropout_rate ,
241+ fc_dropout_rate = fc_dropout_rate ,
242+ pooling_strategy = pooling_strategy ,
201243 use_batch_norm = use_batch_norm ,
202244 num_classes = num_classes ,
203245 in_channels = in_channels
0 commit comments