@@ -10,8 +10,8 @@ const int64_t kNoiseSize = 100;
10
10
// The batch size for training.
11
11
const int64_t kBatchSize = 64 ;
12
12
13
- // The number of epochs to train.
14
- const int64_t kNumberOfEpochs = 30 ;
13
+ // The default number of epochs to train.
14
+ int64_t kNumberOfEpochs = 30 ;
15
15
16
16
// Where to find the MNIST dataset.
17
17
const char * kDataFolder = " ./data" ;
@@ -75,7 +75,39 @@ struct DCGANGeneratorImpl : nn::Module {
75
75
76
76
TORCH_MODULE (DCGANGenerator);
77
77
78
+ nn::Sequential create_discriminator () {
79
+ return nn::Sequential (
80
+ // Layer 1
81
+ nn::Conv2d (nn::Conv2dOptions (1 , 64 , 4 ).stride (2 ).padding (1 ).bias (false )),
82
+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
83
+ // Layer 2
84
+ nn::Conv2d (nn::Conv2dOptions (64 , 128 , 4 ).stride (2 ).padding (1 ).bias (false )),
85
+ nn::BatchNorm2d (128 ),
86
+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
87
+ // Layer 3
88
+ nn::Conv2d (
89
+ nn::Conv2dOptions (128 , 256 , 4 ).stride (2 ).padding (1 ).bias (false )),
90
+ nn::BatchNorm2d (256 ),
91
+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
92
+ // Layer 4
93
+ nn::Conv2d (nn::Conv2dOptions (256 , 1 , 3 ).stride (1 ).padding (0 ).bias (false )),
94
+ nn::Sigmoid ());
95
+ }
96
+
78
97
int main (int argc, const char * argv[]) {
98
+
99
+ if (argc > 1 ) {
100
+ std::string arg = argv[1 ];
101
+ if (std::all_of (arg.begin (), arg.end (), ::isdigit)) {
102
+ try {
103
+ kNumberOfEpochs = std::stoll (arg);
104
+ } catch (const std::invalid_argument& ia) {
105
+ // If unable to parse, do nothing and keep the default value
106
+ }
107
+ }
108
+ }
109
+ std::cout << " Traning with number of epochs: " << kNumberOfEpochs << std::endl;
110
+
79
111
torch::manual_seed (1 );
80
112
81
113
// Create the device we pass around based on whether CUDA is available.
@@ -88,33 +120,15 @@ int main(int argc, const char* argv[]) {
88
120
DCGANGenerator generator (kNoiseSize );
89
121
generator->to (device);
90
122
91
- nn::Sequential discriminator (
92
- // Layer 1
93
- nn::Conv2d (
94
- nn::Conv2dOptions (1 , 64 , 4 ).stride (2 ).padding (1 ).bias (false )),
95
- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
96
- // Layer 2
97
- nn::Conv2d (
98
- nn::Conv2dOptions (64 , 128 , 4 ).stride (2 ).padding (1 ).bias (false )),
99
- nn::BatchNorm2d (128 ),
100
- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
101
- // Layer 3
102
- nn::Conv2d (
103
- nn::Conv2dOptions (128 , 256 , 4 ).stride (2 ).padding (1 ).bias (false )),
104
- nn::BatchNorm2d (256 ),
105
- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
106
- // Layer 4
107
- nn::Conv2d (
108
- nn::Conv2dOptions (256 , 1 , 3 ).stride (1 ).padding (0 ).bias (false )),
109
- nn::Sigmoid ());
123
+ nn::Sequential discriminator = create_discriminator ();
110
124
discriminator->to (device);
111
125
112
126
// Assume the MNIST dataset is available under `kDataFolder`;
113
127
auto dataset = torch::data::datasets::MNIST (kDataFolder )
114
128
.map (torch::data::transforms::Normalize<>(0.5 , 0.5 ))
115
129
.map (torch::data::transforms::Stack<>());
116
- const int64_t batches_per_epoch =
117
- std::ceil (dataset.size ().value () / static_cast <double >(kBatchSize ));
130
+ const int64_t batches_per_epoch = static_cast < int64_t >(
131
+ std::ceil (dataset.size ().value () / static_cast <double >(kBatchSize ))) ;
118
132
119
133
auto data_loader = torch::data::make_data_loader (
120
134
std::move (dataset),
@@ -136,7 +150,7 @@ int main(int argc, const char* argv[]) {
136
150
int64_t checkpoint_counter = 1 ;
137
151
for (int64_t epoch = 1 ; epoch <= kNumberOfEpochs ; ++epoch) {
138
152
int64_t batch_index = 0 ;
139
- for (torch::data::Example<>& batch : *data_loader) {
153
+ for (const torch::data::Example<>& batch : *data_loader) {
140
154
// Train discriminator with real images.
141
155
discriminator->zero_grad ();
142
156
torch::Tensor real_images = batch.data .to (device);
0 commit comments