diff --git a/.github/workflows/eera.yml b/.github/workflows/eera.yml index dc6a8a4e..a4ed106f 100644 --- a/.github/workflows/eera.yml +++ b/.github/workflows/eera.yml @@ -23,6 +23,11 @@ jobs: steps: - uses: actions/checkout@v2 + - uses: actions/checkout@v2 + with: + repository: ScottishCovidResponse/data_pipeline_api + path: data_pipeline_api + - name: Install Dependencies ( Ubuntu ) run: | sudo apt-get update @@ -39,6 +44,16 @@ jobs: run : brew update && brew install gsl cppcheck lcov poppler htmldoc graphviz doxygen if: matrix.os == 'macos-latest' + - name: Data pipeline Dependencies + run : | + sudo apt-get update && sudo apt-get install -y python3-setuptools python3-venv && sudo rm -rf /var/lib/apt/lists/* + python3 -m venv .venv + source .venv/bin/activate + pip install wheel + pip install -r data_pipeline_api/bindings/cpp/requirements.txt + pip install data_pipeline_api + if: matrix.os == 'ubuntu-20.04' + - name: Format Code ( Ubuntu GCC Master ) run: | git config --local user.email "action@github.com" @@ -56,18 +71,33 @@ jobs: force: false if: matrix.os == 'ubuntu-20.04' && matrix.config.compiler == 'gcc' && github.ref == 'refs/heads/master' + - name: Data pipeline API Compile + env: + CC: ${{ matrix.config.compiler }} + CXX: ${{ matrix.config.compilerpp }} + run : | + source .venv/bin/activate + cd data_pipeline_api/bindings/cpp + cmake -H. -Bbuild + cmake --build build + if: matrix.os == 'ubuntu-20.04' + - name: Compile env: CC: ${{ matrix.config.compiler }} CXX: ${{ matrix.config.compilerpp }} run: | + source .venv/bin/activate mkdir build cd build - cmake .. -DCODE_COVERAGE=ON -DCLANG_TIDY=ON + cmake .. -DCODE_COVERAGE=ON -DCLANG_TIDY=ON -DDATA_PIPELINE=$GITHUB_WORKSPACE/data_pipeline_api make 2>&1 | tee clang_tidy_build_results.log - name: Run regression tests + env: + PYTHONPATH: data_pipeline_api run: | + source .venv/bin/activate ./scripts/RunRegressionTests.sh 1 24 if [ $? -eq 0 ]; then echo "Regression tests completed successfully" @@ -78,8 +108,12 @@ jobs: fi - name: Run unit tests + env: + PYTHONPATH: data_pipeline_api run: | - ./build/bin/Covid19EERAModel-unit_tests + source .venv/bin/activate + cd build + ./bin/Covid19EERAModel-unit_tests if [ $? -eq 0 ]; then echo "Unit tests completed successfully" exit 0 diff --git a/.gitignore b/.gitignore index f1ea0ba9..aae9d76a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ site/* index.html src/tclap/ .DS_Store - +test/datapipeline/access-*.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index d6eb59e2..015bbbb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ set(PROJECT_NAME Covid19EERAModel) project(${PROJECT_NAME} VERSION 0.10.0 LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake_modules) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_FLAGS "-DROOT_DIR=\\\"${CMAKE_SOURCE_DIR}\\\" -DVERSION=\\\"${PROJECT_VERSION}\\\" ") diff --git a/README.md b/README.md index 94181d26..24067a14 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,6 @@ The model requires a number of input files to run, in addition to the command li | scot_age.csv | Proportion of health board populations in each age group | All | | scot_data.csv | Timeseries of observed disease cases, by health board | Inference only| | scot_deaths.csv | Timeseries of observed disease deaths, by health board | Inference only | -| scot_frail.csv | Probability of frailty, by age group | All | | waifw_home.csv | Age Mixing Matrix (Home)| All | | waifw_norm.csv | Age Mixing Matrix (All contact included)| All | | waifw_sdist.csv | Age Mixing Matrix (Social Distancing)| All | @@ -178,9 +177,6 @@ CSV file containing the proportion of people in each age group, per health board #### scot_data.csv, scot_deaths.csv CSV file containing the timeseries of cases and deaths, per health board. Each row corresponds to a different health board, while ach column is a day in the time series. The first column is the toal population of the health board. -#### scot_frail.csv -CSV file containing the probabilities of frailty for each age group, by health board. Each column is an age group. Each row is a health board, with the exception of the last row, which is for the whole of Scotland. - #### waifw_home.csv, waifw_norm.csv, waifw_sdist.csv CSV files containing the age mixing matrices for people (1) isolating at home, (2) behaving normally, and (3) socially distancing. @@ -193,6 +189,68 @@ Index,p_inf,p_hcw,c_hcw,d,q,p_s,rrd,intro, T_lat, juvp_s, T_inf, T_rec, T_sym, T ``` Each row in the file contains 17 entries: the first is the index of the row; the following 8 are the inferred posterior parameters; and the remaining 8 are model fixed parameters. The row selected for use in the prediction run will be that specified by the index argument on the command line (see Prediction Mode discussion below). +### Input - Data pipeline + +The intention with the data pipeline is to obtain relevant input data from a shared remote source and to return any results similarly to a shared remote destination. The workflow involves a distinct download stage of the data before running the model and also an upload step after the model has completed. + +The download is carried out using the `pipeline_download` script that is supplied with the [Data Pipeline API](git@github.com:ScottishCovidResponse/data_pipeline_api.git). See instruction in that repository for setting up. + +To action the download, a config `.yaml` file must be supplied similar to this: + +``` +pipeline_download --config /config.yaml +``` + +For this model, the following elements are expected to be available via the data pipeline download. An example `config.yaml` file is located in `test/datapipeline/config.yaml` in the git repository: + +| Local item | Data pipeline | +| ------------- |:-------------:| +| \[Fixed Parameters\], T\_lat | "fixed-parameters/T\_lat", "T\_lat" | +| \[Fixed Parameters\], juvp\_s | "fixed-parameters/juvp\_s", "juvp\_s" | +| \[Fixed Parameters\], T\_inf | "fixed-parameters/T\_inf", "T\_inf" | +| \[Fixed Parameters\], T\_rec | "fixed-parameters/T\_rec", "T\_rec" | +| \[Fixed Parameters\], T\_sym | "fixed-parameters/T\_sym", "T\_sym" | +| \[Fixed Parameters\], T\_hos | "fixed-parameters/T\_hos", "T\_hos" | +| \[Fixed Parameters\], K | "fixed-parameters/K", "K" | +| \[Fixed Parameters\], inf\_asym | "fixed-parameters/inf\_asym", "inf\_asym" | +| \[Fixed Parameters\], totN\_hcw | "fixed-parameters/total\_hcw", "total\_hcw" | +| \[Fixed Parameters\], day\_shut | "fixed-parameters/day\_shut", "day\_shut" | +| \[Priors Settings\], prior\_pinf\_shape1 | "prior-distributions/pinf", "pinf", "alpha" | +| \[Priors Settings\], prior\_pinf\_shape2 | "prior-distributions/pinf", "pinf", "beta" | +| \[Priors Settings\], prior\_phcw\_shape1 | "prior-distributions/phcw", "phcw", "alpha" | +| \[Priors Settings\], prior\_phcw\_shape2 | "prior-distributions/phcw", "phcw", "beta" | +| \[Priors Settings\], prior\_chcw\_mean | "prior-distributions/chcw", "chcw", "lambda" | +| \[Priors Settings\], prior\_d\_shape1 | "prior-distributions/d", "d", "alpha" | +| \[Priors Settings\], prior\_d\_shape2 | "prior-distributions/d", "d", "beta" | +| \[Priors Settings\], prior\_q\_shape1 | "prior-distributions/q", "q", "alpha" | +| \[Priors Settings\], prior\_q\_shape2 | "prior-distributions/q", "q", "beta" | +| \[Priors Settings\], prior\_lambda\_shape1 | "prior-distributions/lambda", "lambda", "a" | +| \[Priors Settings\], prior\_lambda\_shape2 | "prior-distributions/lambda", "lambda", "b" | +| \[Priors Settings\], prior\_ps\_shape1 | "prior-distributions/ps", "ps", "alpha" | +| \[Priors Settings\], prior\_ps\_shape2 | "prior-distributions/ps", "ps", "beta" | +| \[Priors Settings\], prior\_rrd\_shape1 | "prior-distributions/rrd", "rrd", "k" | +| \[Priors Settings\], prior\_rrd\_shape2 | "prior-distributions/rrd", "rrd", "theta" | +| scot\_data.csv | "population-data/data\_for\_scotland", "data" | +| scot\_age.csv | "population-data/data\_for\_scotland", "age" | +| scot\_deaths.csv | "population-data/data\_for\_scotland", "deaths" | +| waifw\_norm.csv | "contact-data/who\_acquired\_infection\_from\_whom", "norm" | +| waifw\_home.csv | "contact-data/who\_acquired\_infection\_from\_whom", "home" | +| waifw\_sdist.csv | "contact-data/who\_acquired\_infection\_from\_whom", "sdist" | +| cfr\_byage.csv | "prob\_hosp\_and\_cfr/data\_for\_scotland", "cfr\_byage" | +| posterior\_parameters.csv | "posterior\_parameters/data\_for\_scotland", "posterior\_parameters" | + +Once the data has been successfully downloaded the model may be run as specified above but with the addition of the `-c` option indicating to use the data pipeline for the above elements instead of local files. + +This requires visibility of the `data_pipeline_api` for Python. If it has been installed via `pip` or `conda` this will already be the case, if the API has been cloned only then `PYTHONPATH` needs amending for this: +``` +$ export PYTHONPATH=/data_pipeline_api:$PYTHONPATH +``` +The command is then: +``` +$ build/bin/Covid19EERAModel -s original -m inference -c /config.yaml +``` +Once completed, results should be uploaded, which is TBD. + ### Prediction mode The model can be run in a prediction mode, where a fixed set of parameters is supplied to the model, and the model is run for a fixed number of simulation steps. @@ -268,6 +326,17 @@ The regression test script automatically configures each run in line with the ta The default option uses local data to perform the run. The addition of a "-d" flag will switch the regression test to use the data pipeline locally stored test data instead. +This requires visibility of the `data_pipeline_api` for Python. If it has been installed via `pip` or `conda` this will already be the case, if the API has been clone only then `PYTHONPATH` needs amending for this: + +``` +$ export PYTHONPATH=/data_pipeline_api:$PYTHONPATH +``` +Then run as follows: + +``` +$ ./scripts/RunRegressionTests 4 9 -d +``` + **Note:** The regression tests are an aid to refactoring with confidence: they should not be considered confirmation of the code's correctness. The reference outputs are updated periodically based on changes in the core model logic. ### Unit tests diff --git a/cmake/git_watcher.cmake b/cmake/git_watcher.cmake index 92673e5a..f0ebc3e3 100644 --- a/cmake/git_watcher.cmake +++ b/cmake/git_watcher.cmake @@ -88,6 +88,7 @@ set(_state_variable_names # >>> # 1. Add the name of the additional git variable you're interested in monitoring # to this list. + GIT_REMOTE_ORIGIN_URL ) @@ -165,7 +166,10 @@ function(GetGitState _working_dir) # "execute_process()" command. Be sure to set them in # the environment using the same variable name you added # to the "_state_variable_names" list. - + RunGitCommand(config --get remote.origin.url) + if(exit_code EQUAL 0) + set(ENV{GIT_REMOTE_ORIGIN_URL} "${output}") + endif() endfunction() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8c6b5193..bc09c709 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,7 +6,7 @@ find_package(GSL REQUIRED) set(PRE_CONFIGURE_FILE "Git.cpp.in") set(POST_CONFIGURE_FILE "${CMAKE_CURRENT_BINARY_DIR}/Git.cpp") include(${CMAKE_SOURCE_DIR}/cmake/git_watcher.cmake) -add_library(git SHARED ${POST_CONFIGURE_FILE}) +add_library(git STATIC ${POST_CONFIGURE_FILE}) target_include_directories(git PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) add_dependencies(git check_git) @@ -16,7 +16,7 @@ configure_file(${PRECONFIGURE_DEPENDENCY_FILE} ${POSTCONFIGURE_DEPENDENCY_FILE} list(APPEND src_files ${POSTCONFIGURE_DEPENDENCY_FILE}) set (PROJECT_LIBS ${PROJECT_NAME}-lib) -add_library(${PROJECT_LIBS} SHARED ${src_files}) +add_library(${PROJECT_LIBS} STATIC ${src_files}) target_link_libraries(${PROJECT_LIBS} PUBLIC GSL::gsl GSL::gslcblas) target_include_directories(${PROJECT_LIBS} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) target_link_libraries(${PROJECT_LIBS} PUBLIC ${TCLAP}) diff --git a/src/Git.cpp.in b/src/Git.cpp.in index 0c448ab6..a1a3a676 100644 --- a/src/Git.cpp.in +++ b/src/Git.cpp.in @@ -44,4 +44,7 @@ std::string GitMetadata::CommitDate() { } std::string GitMetadata::Tag() { return "@GIT_TAG@"; +} +std::string GitMetadata::URL() { + return "@GIT_REMOTE_ORIGIN_URL@"; } \ No newline at end of file diff --git a/src/Git.h b/src/Git.h index e4ea000b..cd059b44 100644 --- a/src/Git.h +++ b/src/Git.h @@ -47,4 +47,6 @@ class GitMetadata { static std::string CommitDate(); // The commit tag static std::string Tag(); + // The Remote Origin URL + static std::string URL(); }; diff --git a/src/IO-datapipeline.cpp b/src/IO-datapipeline.cpp index 3b416429..4f3793d0 100644 --- a/src/IO-datapipeline.cpp +++ b/src/IO-datapipeline.cpp @@ -13,21 +13,28 @@ namespace EERAModel { namespace IO { -IOdatapipeline::IOdatapipeline(string params_path, string dpconfig_path) +IOdatapipeline::IOdatapipeline(string params_path, string model_config, Utilities::logging_stream::Sptr log_stream, string dpconfig_path) { ParamsPath = params_path; + ModelConfigDir = model_config; // Shouldn't need this if data pipeline active + log = log_stream; + datapipelineActive = false; + // std::cout << "ParamsPath: " << ParamsPath << "\n"; + // std::cout << "ModelConfig: " << ModelConfigDir << "\n"; + // std::cout << "ConfigPath: " << dpconfig_path << "\n"; + if (dpconfig_path != "") { - const char *uri = "https://whatever"; // I'm guessing this is the repo for the model - const char *git_sha = "git_sha"; // And this the version ID, need to find these both + std::string uri = GitMetadata::URL(); // "https://whatever"; I'm guessing this is the repo for the model + std::string git_sha = GitMetadata::CommitSHA1(); // And this the version ID, need to find these both - std::cout << "ParamsPath: " << ParamsPath << "\n"; - std::cout << "ConfigPath: " << dpconfig_path << "\n"; + // std::cout << "URI: " << uri << "\n"; + // std::cout << "SHA: " << git_sha << "\n"; // If something goes wrong in the opening of the data pipeline an exception will get thrown - dp.reset(new DataPipeline(dpconfig_path, uri, git_sha)); + dp.reset(new DataPipeline(dpconfig_path, uri.c_str(), git_sha.c_str())); datapipelineActive = true; } } @@ -38,13 +45,19 @@ CommonModelInputParameters IOdatapipeline::ReadCommonParameters() if (datapipelineActive) { + (*log) << " From data pipeline" << std::endl; + commonParameters.paramlist = ReadFixedModelParameters(); commonParameters.totN_hcw = dp->read_estimate("fixed-parameters/total_hcw", "total_hcw"); + commonParameters.day_shut = dp->read_estimate("fixed-parameters/day_shut", "day_shut"); } else { + (*log) << " From local parameters.ini" << std::endl; + commonParameters.paramlist = IO::ReadFixedModelParameters(ParamsPath); commonParameters.totN_hcw = ReadNumberFromFile("totN_hcw", "Fixed parameters", ParamsPath); + commonParameters.day_shut = ReadNumberFromFile("day_shut", "Fixed parameters", ParamsPath); } commonParameters.herd_id = ReadNumberFromFile("shb_id", "Settings", ParamsPath); @@ -54,7 +67,6 @@ CommonModelInputParameters IOdatapipeline::ReadCommonParameters() params IOdatapipeline::ReadFixedModelParameters() { - std::cout << "(Datapipeline): ReadFixedModelParameters\n"; params paramlist; paramlist.T_lat = dp->read_estimate("fixed-parameters/T_lat", "T_lat"); @@ -69,6 +81,240 @@ params IOdatapipeline::ReadFixedModelParameters() return paramlist; } +ObservationsForModels IOdatapipeline::ReadModelObservations() +{ + if (!datapipelineActive) + { + return IO::ReadModelObservations(ModelConfigDir, log); + } + else + { + ValidationParameters validationParameters = ImportValidationParameters(ParamsPath); + int nHealthBoards = validationParameters.nHealthBoards; + int nAgeGroups = validationParameters.nAgeGroups; + int nCfrCategories = validationParameters.nCfrCategories; + int nCasesDays = validationParameters.nCasesDays; + + ObservationsForModels observations; + + //Uploading observed disease data + //Note: first vector is the vector of time. value of -1 indicate number of pigs in the herd + //rows from 1 are indivudual health board + //last row is for all of scotland + dparray_to_csv( + "population-data/data_for_scotland", "data", &observations.cases, + nHealthBoards, nCasesDays); + + // TODO: this is indexed by herd_id, and the data file has a titles row that the data pipeline doesn't + // so need to do something about that. Fix this properly, but for now... + + observations.cases.insert(observations.cases.begin(), std::vector(observations.cases[0].size())); + + //Uploading population per age group + //columns are for each individual Health Borad + //last column is for Scotland + //rows are for each age group: [0] Under20,[1] 20-29,[2] 30-39,[3] 40-49,[4] 50-59,[5] 60-69,[6] Over70,[7] HCW + dparray_to_csv( + "population-data/data_for_scotland", "age", &observations.age_pop, + nHealthBoards, nAgeGroups - 1); + + //mean number of daily contacts per age group (overall) + dparray_to_csv( + "contact-data/who_acquired_infection_from_whom", "norm", &observations.waifw_norm, + nAgeGroups, nAgeGroups); + + //mean number of daily contacts per age group (home only) + dparray_to_csv( + "contact-data/who_acquired_infection_from_whom", "home", &observations.waifw_home, + nAgeGroups, nAgeGroups); + + //mean number of daily contacts per age group (not school, not work) + dparray_to_csv( + "contact-data/who_acquired_infection_from_whom", "sdist", &observations.waifw_sdist, + nAgeGroups, nAgeGroups); + + //Upload cfr by age group + //col0: p_h: probability of hospitalisation + //col1: cfr: case fatality ratio + //col2: p_d: probability of death, given hospitalisation + //rows are for each age group: [0] Under20,[1] 20-29,[2] 30-39,[3] 40-49,[4] 50-59,[5] 60-69,[6] Over70,[7] HCW + + // TODO: (nCfrCategories - 1) is correct currently for the data pipeline vs local files, but is too arbitrary + // and really the parameters.ini should be set for the data pipeline if used. + dptable_to_csv( + "prob_hosp_and_cfr/data_for_scotland", "cfr_byage", &observations.cfr_byage, + nAgeGroups, -1 /* nCfrCategories - 1 */ ); + + // TODO: THIS IS A BODGE TO MAKE THE REGRESSION TESTS PASS. DON'T USE THIS! + // Fix third column. + + for (int i = 0; i < observations.cfr_byage.size(); ++i) { + std::stringstream str; + str << std::scientific << std::setprecision(14); + str << observations.cfr_byage[i][2]; + observations.cfr_byage[i][2] = atof(str.str().c_str()); + } + + return observations; + } +} + +InferenceConfig IOdatapipeline::ReadInferenceConfig( + const CommonModelInputParameters& commonParameters, const ObservationsForModels& modelObservations) +{ + if (!datapipelineActive) + { + return IO::ReadInferenceConfig(ModelConfigDir, log, commonParameters); + } + else + { + InferenceConfig inferenceConfig; + + IO::ReadLocalInferenceConfig(ParamsPath, log, commonParameters, &inferenceConfig); + + dpdistribution( + "prior-distributions/pinf", "pinf", + "alpha", &inferenceConfig.prior_pinf_shape1, "beta", &inferenceConfig.prior_pinf_shape2); + + dpdistribution( + "prior-distributions/phcw", "phcw", + "alpha", &inferenceConfig.prior_phcw_shape1, "beta", &inferenceConfig.prior_phcw_shape2); + + dpdistribution( + "prior-distributions/chcw", "chcw", + "lambda", &inferenceConfig.prior_chcw_mean); + + dpdistribution( + "prior-distributions/d", "d", + "alpha", &inferenceConfig.prior_d_shape1, "beta", &inferenceConfig.prior_d_shape2); + + dpdistribution( + "prior-distributions/q", "q", + "alpha", &inferenceConfig.prior_q_shape1, "beta", &inferenceConfig.prior_q_shape2); + + dpdistribution( + "prior-distributions/lambda", "lambda", + "a", &inferenceConfig.prior_lambda_shape1, "b", &inferenceConfig.prior_lambda_shape2); + + dpdistribution( + "prior-distributions/ps", "ps", + "alpha", &inferenceConfig.prior_ps_shape1, "beta", &inferenceConfig.prior_ps_shape2); + + dpdistribution( + "prior-distributions/rrd", "rrd", + "k", &inferenceConfig.prior_rrd_shape1, "theta", &inferenceConfig.prior_rrd_shape2); + + inferenceConfig.observations = ReadInferenceObservations(modelObservations); + + return inferenceConfig; + } +} + +ObservationsForInference IOdatapipeline::ReadInferenceObservations(const ObservationsForModels& modelObservations) +{ + ValidationParameters validationParameters = ImportValidationParameters(ParamsPath); + int nHealthBoards = validationParameters.nHealthBoards; + int nCasesDays = validationParameters.nCasesDays; + + ObservationsForInference observations; + + (*log) << "Observations For Inference Config:" << std::endl; + + //Uploading observed disease data + //Note: first vector is the vector of time. value of -1 indicate number of pigs in the herd + //rows from 1 are indivudual health board + //last row is for all of scotland + (*log) << "\t- (data pipeline) copying cases from model observations" << std::endl; + observations.cases = modelObservations.cases; + + //Uploading observed death data + //Note: first vector is the vector of time. value of -1 indicate number of pigs in the herd + //rows from 1 are indivudual health board + //last row is for all of scotland + dparray_to_csv( + "population-data/data_for_scotland", "deaths", &observations.deaths, + nHealthBoards, nCasesDays); + + observations.deaths.insert(observations.deaths.begin(), std::vector(observations.deaths[0].size())); + + return observations; +} + +PredictionConfig IOdatapipeline::ReadPredictionConfig(int index, const CommonModelInputParameters& commonParameters) +{ + if (!datapipelineActive) + { + return IO::ReadPredictionConfig(ModelConfigDir, index, log, commonParameters); + } + else + { + PredictionConfig predictionConfig; + + IO::ReadLocalPredictionConfig(ParamsPath, index, log, commonParameters, &predictionConfig); + + const char *posterior_product = "posterior_parameters/data_for_scotland"; + const char *posterior_component = "posterior_parameters"; + + (*log) << "\t- (data pipeline) \"" << posterior_product << "\", \"" << posterior_component << "\"" << std::endl; + + Table posterior_table = dp->read_table(posterior_product, posterior_component); + + ImportConsistencyCheck( + posterior_product, posterior_component, posterior_table.get_column_names().size(), 17, "columns"); + + if (index >= posterior_table.get_column_size()) { + std::stringstream SetSelectError; + SetSelectError << "Parameter set selection out of bounds! Please select between 0-" << (posterior_table.get_column_size() - 1) << "..." << std::endl; + throw std::overflow_error(SetSelectError.str()); + } + + // The posterior parameters are columns 1 to 8 + const char *posterior_colnames[] = { "p_inf", "p_hcw", "c_hcw", "d", "q", "p_s", "rrd", "intro" }; + + predictionConfig.posterior_parameters.resize(8); + + for (int p = 0; p < 8; ++p) { + predictionConfig.posterior_parameters[p] = posterior_table.get_column(posterior_colnames[p])[index]; + } + + // The fixed parameters are columns 9 to 16 + + predictionConfig.fixedParameters.T_lat = posterior_table.get_column("T_lat")[index]; + predictionConfig.fixedParameters.juvp_s = posterior_table.get_column("juvp_s")[index]; + predictionConfig.fixedParameters.T_inf = posterior_table.get_column("T_inf")[index]; + predictionConfig.fixedParameters.T_rec = posterior_table.get_column("T_rec")[index]; + predictionConfig.fixedParameters.T_sym = posterior_table.get_column("T_sym")[index]; + predictionConfig.fixedParameters.T_hos = posterior_table.get_column("T_hos")[index]; + predictionConfig.fixedParameters.K = posterior_table.get_column("K")[index]; + predictionConfig.fixedParameters.inf_asym = posterior_table.get_column("inf_asym")[index]; + + return predictionConfig; + } +} + +void IOdatapipeline::dpdistribution( + const std::string& data_product, const std::string& component, + std::string p1, double *a, std::string p2, double *b) +{ + (*log) << "\t- (data pipeline) \"" << data_product << "\", \"" << component << "\"" << std::endl; + + Distribution input = dp->read_distribution(data_product, component); + *a = input.getParameter(p1.c_str()); + if (b) *b = input.getParameter(p2.c_str()); +} + +void IOdatapipeline::ImportConsistencyCheck( + const std::string& data_product, const std::string& component, + const unsigned int axisLength, const unsigned int expectedValue, const std::string& axisID) +{ + if (axisLength != expectedValue) { + std::stringstream IOMessage; + IOMessage << "Error in data pipeline : \"" << data_product << "\" \"" << component << + "\n Number of " << axisID << ": " << axisLength << + "\n Expected number of " << axisID << ": " << expectedValue << std::endl; + throw IOException(IOMessage.str()); + } +} } // namespace IO } // namespace EERAModel \ No newline at end of file diff --git a/src/IO-datapipeline.h b/src/IO-datapipeline.h index bb577d6b..290019e4 100644 --- a/src/IO-datapipeline.h +++ b/src/IO-datapipeline.h @@ -5,12 +5,39 @@ #include "Utilities.h" #include #include +#include +#include "array.hh" +#include "table.hh" #include "datapipeline.hh" namespace EERAModel { namespace IO { +namespace { + template + std::ostream& operator<<(std::ostream& os, const std::vector& v) { + std::cout << std::scientific << std::setprecision(17); + + os << "(" << v.size() << ") ["; + for (auto &ve : v) { + os << " " << ve; + } + os << " ]"; + return os; + } + + template + std::ostream& operator<<(std::ostream& os, const std::vector>& v) { + os << "(" << v.size() << ") [\n"; + for (auto &ve : v) { + os << " " << ve << "\n"; + } + os << "]"; + return os; + } +} + /** * @class IOdatapipeline * @brief IO class to manage access to the data pipeline API @@ -33,9 +60,11 @@ class IOdatapipeline * must be active and remain in scope for the durection of the use of this class. * * @param params_path Pathname for a parameters ".ini" file + * @param model_config Pathname for a directory containing model configuration data + * @param log_stream An output stream to use for logging diagnostic messages * @param dpconfig_path Pathname for a data pipeline configuration yaml file (empty string disables data pipeline) */ - IOdatapipeline(string params_path, string dpconfig_path = ""); + IOdatapipeline(string params_path, string model_config, Utilities::logging_stream::Sptr log_stream, string dpconfig_path = ""); ~IOdatapipeline() {} @@ -46,6 +75,34 @@ class IOdatapipeline */ CommonModelInputParameters ReadCommonParameters(); + /** + * @brief Read model observation data used by all model types + * + * @return Observation data + */ + ObservationsForModels ReadModelObservations(); + + /** + * @brief Read Inference config data + * + * @param commonParameters Common parameters, could've been read using ReadCommonParameters above + * @param modelObservations Model Observations, could've been read using ReadModelObservations above + * + * @return Inference configuration + */ + InferenceConfig ReadInferenceConfig( + const CommonModelInputParameters& commonParameters, const ObservationsForModels& modelObservations); + + /** + * @brief Read Prediction config data + * + * @param index Index of the parameter set to select from the posterior parameters file + * @param commonParameters Common parameters, could've been read using ReadCommonParameters above + * + * @return Prediction configuration + */ + PredictionConfig ReadPredictionConfig(int index, const CommonModelInputParameters& commonParameters); + private: /** * @private @@ -59,6 +116,18 @@ class IOdatapipeline */ std::string ParamsPath; + /** + * @private + * @brief The path of the model data configspecified parameter ".ini" file + */ + std::string ModelConfigDir; + + /** + * @private + * @brief A stream to use for logging diagnostic messages + */ + Utilities::logging_stream::Sptr log; + /** * @private * @brief Flag indicating if the data pipeline has been requested and open successfully @@ -78,6 +147,129 @@ class IOdatapipeline * @return Parameters data structure */ params ReadFixedModelParameters(); + + /** + * @brief Read the inference observations + * + * @return Inference observations data structure + */ + ObservationsForInference ReadInferenceObservations(const ObservationsForModels& modelObservations); + + /** + * @brief Perform consistency checks on imported data from filePath. + * + * @param data_product The data pipeline data product to load + * @param component The component within the data product + * @param axisLength Length of axis to be checked + * @param expectedLength Expected Length of axis to check against + * @param axisID String holding axis identifier ("rows" or "columns") + */ + void ImportConsistencyCheck( + const std::string& data_product, const std::string& component, + const unsigned int axisLength, const unsigned int expectedValue, const std::string& axisID); + + /** + * @brief Read a 2D array from the data pipeline and copy the contents into a vector of vectors + * + * @param data_product The data pipeline data product to load + * @param component The component within the data product + * @param result A pointer to the vector of vectors for the result + * @param expected_rows if >=0, the number of rows expected in the data + * @param expected_columns, if >=0, the number of columns expected in the data + */ + template + void dparray_to_csv( + const std::string& data_product, const std::string& component, std::vector> *result, + int expected_rows = -1, int expected_columns = -1) + { + (*log) << "\t- (data pipeline) \"" << data_product << "\", \"" << component << "\"" << std::endl; + + Array input = dp->read_array(data_product, component); + std::vector array_sizes = input.size(); + + if (array_sizes.size() != 2) { + // Should complain about this... and probably should check the dimensions as matching what is expected... if that matters. + } + + if (expected_rows >= 0) { + ImportConsistencyCheck(data_product, component, array_sizes[1], expected_rows, "rows"); + } + + if (expected_columns >= 0) { + ImportConsistencyCheck(data_product, component, array_sizes[0], expected_columns, "columns"); + } + + result->resize(0); + + for (int j = 0; j < array_sizes[1]; ++j) { + // Construct a new element of isize size + result->emplace_back(array_sizes[0]); + auto& row = result->back(); + + // Copy the data row + for (int i = 0; i < array_sizes[0]; ++i) { + row[i] = input(i, j); + } + } + + // Some checking + // std::cout << "Original: \"" << data_product << "\" \"" << component << "\"\n"; + // std::cout << " size() = [" << array_sizes[0] << ", " << array_sizes[1] << "]\n"; + // std::cout << " (0,0)=" << input(0,0) << " (1,0)=" << input(1,0) << "\n"; + // std::cout << " (0,1)=" << input(0,1) << " (1,1)=" << input(1,1) << "\n"; + + // std::cout << "VoV:\n"; + // std::cout << " size() = [" << (*result)[0].size() << ", " << result->size() << "]\n"; + // std::cout << " (0,0)=" << (*result)[0][0] << " (1,0)=" << (*result)[0][1] << "\n"; + // std::cout << " (0,1)=" << (*result)[1][0] << " (1,1)=" << (*result)[1][1] << "\n"; + } + + /** + * @brief Read a table from the data pipeline and copy the contents into a vector of vectors + * + * @param data_product The data pipeline data product to load + * @param component The component within the data product + * @param result A pointer to the vector of vectors for the result + * @param expected_rows if >=0, the number of rows expected in the data + * @param expected_columns, if >=0, the number of columns expected in the data + */ + template + void dptable_to_csv( + const std::string& data_product, const std::string& component, std::vector> *result, + int expected_rows = -1, int expected_columns = -1) + { + (*log) << "\t- (data pipeline) \"" << data_product << "\", \"" << component << "\"" << std::endl; + + Table input = dp->read_table(data_product, component); + + // std::cout << input.to_string() << "\n"; + + std::vector columns = input.get_column_names(); + std::size_t col_size = input.get_column_size(); + + if (expected_rows >= 0) { + ImportConsistencyCheck(data_product, component, col_size, expected_rows, "rows"); + } + + if (expected_columns >= 0) { + ImportConsistencyCheck(data_product, component, columns.size(), expected_columns, "columns"); + } + + result->resize(0); + result->resize(col_size); + + for (const auto& column : columns) { + std::vector& column_values = input.get_column(column); + + for (std::size_t e = 0; e < col_size; ++e) { + (*result)[e].push_back(column_values[e]); + } + } + } + + void dpdistribution( + const std::string& data_product, const std::string& component, + std::string p1, double *a, std::string p2 = "", double *b = nullptr); }; diff --git a/src/IO.cpp b/src/IO.cpp index 8d222df4..4ec75841 100644 --- a/src/IO.cpp +++ b/src/IO.cpp @@ -25,10 +25,9 @@ void ImportConsistencyCheck(const std::string& filePath, const unsigned int& axi } -ValidationParameters ImportValidationParameters(const std::string& configDir) +ValidationParameters ImportValidationParameters(const std::string& filePath) { ValidationParameters parameters; - std::string filePath(configDir + "/parameters.ini"); parameters.nHealthBoards = ReadNumberFromFile("nHealthBoards", "Settings", filePath); parameters.nAgeGroups = ReadNumberFromFile("nAgeGroups", "Settings", filePath); @@ -98,7 +97,6 @@ SupplementaryInputParameters ReadSupplementaryParameters(const std::string& Para CommonModelInputParameters ReadCommonParameters(const std::string& ParamsPath) { - std::cout << "(Files): ReadFixedModelParameters\n"; CommonModelInputParameters commonParameters; commonParameters.paramlist = ReadFixedModelParameters(ParamsPath); @@ -108,18 +106,13 @@ CommonModelInputParameters ReadCommonParameters(const std::string& ParamsPath) return commonParameters; } -InferenceConfig ReadInferenceConfig(const std::string& configDir, Utilities::logging_stream::Sptr log) +InferenceConfig ReadInferenceConfig(const std::string& configDir, Utilities::logging_stream::Sptr log, const CommonModelInputParameters& commonParameters) { std::string ParamsPath(configDir + "/parameters.ini"); InferenceConfig inferenceConfig; - inferenceConfig.seedlist = ReadSeedSettings(ParamsPath, log); - inferenceConfig.paramlist = ReadFixedModelParameters(ParamsPath); - - inferenceConfig.herd_id = ReadNumberFromFile("shb_id", "Settings", ParamsPath); - inferenceConfig.day_shut = ReadNumberFromFile("day_shut", "Fixed parameters", ParamsPath); - inferenceConfig.tau = ReadNumberFromFile("tau", "Settings", ParamsPath); + ReadLocalInferenceConfig(ParamsPath, log, commonParameters, &inferenceConfig); inferenceConfig.prior_pinf_shape1 = ReadNumberFromFile("prior_pinf_shape1", "Priors settings", ParamsPath); inferenceConfig.prior_pinf_shape2 = ReadNumberFromFile("prior_pinf_shape2", "Priors settings", ParamsPath); @@ -138,62 +131,37 @@ InferenceConfig ReadInferenceConfig(const std::string& configDir, Utilities::log inferenceConfig.prior_rrd_shape1 = ReadNumberFromFile("prior_rrd_shape1", "Priors settings", ParamsPath); inferenceConfig.prior_rrd_shape2 = ReadNumberFromFile("prior_rrd_shape2", "Priors settings", ParamsPath); - inferenceConfig.nsteps = ReadNumberFromFile("nsteps", "Fit settings", ParamsPath); - inferenceConfig.kernelFactor = ReadNumberFromFile("kernelFactor", "Fit settings", ParamsPath); - inferenceConfig.nSim = ReadNumberFromFile("nSim", "Fit settings", ParamsPath); - inferenceConfig.nParticleLimit = ReadNumberFromFile("nParticLimit", "Fit settings", ParamsPath); - - for (int ii = 1; ii <= inferenceConfig.nsteps; ii++) { - inferenceConfig.toleranceLimit.push_back(0.0); - } - - for (int ii = 0; ii < inferenceConfig.nsteps; ii++) { - std::stringstream KeyName; - KeyName << "Key" << (ii + 1); - inferenceConfig.toleranceLimit[ii] = ReadNumberFromFile(KeyName.str(), "Tolerance settings", ParamsPath); - } - inferenceConfig.observations = ReadInferenceObservations(configDir, log); return inferenceConfig; } -PredictionConfig ReadPredictionConfig(const std::string& configDir, int index, Utilities::logging_stream::Sptr log) -{ - std::string filePath(configDir + "/parameters.ini"); - PredictionConfig predictionConfig; - - predictionConfig.seedlist = ReadSeedSettings(filePath, log); - - predictionConfig.day_shut = ReadNumberFromFile("day_shut", "Fixed parameters", filePath); - - std::string sectionId("Prediction Configuration"); - predictionConfig.n_sim_steps = ReadNumberFromFile("n_sim_steps", - sectionId, filePath); - - predictionConfig.index = index; - - std::string parametersFile(configDir + "/posterior_parameters.csv"); - if (!Utilities::fileExists(parametersFile)) { - std::stringstream error_message; - error_message << "Cannot locate posterior parameters file at " << parametersFile << std::endl; - throw std::runtime_error(error_message.str()); +void ReadLocalInferenceConfig( + const std::string& ParamsPath, Utilities::logging_stream::Sptr log, const CommonModelInputParameters& commonParameters, + InferenceConfig *inferenceConfig) +{ + inferenceConfig->seedlist = ReadSeedSettings(ParamsPath, log); + //inferenceConfig->paramlist = ReadFixedModelParameters(ParamsPath); + inferenceConfig->paramlist = commonParameters.paramlist; + inferenceConfig->herd_id = commonParameters.herd_id; + inferenceConfig->day_shut = commonParameters.day_shut; + inferenceConfig->tau = ReadNumberFromFile("tau", "Settings", ParamsPath); + + inferenceConfig->nsteps = ReadNumberFromFile("nsteps", "Fit settings", ParamsPath); + inferenceConfig->kernelFactor = ReadNumberFromFile("kernelFactor", "Fit settings", ParamsPath); + inferenceConfig->nSim = ReadNumberFromFile("nSim", "Fit settings", ParamsPath); + inferenceConfig->nParticleLimit = ReadNumberFromFile("nParticLimit", "Fit settings", ParamsPath); + + for (int ii = 1; ii <= inferenceConfig->nsteps; ii++) { + inferenceConfig->toleranceLimit.push_back(0.0); } - // The posterior parameters are columns 0 to 7; the fixed parameters are columns 8 to 15 - std::vector modelParameters = ReadPredictionParametersFromFile(parametersFile, predictionConfig.index); - predictionConfig.posterior_parameters = std::vector(modelParameters.begin(), modelParameters.begin() + 8); - predictionConfig.fixedParameters.T_lat = modelParameters[8]; - predictionConfig.fixedParameters.juvp_s = modelParameters[9]; - predictionConfig.fixedParameters.T_inf = modelParameters[10]; - predictionConfig.fixedParameters.T_rec = modelParameters[11]; - predictionConfig.fixedParameters.T_sym = modelParameters[12]; - predictionConfig.fixedParameters.T_hos = modelParameters[13]; - predictionConfig.fixedParameters.K = static_cast(modelParameters[14]); - predictionConfig.fixedParameters.inf_asym = modelParameters[15]; - - return predictionConfig; + for (int ii = 0; ii < inferenceConfig->nsteps; ii++) { + std::stringstream KeyName; + KeyName << "Key" << (ii + 1); + inferenceConfig->toleranceLimit[ii] = ReadNumberFromFile(KeyName.str(), "Tolerance settings", ParamsPath); + } } ObservationsForInference ReadInferenceObservations(const std::string& configDir, Utilities::logging_stream::Sptr log) @@ -221,7 +189,7 @@ ObservationsForInference ReadInferenceObservations(const std::string& configDir, const std::string settings_file = configDir + "/parameters.ini"; - ValidationParameters validationParams = ImportValidationParameters(configDir); + ValidationParameters validationParams = ImportValidationParameters(settings_file); int nHealthBoards = validationParams.nHealthBoards; int nCasesDays = validationParams.nCasesDays; @@ -240,22 +208,64 @@ ObservationsForInference ReadInferenceObservations(const std::string& configDir, return observations; } +PredictionConfig ReadPredictionConfig(const std::string& configDir, int index, Utilities::logging_stream::Sptr log, const CommonModelInputParameters& commonParameters) +{ + std::string filePath(configDir + "/parameters.ini"); + + PredictionConfig predictionConfig; + + ReadLocalPredictionConfig(filePath, index, log, commonParameters, &predictionConfig); + + std::string parametersFile(configDir + "/posterior_parameters.csv"); + if (!Utilities::fileExists(parametersFile)) { + std::stringstream error_message; + error_message << "Cannot locate posterior parameters file at " << parametersFile << std::endl; + throw std::runtime_error(error_message.str()); + } + + // The posterior parameters are columns 0 to 7; the fixed parameters are columns 8 to 15 + std::vector modelParameters = ReadPredictionParametersFromFile(parametersFile, predictionConfig.index); + predictionConfig.posterior_parameters = std::vector(modelParameters.begin(), modelParameters.begin() + 8); + predictionConfig.fixedParameters.T_lat = modelParameters[8]; + predictionConfig.fixedParameters.juvp_s = modelParameters[9]; + predictionConfig.fixedParameters.T_inf = modelParameters[10]; + predictionConfig.fixedParameters.T_rec = modelParameters[11]; + predictionConfig.fixedParameters.T_sym = modelParameters[12]; + predictionConfig.fixedParameters.T_hos = modelParameters[13]; + predictionConfig.fixedParameters.K = static_cast(modelParameters[14]); + predictionConfig.fixedParameters.inf_asym = modelParameters[15]; + + return predictionConfig; +} + +void ReadLocalPredictionConfig( + const std::string& filePath, int index, Utilities::logging_stream::Sptr log, + const CommonModelInputParameters& commonParameters, PredictionConfig *predictionConfig) +{ + predictionConfig->seedlist = ReadSeedSettings(filePath, log); + + predictionConfig->day_shut = commonParameters.day_shut; + + std::string sectionId("Prediction Configuration"); + predictionConfig->n_sim_steps = ReadNumberFromFile("n_sim_steps", + sectionId, filePath); + + predictionConfig->index = index; +} + ObservationsForModels ReadModelObservations(const std::string& configDir, Utilities::logging_stream::Sptr log) { ObservationsForModels observations; - (*log) << "Observations For Models:" << std::endl; - const std::string scot_data_file = configDir + "/scot_data.csv"; const std::string scot_ages_file = configDir + "/scot_age.csv"; const std::string waifw_norm_file = configDir + "/waifw_norm.csv"; const std::string waifw_home_file = configDir + "/waifw_home.csv"; const std::string waifw_sdist_file = configDir + "/waifw_sdist.csv"; const std::string cfr_byage_file = configDir + "/cfr_byage.csv"; - const std::string scot_frail_file = configDir + "/scot_frail.csv"; const std::string settings_file = configDir + "/parameters.ini"; - ValidationParameters validationParameters = ImportValidationParameters(configDir); + ValidationParameters validationParameters = ImportValidationParameters(settings_file); int nHealthBoards = validationParameters.nHealthBoards; int nAgeGroups = validationParameters.nAgeGroups; int nCfrCategories = validationParameters.nCfrCategories; @@ -332,19 +342,6 @@ ObservationsForModels ReadModelObservations(const std::string& configDir, Utilit ImportConsistencyCheck(cfr_byage_file, cfr_rows, nAgeGroups, "rows"); ImportConsistencyCheck(cfr_byage_file, cfr_cols, nCfrCategories, "columns"); - //Upload frailty probability p_f by age group - //columns are for each age group: [0] Under20,[1] 20-29,[2] 30-39,[3] 40-49,[4] 50-59,[5] 60-69,[6] Over70,[7] HCW - //rows are for each individual Health Borad - //last row is for Scotland - (*log) << "\t- " << scot_frail_file << std::endl; - observations.pf_pop = Utilities::read_csv(scot_frail_file, ','); - - unsigned int pf_pop_rows = observations.pf_pop.size(); - unsigned int pf_pop_cols = observations.pf_pop[0].size(); - - ImportConsistencyCheck(scot_frail_file, pf_pop_rows, nHealthBoards, "rows"); - ImportConsistencyCheck(scot_frail_file, pf_pop_cols, nAgeGroups, "columns"); - return observations; } diff --git a/src/IO.h b/src/IO.h index f928bdd0..aee5060f 100644 --- a/src/IO.h +++ b/src/IO.h @@ -41,16 +41,16 @@ class IOException: public std::exception * @param expectedLength Expected Length of axis to check against * @param axisID String holding axis identifier ("rows" or "columns") */ -void ImportConsistencyCheck(const std::string& filePath, const unsigned int& axisLength, const unsigned int& expectedValue); +void ImportConsistencyCheck(const std::string& filePath, const unsigned int& axisLength, const unsigned int& expectedValue, const std::string& axisID); /** * @brief Import parameters used for validating observation data * - * @param configDir Directory containing INI file + * @param filePath File path to parameters INI file * * @return Validation parameters */ -ValidationParameters ImportValidationParameters(const std::string& configDir); +ValidationParameters ImportValidationParameters(const std::string& filePath); /** * @brief Read supplementary parameters used for main.cpp. @@ -77,10 +77,23 @@ CommonModelInputParameters ReadCommonParameters(const std::string& ParamsPath); * * @param configDir Directory containing the configuration and data files * @param log Logger + * @param commonParameters CommonModelInputParameters object that has already been read * * @return Inference parameters */ -InferenceConfig ReadInferenceConfig(const std::string& configDir, Utilities::logging_stream::Sptr log); +InferenceConfig ReadInferenceConfig(const std::string& configDir, Utilities::logging_stream::Sptr log, const CommonModelInputParameters& commonParameters); + +/** + * @brief Read local parameters for the inference mode + * + * @param ParamsPath Location of the parameters .ini file + * @param log Logger + * @param commonParameters CommonModelInputParameters object that has already been read + * @param inferenceConfig Destination to place local inference config + */ +void ReadLocalInferenceConfig( + const std::string& ParamsPath, Utilities::logging_stream::Sptr log, const CommonModelInputParameters& commonParameters, + InferenceConfig *inferenceConfig); /** * @brief Read prediction framework configuration from input files @@ -88,10 +101,24 @@ InferenceConfig ReadInferenceConfig(const std::string& configDir, Utilities::log * @param configDir Directory containing the configuration and data files * @param index Index of the parameter set to select from the posterior parameters file * @param log Logger + * @param commonParameters CommonModelInputParameters object that has already been read * * @return Prediction configuration */ -PredictionConfig ReadPredictionConfig(const std::string& configDir, int index, Utilities::logging_stream::Sptr log); +PredictionConfig ReadPredictionConfig(const std::string& configDir, int index, Utilities::logging_stream::Sptr log, const CommonModelInputParameters& commonParameters); + +/** + * @brief Read local parameters for the prediction mode + * + * @param filePath Location of the parameters .ini file + * @param index Index of the parameter set to select from the posterior parameters file + * @param log Logger + * @param commonParameters CommonModelInputParameters object that has already been read + * @param predictionConfig Destination to place local prediction config + */ +void ReadLocalPredictionConfig( + const std::string& filePath, int index, Utilities::logging_stream::Sptr log, + const CommonModelInputParameters& commonParameters, PredictionConfig *predictionConfig); /** * @brief Read model posterior parameters from a CSV file diff --git a/src/IrishModel.cpp b/src/IrishModel.cpp index 06d31f25..1f6a60ea 100644 --- a/src/IrishModel.cpp +++ b/src/IrishModel.cpp @@ -15,8 +15,7 @@ IrishModel::IrishModel(const CommonModelInputParameters& commonParameters, observations.waifw_norm, observations.waifw_home, observations.waifw_sdist, - observations.cfr_byage, - observations.pf_pop[commonParameters.herd_id - 1] + observations.cfr_byage }; int regionalPopulation = GetPopulationOfRegion( diff --git a/src/ModelCommon.cpp b/src/ModelCommon.cpp index 6bd2eac1..b40fde5d 100644 --- a/src/ModelCommon.cpp +++ b/src/ModelCommon.cpp @@ -41,7 +41,13 @@ int GetPopulationOfRegion(const ObservationsForModels& obs, int region_id) int ComputeNumberOfHCWInRegion(int regionalPopulation, int totalHCW, const ObservationsForModels& obs) { int scotlandPopulation = 0; - for (unsigned int region = 0; region < obs.cases.size() - 1; ++region) { + + // TODO: In the csv files, obs.cases[0] contains column headings so probably shouldn't + // be used in this sum. The value of obs.cases[0][0] will be -1, so unlikely to show + // much? + + // for (unsigned int region = 0; region < obs.cases.size() - 1; ++region) { + for (unsigned int region = 1; region < obs.cases.size() - 1; ++region) { scotlandPopulation += obs.cases[region][0]; } double regionalProportion = static_cast(regionalPopulation) / scotlandPopulation; diff --git a/src/ModelTypes.h b/src/ModelTypes.h index f8424f9f..19ccab6c 100644 --- a/src/ModelTypes.h +++ b/src/ModelTypes.h @@ -111,7 +111,6 @@ struct AgeGroupData std::vector> waifw_home; /*!< mean number of daily contacts between age groups (home only). */ std::vector> waifw_sdist; /*!< mean number of daily contacts between age groups (not school, not work). */ std::vector> cfr_byage; /*!< Case fatality ratio by age. */ - std::vector pf_byage; /*!< Frailty Probability by age. */ }; /** @@ -198,6 +197,7 @@ struct CommonModelInputParameters params paramlist; int herd_id; int totN_hcw; + int day_shut; }; /** diff --git a/src/OriginalModel.cpp b/src/OriginalModel.cpp index 3481989d..94482d8e 100644 --- a/src/OriginalModel.cpp +++ b/src/OriginalModel.cpp @@ -15,8 +15,7 @@ OriginalModel::OriginalModel(const CommonModelInputParameters& commonParameters, observations.waifw_norm, observations.waifw_home, observations.waifw_sdist, - observations.cfr_byage, - observations.pf_pop[commonParameters.herd_id - 1] + observations.cfr_byage }; int regionalPopulation = GetPopulationOfRegion( diff --git a/src/TempModel.cpp b/src/TempModel.cpp index ffc9f424..75936c74 100644 --- a/src/TempModel.cpp +++ b/src/TempModel.cpp @@ -15,8 +15,7 @@ TempModel::TempModel(const CommonModelInputParameters& commonParameters, observations.waifw_norm, observations.waifw_home, observations.waifw_sdist, - observations.cfr_byage, - observations.pf_pop[commonParameters.herd_id - 1] + observations.cfr_byage }; int regionalPopulation = GetPopulationOfRegion( diff --git a/src/main.cpp b/src/main.cpp index b08b9853..b69f366a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,10 +25,15 @@ int main(int argc, char** argv) IO::LogDependencyVersionInfo(logger); arg_parser.logArguments(logger); + // Parameters ini file location const std::string params_addr = std::string(ROOT_DIR)+"/data/parameters.ini"; + // Model Observational data location + std::string modelConfigDir(std::string(ROOT_DIR) + "/data"); + pybind11::scoped_interpreter guard{}; // start the interpreter and keep it alive - IO::IOdatapipeline datapipeline{params_addr, arg_parser.getArgs().datapipeline_path}; + IO::IOdatapipeline datapipeline{ + params_addr, modelConfigDir, logger, arg_parser.getArgs().datapipeline_path}; SupplementaryInputParameters supplementaryParameters = IO::ReadSupplementaryParameters(params_addr, logger); arg_parser.AppendOptions(supplementaryParameters); @@ -45,14 +50,17 @@ int main(int argc, char** argv) Random::RNG::Sptr rng = std::make_shared(randomiser_seed); IO::LogRandomiserSettings(supplementaryParameters, randomiser_seed, logger); - // Import common parameters for all models + (*logger) << "[Common Parameters]:" << std::endl; + // Import common parameters for all models CommonModelInputParameters commonParameters = datapipeline.ReadCommonParameters(); // CommonModelInputParameters commonParameters = IO::ReadCommonParameters(params_addr); + (*logger) << "[Model Observations]:" << std::endl; + // Import model observational data - std::string modelConfigDir(std::string(ROOT_DIR) + "/data"); - ObservationsForModels modelObservations = IO::ReadModelObservations(modelConfigDir, logger); + ObservationsForModels modelObservations = datapipeline.ReadModelObservations(); + // ObservationsForModels modelObservations = IO::ReadModelObservations(modelConfigDir, logger); // Log the disease seed settings IO::LogSeedSettings(supplementaryParameters.seedlist, logger); @@ -75,10 +83,10 @@ int main(int argc, char** argv) // Select the mode to run in - prediction or inference if (ModelModeId::PREDICTION == supplementaryParameters.run_type) { - std::string configDir(std::string(ROOT_DIR) + "/data"); + // std::string configDir(std::string(ROOT_DIR) + "/data"); int index = arg_parser.parameterSetIndex(); - PredictionConfig predictionConfig = IO::ReadPredictionConfig(configDir, index, logger); + PredictionConfig predictionConfig = datapipeline.ReadPredictionConfig(index, commonParameters); IO::LogPredictionConfig(predictionConfig, logger); // Update the model with the fixed parameters from the prediction configuration @@ -91,8 +99,8 @@ int main(int argc, char** argv) } else { - std::string configDir(std::string(ROOT_DIR) + "/data"); - InferenceConfig inferenceConfig = IO::ReadInferenceConfig(configDir, logger); + // std::string configDir(std::string(ROOT_DIR) + "/data"); + InferenceConfig inferenceConfig = datapipeline.ReadInferenceConfig(commonParameters, modelObservations); IO::LogFixedParameters(commonParameters.paramlist, logger); diff --git a/test/datapipeline/config.yaml b/test/datapipeline/config.yaml index 451ec842..f8e8526d 100644 --- a/test/datapipeline/config.yaml +++ b/test/datapipeline/config.yaml @@ -1,41 +1,50 @@ data_directory: data access_log: access-{run_id}.yaml fail_on_hash_mismatch: True -namespace: EERA +run_metadata: + default_input_namespace: EERA read: - where: data_product: fixed-parameters/T_lat - use: - version: 0.1.0 - where: data_product: fixed-parameters/juvp_s - use: - version: 0.1.0 - where: data_product: fixed-parameters/T_inf - use: - version: 0.1.0 - where: data_product: fixed-parameters/T_rec - use: - version: 0.1.0 - where: data_product: fixed-parameters/T_sym - use: - version: 0.1.0 - where: data_product: fixed-parameters/T_hos - use: - version: 0.1.0 - where: data_product: fixed-parameters/K - use: - version: 0.1.0 - where: data_product: fixed-parameters/inf_asym - use: - version: 0.1.0 - where: data_product: fixed-parameters/total_hcw - use: - version: 0.1.0 + - where: + data_product: fixed-parameters/day_shut + - where: + data_product: prob_hosp_and_cfr/data_for_scotland + - where: + data_product: population-data/data_for_scotland + - where: + data_product: posterior_parameters/data_for_scotland + - where: + data_product: contact-data/who_acquired_infection_from_whom + - where: + data_product: prior-distributions/pinf + - where: + data_product: prior-distributions/ps + - where: + data_product: prior-distributions/q + - where: + data_product: prior-distributions/lambda + - where: + data_product: prior-distributions/chcw + - where: + data_product: prior-distributions/d + - where: + data_product: prior-distributions/phcw + - where: + data_product: prior-distributions/rrd diff --git a/test/datapipeline/data/contact-data/who_acquired_infection_from_whom/0.20200729.0/0.20200729.0.h5 b/test/datapipeline/data/contact-data/who_acquired_infection_from_whom/0.20200729.0/0.20200729.0.h5 new file mode 100644 index 00000000..9342b8a5 Binary files /dev/null and b/test/datapipeline/data/contact-data/who_acquired_infection_from_whom/0.20200729.0/0.20200729.0.h5 differ diff --git a/test/datapipeline/data/fixed-parameters/day_shut/0.1.0/0.1.0.toml b/test/datapipeline/data/fixed-parameters/day_shut/0.1.0/0.1.0.toml new file mode 100644 index 00000000..78dc2e1a --- /dev/null +++ b/test/datapipeline/data/fixed-parameters/day_shut/0.1.0/0.1.0.toml @@ -0,0 +1,3 @@ +[day_shut] +type = "point-estimate" +value = 19 diff --git a/test/datapipeline/data/metadata.yaml b/test/datapipeline/data/metadata.yaml index 84cbe3bb..13f49493 100644 --- a/test/datapipeline/data/metadata.yaml +++ b/test/datapipeline/data/metadata.yaml @@ -1,3 +1,139 @@ +- accessibility: 0 + component: rrd + data_product: prior-distributions/rrd + extension: toml + filename: prior-distributions/rrd/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: 19a4746b9c198ac3cce5948091daea0dff4bd0dd + version: 0.1.0 +- accessibility: 0 + component: phcw + data_product: prior-distributions/phcw + extension: toml + filename: prior-distributions/phcw/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: 9225fd7f5cb1597f7eefa3b611a180752d61431a + version: 0.1.0 +- accessibility: 0 + component: d + data_product: prior-distributions/d + extension: toml + filename: prior-distributions/d/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: 32acb95aebc2a5438c6f0b1c997aafe4ac544a14 + version: 0.1.0 +- accessibility: 0 + component: chcw + data_product: prior-distributions/chcw + extension: toml + filename: prior-distributions/chcw/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: b5c9a01e8e5913bcdfd052c63d13d9a6e39b7ab6 + version: 0.1.0 +- accessibility: 0 + component: lambda + data_product: prior-distributions/lambda + extension: toml + filename: prior-distributions/lambda/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: 32b4a4b682fc9fbe8b44eafd90749a1fddf2671a + version: 0.1.0 +- accessibility: 0 + component: q + data_product: prior-distributions/q + extension: toml + filename: prior-distributions/q/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: 74bb2a797958417edc38ac8f2bdedcc641d4efeb + version: 0.1.0 +- accessibility: 0 + component: ps + data_product: prior-distributions/ps + extension: toml + filename: prior-distributions/ps/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: e89e99e00e1a542c4ee7f3aa877f3979ec8d70f5 + version: 0.1.0 +- accessibility: 0 + component: pinf + data_product: prior-distributions/pinf + extension: toml + filename: prior-distributions/pinf/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: 85b63f8befd9b54ce2b1f42a3a0a9072d2d839db + version: 0.1.0 +- accessibility: 0 + component: home + data_product: contact-data/who_acquired_infection_from_whom + extension: h5 + filename: contact-data/who_acquired_infection_from_whom/0.20200729.0/0.20200729.0.h5 + namespace: EERA + verified_hash: ddb1d89f3fbf0bada841ec8ce7bb18f2c168179a + version: 0.20200729.0 +- accessibility: 0 + component: norm + data_product: contact-data/who_acquired_infection_from_whom + extension: h5 + filename: contact-data/who_acquired_infection_from_whom/0.20200729.0/0.20200729.0.h5 + namespace: EERA + verified_hash: ddb1d89f3fbf0bada841ec8ce7bb18f2c168179a + version: 0.20200729.0 +- accessibility: 0 + component: sdist + data_product: contact-data/who_acquired_infection_from_whom + extension: h5 + filename: contact-data/who_acquired_infection_from_whom/0.20200729.0/0.20200729.0.h5 + namespace: EERA + verified_hash: ddb1d89f3fbf0bada841ec8ce7bb18f2c168179a + version: 0.20200729.0 +- accessibility: 0 + component: posterior_parameters + data_product: posterior_parameters/data_for_scotland + extension: h5 + filename: posterior_parameters/data_for_scotland/0.20200813.0/0.20200813.0.h5 + namespace: EERA + verified_hash: adc99c294ebb864deb5dacf338e230e639e69b40 + version: 0.20200813.0 +- accessibility: 0 + component: age + data_product: population-data/data_for_scotland + extension: h5 + filename: population-data/data_for_scotland/0.20200728.0/0.20200728.0.h5 + namespace: EERA + verified_hash: 035a0bde31e3de24519b6537b08a20b8a3d5c67b + version: 0.20200728.0 +- accessibility: 0 + component: data + data_product: population-data/data_for_scotland + extension: h5 + filename: population-data/data_for_scotland/0.20200728.0/0.20200728.0.h5 + namespace: EERA + verified_hash: 035a0bde31e3de24519b6537b08a20b8a3d5c67b + version: 0.20200728.0 +- accessibility: 0 + component: deaths + data_product: population-data/data_for_scotland + extension: h5 + filename: population-data/data_for_scotland/0.20200728.0/0.20200728.0.h5 + namespace: EERA + verified_hash: 035a0bde31e3de24519b6537b08a20b8a3d5c67b + version: 0.20200728.0 +- accessibility: 0 + component: cfr_byage + data_product: prob_hosp_and_cfr/data_for_scotland + extension: h5 + filename: prob_hosp_and_cfr/data_for_scotland/0.20200813.0/0.20200813.0.h5 + namespace: EERA + verified_hash: 6777323dc4c7e8eb70c2b985a15742946216e4c0 + version: 0.20200813.0 +- accessibility: 0 + component: day_shut + data_product: fixed-parameters/day_shut + extension: toml + filename: fixed-parameters/day_shut/0.1.0/0.1.0.toml + namespace: EERA + verified_hash: fcac799da35e8df982aa0a4804a38b6b46ddaed0 + version: 0.1.0 - accessibility: 0 component: total_hcw data_product: fixed-parameters/total_hcw diff --git a/test/datapipeline/data/population-data/data_for_scotland/0.20200728.0/0.20200728.0.h5 b/test/datapipeline/data/population-data/data_for_scotland/0.20200728.0/0.20200728.0.h5 new file mode 100644 index 00000000..c7da4ae7 Binary files /dev/null and b/test/datapipeline/data/population-data/data_for_scotland/0.20200728.0/0.20200728.0.h5 differ diff --git a/test/datapipeline/data/posterior_parameters/data_for_scotland/0.20200813.0/0.20200813.0.h5 b/test/datapipeline/data/posterior_parameters/data_for_scotland/0.20200813.0/0.20200813.0.h5 new file mode 100644 index 00000000..d092fe6f Binary files /dev/null and b/test/datapipeline/data/posterior_parameters/data_for_scotland/0.20200813.0/0.20200813.0.h5 differ diff --git a/test/datapipeline/data/prior-distributions/chcw/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/chcw/0.1.0/0.1.0.toml new file mode 100644 index 00000000..6ee2686f --- /dev/null +++ b/test/datapipeline/data/prior-distributions/chcw/0.1.0/0.1.0.toml @@ -0,0 +1,5 @@ +[chcw] +type = "distribution" +distribution = "poisson" +lambda = 42 + diff --git a/test/datapipeline/data/prior-distributions/d/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/d/0.1.0/0.1.0.toml new file mode 100644 index 00000000..b07405f0 --- /dev/null +++ b/test/datapipeline/data/prior-distributions/d/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[d] +type = "distribution" +distribution = "beta" +alpha = 3 +beta = 3 + diff --git a/test/datapipeline/data/prior-distributions/lambda/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/lambda/0.1.0/0.1.0.toml new file mode 100644 index 00000000..60b86f32 --- /dev/null +++ b/test/datapipeline/data/prior-distributions/lambda/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[lambda] +type = "distribution" +distribution = "uniform" +a = 1e-09 +b = 1e-06 + diff --git a/test/datapipeline/data/prior-distributions/phcw/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/phcw/0.1.0/0.1.0.toml new file mode 100644 index 00000000..88c6b0ff --- /dev/null +++ b/test/datapipeline/data/prior-distributions/phcw/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[phcw] +type = "distribution" +distribution = "beta" +alpha = 3 +beta = 3 + diff --git a/test/datapipeline/data/prior-distributions/pinf/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/pinf/0.1.0/0.1.0.toml new file mode 100644 index 00000000..18a9d6e1 --- /dev/null +++ b/test/datapipeline/data/prior-distributions/pinf/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[pinf] +type = "distribution" +distribution = "beta" +alpha = 3 +beta = 9 + diff --git a/test/datapipeline/data/prior-distributions/ps/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/ps/0.1.0/0.1.0.toml new file mode 100644 index 00000000..911f859f --- /dev/null +++ b/test/datapipeline/data/prior-distributions/ps/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[ps] +type = "distribution" +distribution = "beta" +alpha = 9 +beta = 3 + diff --git a/test/datapipeline/data/prior-distributions/q/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/q/0.1.0/0.1.0.toml new file mode 100644 index 00000000..18002b6a --- /dev/null +++ b/test/datapipeline/data/prior-distributions/q/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[q] +type = "distribution" +distribution = "beta" +alpha = 3 +beta = 3 + diff --git a/test/datapipeline/data/prior-distributions/rrd/0.1.0/0.1.0.toml b/test/datapipeline/data/prior-distributions/rrd/0.1.0/0.1.0.toml new file mode 100644 index 00000000..3ef4bed8 --- /dev/null +++ b/test/datapipeline/data/prior-distributions/rrd/0.1.0/0.1.0.toml @@ -0,0 +1,6 @@ +[rrd] +type = "distribution" +distribution = "gamma" +k = 1 +theta = 1 + diff --git a/test/datapipeline/data/prob_hosp_and_cfr/data_for_scotland/0.20200813.0/0.20200813.0.h5 b/test/datapipeline/data/prob_hosp_and_cfr/data_for_scotland/0.20200813.0/0.20200813.0.h5 new file mode 100644 index 00000000..f07491b0 Binary files /dev/null and b/test/datapipeline/data/prob_hosp_and_cfr/data_for_scotland/0.20200813.0/0.20200813.0.h5 differ diff --git a/test/datapipeline/parameters.ini b/test/datapipeline/parameters.ini index 5321a56c..b39b007b 100755 --- a/test/datapipeline/parameters.ini +++ b/test/datapipeline/parameters.ini @@ -1,6 +1,10 @@ [Settings] shb_id=15 tau=1 +nHealthBoards=15 +nAgeGroups=8 +nCfrCategories=4 +nCasesDays=60 [Seed settings] seedmethod=background diff --git a/test/unit/ArgParserTest.cpp b/test/unit/ArgParserTest.cpp index b2183c5d..7871ffac 100644 --- a/test/unit/ArgParserTest.cpp +++ b/test/unit/ArgParserTest.cpp @@ -103,3 +103,28 @@ TEST(AnArgumentParser, RecognisesParameterSetIndex) EXPECT_EQ(parse.getArgs().parameter_set_index, 2); } + +TEST(AnArgumentParser, CheckAbsenseOfDataPipeline) +{ + const char* _test_args[] = {"exe", "-m", "inference", "-s", "original"}; + + ArgumentParser parse(5, _test_args); + + EXPECT_EQ(parse.getArgs().datapipeline_path, ""); +} + +TEST(AnArgumentParser, RecognisesRequestForDataPipeline) +{ + const char* _test_args[] = {"exe", "-m", "inference", "-s", "original", "-c", "my_config.yaml"}; + + ArgumentParser parse(7, _test_args); + + EXPECT_EQ(parse.getArgs().datapipeline_path, "my_config.yaml"); +} + +TEST(AnArgumentParser, TerminatesIfNoDataPipelinePath) +{ + const char* _test_args_default[] = {"exe", "-c"}; + + EXPECT_EXIT( ArgumentParser(2, _test_args_default), ExitedWithCode(1), ""); +} diff --git a/test/unit/DataPipelineTest.cpp b/test/unit/DataPipelineTest.cpp index db45a690..aabf8461 100644 --- a/test/unit/DataPipelineTest.cpp +++ b/test/unit/DataPipelineTest.cpp @@ -1,6 +1,7 @@ #include "gtest/gtest.h" #include "ModelTypes.h" #include "IO-datapipeline.h" +#include "IO.h" #include @@ -15,12 +16,18 @@ namespace { TEST(TestIODatapipeline, ExpectThrowForBadPath) { - EXPECT_ANY_THROW(IO::IOdatapipeline idp2("../test/datapipeline/parameters.ini", "NoValidPath.yaml");); + const std::string out_dir = std::string(ROOT_DIR)+"/outputs"; + Utilities::logging_stream::Sptr logger = std::make_shared(out_dir); + + EXPECT_ANY_THROW(IO::IOdatapipeline idp2("../test/datapipeline/parameters.ini", "", logger, "NoValidPath.yaml");); } TEST(TestIODatapipeline, CanReadFixedParameters) { - IO::IOdatapipeline idp("../test/datapipeline/parameters.ini", "../test/datapipeline/config.yaml"); + const std::string out_dir = std::string(ROOT_DIR)+"/outputs"; + Utilities::logging_stream::Sptr logger = std::make_shared(out_dir); + + IO::IOdatapipeline idp("../test/datapipeline/parameters.ini", "", logger, "../test/datapipeline/config.yaml"); CommonModelInputParameters params = idp.ReadCommonParameters(); EXPECT_EQ(params.paramlist.T_lat, 4); @@ -32,4 +39,158 @@ TEST(TestIODatapipeline, CanReadFixedParameters) EXPECT_EQ(params.paramlist.K, 2000); EXPECT_EQ(params.paramlist.inf_asym, 1.0); EXPECT_EQ(params.totN_hcw, 112974); + EXPECT_EQ(params.day_shut, 19); } + +namespace { + template + std::ostream& operator<<(std::ostream& os, const std::vector& v) + { + os << "(" << v.size() << ") ["; + for (auto &ve : v) { + os << " " << ve; + } + os << " ]"; + return os; + } + + template + std::ostream& operator<<(std::ostream& os, const std::vector>& v) + { + os << "(" << v.size() << ") [\n"; + for (auto &ve : v) { + os << " " << ve << "\n"; + } + os << "]"; + return os; + } + + void compare_eq(int a, int b) + { + EXPECT_EQ(a, b); + } + + void compare_eq(double a, double b) + { + EXPECT_DOUBLE_EQ(a, b); + } + + template + void compare_eq( + const std::vector>& a, const std::vector>& b, + int firstj = -1, int endi = -1) + { + EXPECT_EQ(a.size(), b.size()); + std::size_t sz1 = std::min(a.size(), b.size()); + + if (firstj < 0) firstj = 0; + + for (int j = firstj; j < sz1; ++j) { + std::size_t sz2 = std::min(a[j].size(), b[j].size()); + + if (endi < 0) { + EXPECT_EQ(a[j].size(), b[j].size()); + } else { + EXPECT_EQ(sz2, endi); + } + + for (int i = 0; i < sz2; ++i) { + compare_eq(a[j][i], b[j][i]); + } + } + } +} + +TEST(TestIODatapipeline, CanReadModelData) +{ + const std::string out_dir = std::string(ROOT_DIR)+"/outputs"; + Utilities::logging_stream::Sptr logger = std::make_shared(out_dir); + + std::string paramsFile = std::string(ROOT_DIR)+"/test/datapipeline/parameters.ini"; + std::string configDir = std::string(ROOT_DIR)+"/test/regression/run1/data"; + std::string datapipelineConfig = std::string(ROOT_DIR)+"/test/datapipeline/config.yaml"; + + // Load data from data pipeline store + IO::IOdatapipeline idp(paramsFile, configDir, logger, datapipelineConfig); + ObservationsForModels dp_params = idp.ReadModelObservations(); + + // Load data from regression test 1 + ObservationsForModels rg_params = IO::ReadModelObservations(configDir, logger); + + compare_eq(dp_params.cases, rg_params.cases, 1); + compare_eq(dp_params.age_pop, rg_params.age_pop); + compare_eq(dp_params.waifw_norm, rg_params.waifw_norm); + compare_eq(dp_params.waifw_home, rg_params.waifw_home); + compare_eq(dp_params.waifw_sdist, rg_params.waifw_sdist); + compare_eq(dp_params.cfr_byage, rg_params.cfr_byage, -1, 3); +} + +TEST(TestIODatapipeline, CanReadInferenceConfig) +{ + const std::string out_dir = std::string(ROOT_DIR)+"/outputs"; + Utilities::logging_stream::Sptr logger = std::make_shared(out_dir); + + std::string paramsFile = std::string(ROOT_DIR)+"/test/datapipeline/parameters.ini"; + std::string configDir = std::string(ROOT_DIR)+"/test/regression/run1/data"; + std::string datapipelineConfig = std::string(ROOT_DIR)+"/test/datapipeline/config.yaml"; + + // Load data from data pipeline store + IO::IOdatapipeline idp(paramsFile, configDir, logger, datapipelineConfig); + CommonModelInputParameters common_params = idp.ReadCommonParameters(); + ObservationsForModels model_obs = idp.ReadModelObservations(); + InferenceConfig dp_infconfig = idp.ReadInferenceConfig(common_params, model_obs); + + EXPECT_EQ(dp_infconfig.prior_pinf_shape1, 3.0); + EXPECT_EQ(dp_infconfig.prior_pinf_shape2, 9.0); + EXPECT_EQ(dp_infconfig.prior_phcw_shape1, 3.0); + EXPECT_EQ(dp_infconfig.prior_phcw_shape2, 3.0); + EXPECT_EQ(dp_infconfig.prior_chcw_mean, 42); + EXPECT_EQ(dp_infconfig.prior_d_shape1, 3.0); + EXPECT_EQ(dp_infconfig.prior_d_shape2, 3.0); + EXPECT_EQ(dp_infconfig.prior_q_shape1, 3.0); + EXPECT_EQ(dp_infconfig.prior_q_shape2, 3.0); + EXPECT_EQ(dp_infconfig.prior_ps_shape1, 9.0); + EXPECT_EQ(dp_infconfig.prior_ps_shape2, 3.0); + EXPECT_EQ(dp_infconfig.prior_rrd_shape1, 1.0); + EXPECT_EQ(dp_infconfig.prior_rrd_shape2, 1.0); + EXPECT_EQ(dp_infconfig.prior_lambda_shape1, 1e-9); + EXPECT_EQ(dp_infconfig.prior_lambda_shape2, 1e-6); + + // Load from local data + InferenceConfig rg_infconfig = IO::ReadInferenceConfig(configDir, logger, common_params); + + compare_eq(dp_infconfig.observations.cases, rg_infconfig.observations.cases, 1); + compare_eq(dp_infconfig.observations.deaths, rg_infconfig.observations.deaths, 1); +} + +TEST(TestIODatapipeline, CanReadPredictionConfig) +{ + const std::string out_dir = std::string(ROOT_DIR)+"/outputs"; + Utilities::logging_stream::Sptr logger = std::make_shared(out_dir); + + std::string paramsFile = std::string(ROOT_DIR)+"/test/datapipeline/parameters.ini"; + std::string configDir = std::string(ROOT_DIR)+"/test/regression/run1/data"; + std::string datapipelineConfig = std::string(ROOT_DIR)+"/test/datapipeline/config.yaml"; + + // Load data from data pipeline store + IO::IOdatapipeline idp(paramsFile, configDir, logger, datapipelineConfig); + CommonModelInputParameters common_params = idp.ReadCommonParameters(); + PredictionConfig dp_predconfig = idp.ReadPredictionConfig(0, common_params); + + EXPECT_EQ(dp_predconfig.posterior_parameters[0], 0.153532); + EXPECT_EQ(dp_predconfig.posterior_parameters[1], 0.60916); + EXPECT_EQ(dp_predconfig.posterior_parameters[2], 37.9059); + EXPECT_EQ(dp_predconfig.posterior_parameters[3], 0.525139); + EXPECT_EQ(dp_predconfig.posterior_parameters[4], 0.313957); + EXPECT_EQ(dp_predconfig.posterior_parameters[5], 0.787278); + EXPECT_EQ(dp_predconfig.posterior_parameters[6], 0.516736); + EXPECT_EQ(dp_predconfig.posterior_parameters[7], 8.50135E-07); + EXPECT_EQ(dp_predconfig.fixedParameters.T_lat, 4); + EXPECT_EQ(dp_predconfig.fixedParameters.juvp_s, 0.1); + EXPECT_EQ(dp_predconfig.fixedParameters.T_inf, 1.5); + EXPECT_EQ(dp_predconfig.fixedParameters.T_rec, 11); + EXPECT_EQ(dp_predconfig.fixedParameters.T_sym, 7); + EXPECT_EQ(dp_predconfig.fixedParameters.T_hos, 5); + EXPECT_EQ(dp_predconfig.fixedParameters.K, 2000); + EXPECT_EQ(dp_predconfig.fixedParameters.inf_asym, 1); +} \ No newline at end of file