From a11fe63885ae32cc22c60af8526e5baf7d23712b Mon Sep 17 00:00:00 2001 From: lancer Date: Wed, 24 Jan 2024 11:32:14 -0800 Subject: [PATCH] Change the cpp/dcgan --- cpp/dcgan/dcgan.cpp | 62 +++++++++++++++++++++++++++------------------ run_cpp_examples.sh | 2 +- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/cpp/dcgan/dcgan.cpp b/cpp/dcgan/dcgan.cpp index acd1bca781..88b2b4dca2 100644 --- a/cpp/dcgan/dcgan.cpp +++ b/cpp/dcgan/dcgan.cpp @@ -10,8 +10,8 @@ const int64_t kNoiseSize = 100; // The batch size for training. const int64_t kBatchSize = 64; -// The number of epochs to train. -const int64_t kNumberOfEpochs = 30; +// The default number of epochs to train. +int64_t kNumberOfEpochs = 30; // Where to find the MNIST dataset. const char* kDataFolder = "./data"; @@ -75,7 +75,39 @@ struct DCGANGeneratorImpl : nn::Module { TORCH_MODULE(DCGANGenerator); +nn::Sequential create_discriminator() { + return nn::Sequential( + // Layer 1 + nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)), + nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), + // Layer 2 + nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)), + nn::BatchNorm2d(128), + nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), + // Layer 3 + nn::Conv2d( + nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)), + nn::BatchNorm2d(256), + nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), + // Layer 4 + nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)), + nn::Sigmoid()); +} + int main(int argc, const char* argv[]) { + + if (argc > 1) { + std::string arg = argv[1]; + if (std::all_of(arg.begin(), arg.end(), ::isdigit)) { + try { + kNumberOfEpochs = std::stoll(arg); + } catch (const std::invalid_argument& ia) { + // If unable to parse, do nothing and keep the default value + } + } + } + std::cout << "Traning with number of epochs: " << kNumberOfEpochs << std::endl; + torch::manual_seed(1); // Create the device we pass around based on whether CUDA is available. @@ -88,33 +120,15 @@ int main(int argc, const char* argv[]) { DCGANGenerator generator(kNoiseSize); generator->to(device); - nn::Sequential discriminator( - // Layer 1 - nn::Conv2d( - nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)), - nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), - // Layer 2 - nn::Conv2d( - nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)), - nn::BatchNorm2d(128), - nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), - // Layer 3 - nn::Conv2d( - nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)), - nn::BatchNorm2d(256), - nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)), - // Layer 4 - nn::Conv2d( - nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)), - nn::Sigmoid()); + nn::Sequential discriminator = create_discriminator(); discriminator->to(device); // Assume the MNIST dataset is available under `kDataFolder`; auto dataset = torch::data::datasets::MNIST(kDataFolder) .map(torch::data::transforms::Normalize<>(0.5, 0.5)) .map(torch::data::transforms::Stack<>()); - const int64_t batches_per_epoch = - std::ceil(dataset.size().value() / static_cast(kBatchSize)); + const int64_t batches_per_epoch = static_cast( + std::ceil(dataset.size().value() / static_cast(kBatchSize))); auto data_loader = torch::data::make_data_loader( std::move(dataset), @@ -136,7 +150,7 @@ int main(int argc, const char* argv[]) { int64_t checkpoint_counter = 1; for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) { int64_t batch_index = 0; - for (torch::data::Example<>& batch : *data_loader) { + for (const torch::data::Example<>& batch : *data_loader) { // Train discriminator with real images. discriminator->zero_grad(); torch::Tensor real_images = batch.data.to(device); diff --git a/run_cpp_examples.sh b/run_cpp_examples.sh index 5dfb07343a..57da57aa01 100644 --- a/run_cpp_examples.sh +++ b/run_cpp_examples.sh @@ -102,7 +102,7 @@ function dcgan() { make if [ $? -eq 0 ]; then echo "Successfully built $EXAMPLE" - ./$EXAMPLE # Run the executable + ./$EXAMPLE 5 # Run the executable with kNumberOfEpochs = 5 check_run_success $EXAMPLE else error "Failed to build $EXAMPLE"