Skip to content

Commit a11fe63

Browse files
committed
Change the cpp/dcgan
1 parent b88d805 commit a11fe63

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

cpp/dcgan/dcgan.cpp

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ const int64_t kNoiseSize = 100;
1010
// The batch size for training.
1111
const int64_t kBatchSize = 64;
1212

13-
// The number of epochs to train.
14-
const int64_t kNumberOfEpochs = 30;
13+
// The default number of epochs to train.
14+
int64_t kNumberOfEpochs = 30;
1515

1616
// Where to find the MNIST dataset.
1717
const char* kDataFolder = "./data";
@@ -75,7 +75,39 @@ struct DCGANGeneratorImpl : nn::Module {
7575

7676
TORCH_MODULE(DCGANGenerator);
7777

78+
nn::Sequential create_discriminator() {
79+
return nn::Sequential(
80+
// Layer 1
81+
nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
82+
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
83+
// Layer 2
84+
nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
85+
nn::BatchNorm2d(128),
86+
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
87+
// Layer 3
88+
nn::Conv2d(
89+
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
90+
nn::BatchNorm2d(256),
91+
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
92+
// Layer 4
93+
nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
94+
nn::Sigmoid());
95+
}
96+
7897
int main(int argc, const char* argv[]) {
98+
99+
if (argc > 1) {
100+
std::string arg = argv[1];
101+
if (std::all_of(arg.begin(), arg.end(), ::isdigit)) {
102+
try {
103+
kNumberOfEpochs = std::stoll(arg);
104+
} catch (const std::invalid_argument& ia) {
105+
// If unable to parse, do nothing and keep the default value
106+
}
107+
}
108+
}
109+
std::cout << "Traning with number of epochs: " << kNumberOfEpochs << std::endl;
110+
79111
torch::manual_seed(1);
80112

81113
// Create the device we pass around based on whether CUDA is available.
@@ -88,33 +120,15 @@ int main(int argc, const char* argv[]) {
88120
DCGANGenerator generator(kNoiseSize);
89121
generator->to(device);
90122

91-
nn::Sequential discriminator(
92-
// Layer 1
93-
nn::Conv2d(
94-
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
95-
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
96-
// Layer 2
97-
nn::Conv2d(
98-
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
99-
nn::BatchNorm2d(128),
100-
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
101-
// Layer 3
102-
nn::Conv2d(
103-
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
104-
nn::BatchNorm2d(256),
105-
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
106-
// Layer 4
107-
nn::Conv2d(
108-
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
109-
nn::Sigmoid());
123+
nn::Sequential discriminator = create_discriminator();
110124
discriminator->to(device);
111125

112126
// Assume the MNIST dataset is available under `kDataFolder`;
113127
auto dataset = torch::data::datasets::MNIST(kDataFolder)
114128
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
115129
.map(torch::data::transforms::Stack<>());
116-
const int64_t batches_per_epoch =
117-
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
130+
const int64_t batches_per_epoch = static_cast<int64_t>(
131+
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize)));
118132

119133
auto data_loader = torch::data::make_data_loader(
120134
std::move(dataset),
@@ -136,7 +150,7 @@ int main(int argc, const char* argv[]) {
136150
int64_t checkpoint_counter = 1;
137151
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
138152
int64_t batch_index = 0;
139-
for (torch::data::Example<>& batch : *data_loader) {
153+
for (const torch::data::Example<>& batch : *data_loader) {
140154
// Train discriminator with real images.
141155
discriminator->zero_grad();
142156
torch::Tensor real_images = batch.data.to(device);

run_cpp_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function dcgan() {
102102
make
103103
if [ $? -eq 0 ]; then
104104
echo "Successfully built $EXAMPLE"
105-
./$EXAMPLE # Run the executable
105+
./$EXAMPLE 5 # Run the executable with kNumberOfEpochs = 5
106106
check_run_success $EXAMPLE
107107
else
108108
error "Failed to build $EXAMPLE"

0 commit comments

Comments
 (0)