From 246c2706c1c58fb8c5147a344486e0f4edc1d51d Mon Sep 17 00:00:00 2001 From: lancer Date: Wed, 24 Jan 2024 18:14:17 -0800 Subject: [PATCH] Use an open source argparse --- .github/workflows/main_cpp.yml | 9 ++++++++- cpp/dcgan/dcgan.cpp | 27 +++++++++++++++------------ run_cpp_examples.sh | 2 +- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/.github/workflows/main_cpp.yml b/.github/workflows/main_cpp.yml index 75cf01be11..97a8def958 100644 --- a/.github/workflows/main_cpp.yml +++ b/.github/workflows/main_cpp.yml @@ -31,7 +31,14 @@ jobs: run: | sudo apt -y install libtbb-dev sudo apt install libopencv-dev - + - name: Install argparse + run: | + git clone https://github.com/p-ranav/argparse + cd argparse + mkdir build + cd build + cmake -DARGPARSE_BUILD_SAMPLES=off -DARGPARSE_BUILD_TESTS=off .. + sudo make install # Alternatively, you can install OpenCV from source # - name: Install OpenCV from source # run: | diff --git a/cpp/dcgan/dcgan.cpp b/cpp/dcgan/dcgan.cpp index 88b2b4dca2..a1232d2a9f 100644 --- a/cpp/dcgan/dcgan.cpp +++ b/cpp/dcgan/dcgan.cpp @@ -1,5 +1,5 @@ #include - +#include #include #include #include @@ -95,18 +95,21 @@ nn::Sequential create_discriminator() { } int main(int argc, const char* argv[]) { - - if (argc > 1) { - std::string arg = argv[1]; - if (std::all_of(arg.begin(), arg.end(), ::isdigit)) { - try { - kNumberOfEpochs = std::stoll(arg); - } catch (const std::invalid_argument& ia) { - // If unable to parse, do nothing and keep the default value - } - } + argparse::ArgumentParser parser("cpp/dcgan example"); + parser.add_argument("--epochs") + .help("Number of epochs to train") + .default_value(kNumberOfEpochs) + .scan<'i', int64_t>(); + try { + parser.parse_args(argc, argv); + } catch (const std::exception& err) { + std::cout << err.what() << std::endl; + std::cout << parser; + std::exit(1); } - std::cout << "Traning with number of epochs: " << kNumberOfEpochs << std::endl; + kNumberOfEpochs = parser.get("--epochs"); + std::cout << "Traning with number of epochs: " << kNumberOfEpochs + << std::endl; torch::manual_seed(1); diff --git a/run_cpp_examples.sh b/run_cpp_examples.sh index 57da57aa01..e1a912ff0d 100644 --- a/run_cpp_examples.sh +++ b/run_cpp_examples.sh @@ -102,7 +102,7 @@ function dcgan() { make if [ $? -eq 0 ]; then echo "Successfully built $EXAMPLE" - ./$EXAMPLE 5 # Run the executable with kNumberOfEpochs = 5 + ./$EXAMPLE --epochs 5 # Run the executable with kNumberOfEpochs = 5 check_run_success $EXAMPLE else error "Failed to build $EXAMPLE"