Skip to content

Commit

Permalink
Change the cpp/dcgan
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts committed Jan 24, 2024
1 parent b88d805 commit a11fe63
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
62 changes: 38 additions & 24 deletions cpp/dcgan/dcgan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const int64_t kNoiseSize = 100;
// The batch size for training.
const int64_t kBatchSize = 64;

// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;
// The default number of epochs to train.
int64_t kNumberOfEpochs = 30;

// Where to find the MNIST dataset.
const char* kDataFolder = "./data";
Expand Down Expand Up @@ -75,7 +75,39 @@ struct DCGANGeneratorImpl : nn::Module {

TORCH_MODULE(DCGANGenerator);

nn::Sequential create_discriminator() {
return nn::Sequential(
// Layer 1
nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
}

int main(int argc, const char* argv[]) {

if (argc > 1) {
std::string arg = argv[1];
if (std::all_of(arg.begin(), arg.end(), ::isdigit)) {
try {
kNumberOfEpochs = std::stoll(arg);
} catch (const std::invalid_argument& ia) {
// If unable to parse, do nothing and keep the default value
}
}
}
std::cout << "Traning with number of epochs: " << kNumberOfEpochs << std::endl;

torch::manual_seed(1);

// Create the device we pass around based on whether CUDA is available.
Expand All @@ -88,33 +120,15 @@ int main(int argc, const char* argv[]) {
DCGANGenerator generator(kNoiseSize);
generator->to(device);

nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
nn::Sequential discriminator = create_discriminator();
discriminator->to(device);

// Assume the MNIST dataset is available under `kDataFolder`;
auto dataset = torch::data::datasets::MNIST(kDataFolder)
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch =
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
const int64_t batches_per_epoch = static_cast<int64_t>(
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize)));

auto data_loader = torch::data::make_data_loader(
std::move(dataset),
Expand All @@ -136,7 +150,7 @@ int main(int argc, const char* argv[]) {
int64_t checkpoint_counter = 1;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
for (const torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data.to(device);
Expand Down
2 changes: 1 addition & 1 deletion run_cpp_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function dcgan() {
make
if [ $? -eq 0 ]; then
echo "Successfully built $EXAMPLE"
./$EXAMPLE # Run the executable
./$EXAMPLE 5 # Run the executable with kNumberOfEpochs = 5
check_run_success $EXAMPLE
else
error "Failed to build $EXAMPLE"
Expand Down

0 comments on commit a11fe63

Please sign in to comment.