Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
.Rproj
.Rproj.user
.Rhistory
.RData
.Ruserdata
*.DS_Store
*SS/*
r/rough/
outputs/setting_specific/
all_mats_*
116 changes: 116 additions & 0 deletions r/analyses/run_calculate_setting_specific_contact_matrices.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
## Name: fit_neg_binom
## Description:
## Input file: clean_participants.rds, clean_contacts.rds
## Functions:
## Output file:

# the following code loops through the weeks to date and constructs a contact matrix using get_matrix().
# There is also a calculation of R assuming probability of infection on contact on 0.1


# Packages ----------------------------------------------------------------
library(data.table)
library(ggplot2)
library(viridis)
library(doParallel)

# Source user written scripts ---------------------------------------------

source('r/functions/get_minimal_data.R')
source('r/functions/functions.R')
# source('r/functions/get_react_data.R')
source('r/functions/calc_cm.R')
source('r/functions/compare_Rs.R')

# Set up parallel computing environment -------------------------------------
ncores = detectCores() - 1
registerDoParallel(cores = ncores)

# Input data ----------------------------------------------------------------

# extract data with useful columns
data = get_minimal_data()

# decant data into relevant containers
contacts = data[[1]]
parts = data[[2]]

start_date = lubridate::ymd('20200323')
end_date = lubridate::ymd('20210623')

parts = parts[between(date, start_date, end_date)]
contacts = contacts[between(date, start_date, end_date)]


# set breaks for age categories and get population proportions
breaks = c(0,5,12,18,30,40,50,60,70,Inf)
#breaks = c(0,18,65,Inf)
max_ = 50 # upper limit for censoring/truncation
popdata_totals = get_popvec(breaks, year_ = 2020)
weeks_in_parts = sort(unique(parts$survey_round))
week_range = c(53,54,57:63) #c(1,11,19,24,34,37,39,42,51) #c(min(weeks_in_parts):max(weeks_in_parts))
#week_range = 34:51
nwk = rep(2,length(week_range)) #c(10,8,5,10,3,2,3,9,6)
samples_ = 1000
fit_with_ = 'bs'
trunc_flag_ = F # flag for whether or not to use truncation rather than uncorrected censoring
zi_ = T # flag for fitting zero-inflated negative binomial vs negative binomial

# Filter data -------------------------------------------------------------
unique_wave_pid <- unique(parts$part_wave_uid)
contacts <- contacts[part_wave_uid %in% unique_wave_pid]

parts[,part_id := paste(as.character(part_id), survey_round, sep = '_')]
contacts[,part_id := paste(as.character(part_id), survey_round, sep = '_')]

countries = list(c("uk"))
country_names = c("uk")
regions = list(c("North East", "Yorkshire and The Humber"), c("North West"), c("East Midlands", "West Midlands"), c("East of England"), c("South West"), c("South East"), c("Greater London"))
nations = list(unlist(regions), c('Scotland'), c('Wales'), c('Northern Ireland'))
nation_names = c("England", "Scotland", "Wales")[1]

settings = c("home","school","work","other")

weights = get_contact_age_weights()

# for (i in 1:length(country_names)){
#for (i in 1:length(regions)){
for (i in 1:length(nation_names)){
for (j in 1:length(settings)){
# for (k in 1:length(week_range)){
lcms = foreach(k=1:length(week_range)) %dopar% {
print(k)

# contacts_nation <- contacts[country %in% countries[[i]] & eval(parse(text=paste0("cnt_",settings[j])))]
# unique_wave_pid <- unique(contacts_nation$part_wave_uid)
#
# parts_nation <- parts[part_wave_uid %in% unique_wave_pid]

# parts_nation <- parts[area_3_name %in% regions[[i]]]
parts_nation <- parts[area_3_name %in% nations[[i]]]
# parts_nation <- parts[country %in% countries[[i]]]

unique_wave_pid <- unique(parts_nation$part_wave_uid)
contacts_nation <- contacts[part_wave_uid %in% unique_wave_pid]

print(nations[[i]])
# print(countries[[i]])


# calculate contact matrices-------------------------------------------------------------


outfolder=paste0('outputs/setting_specific/', nation_names[i], '/')
# outfolder=paste0('outputs/setting_specific/', country_names[i], '/')
if(!dir.exists(outfolder)){
dir.create(outfolder, recursive = TRUE)
}

# cms_max50 = calc_cm_general(parts_nation, contacts_nation, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, zi=zi_, setting=settings[j])
calc_cm_general(parts_nation, contacts_nation, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, zi=zi_, setting=settings[j])

}

}

}
44 changes: 28 additions & 16 deletions r/functions/calc_cm.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## calculate the contact matrices


calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals, weeks_range=23:33, nwks=2, samples=10, fitwith='bs', outfolder='outputs/regular/', model_path='stan/trunc_negbinom_matrix_bunchtrunc.stan', prior_pars_mu=NULL, prior_pars_k=NULL, weights = NULL){
calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals, weeks_range=23:33, nwks=2, samples=10, fitwith='bs', outfolder='outputs/regular/', model_path='stan/trunc_negbinom_matrix_bunchtrunc.stan', prior_pars_mu=NULL, prior_pars_k=NULL, weights = NULL, trunc_flag = F, zi = F, setting = ""){
print(nwks)
if(!dir.exists(paste0(outfolder, 'contact_matrices/'))){
dir.create(paste0(outfolder, 'contact_matrices/'), recursive = TRUE)
Expand Down Expand Up @@ -32,34 +32,41 @@ calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals

if (nwks == 'ALL'){
weeks_range = list(weeks_range)
print(length(weeks_range))}
print(length(weeks_range))
}


for(week in weeks_range){
for(i in 1:length(weeks_range)){
week <- weeks_range[i]
# for(week in weeks_range){
if (nwks != 'ALL'){
i = week
# i = week
#print(i)
weeks <- week:(week + nwks - 1)
}
else{
weeks = week:(week + nwks - 1)
} else{
weeks = week
week = weeks[1]
}
filename_primer = paste0(outfolder, 'contact_matrices/', fitwith, samples, '_ngrps', length(breaks) - 1, '_cap', max_, '_nwks', length(weeks),'_sr', week, '_')

if(i %in% c(1:6, 17,18)) weeks <- c(weeks, 700)
}
filename_primer = paste0(outfolder, 'contact_matrices/', fitwith, samples, '_ngrps', length(breaks) - 1, '_cap', max_, '_nwks', length(weeks),'_sr', week, '_', setting, '_')
if (zi){
filename_primer = paste0(filename_primer,"zi_")
}
if (trunc_flag){
filename_primer = paste0(filename_primer,"trunc_")
}

if(week %in% c(1:6, 17,18)) weeks <- c(weeks, 700)

# Replace weekend contacts by week contacts if there were no surveys done at the weekend in that survey round
if(length(conts_weekend[survey_round %in% weeks]$part_id) == 0) {
conts_weekend = conts_weekday
parts_weekend = parts_weekday
}

ct_ac_weekend = get_age_table(conts_weekend, parts_weekend, weeks, breaks, weights = weights)
ct_ac_weekend = get_age_table(conts_weekend, parts_weekend, weeks, breaks, weights = weights, setting = setting)
cont_per_age_per_part_weekend = ct_ac_weekend[[1]]
all_conts_weekend = ct_ac_weekend[[2]]

ct_ac_weekday = get_age_table(conts_weekday, parts_weekday, weeks, breaks, weights = weights)
ct_ac_weekday = get_age_table(conts_weekday, parts_weekday, weeks, breaks, weights = weights, setting = setting)
cont_per_age_per_part_weekday = ct_ac_weekday[[1]]
all_conts_weekday = ct_ac_weekday[[2]]

Expand All @@ -81,8 +88,13 @@ calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals
}

if (fitwith == 'bs'){
outs_weekend = get_matrix_bs(cont_per_age_per_part_weekend, breaks, max_, bs=samples)
outs_weekday = get_matrix_bs(cont_per_age_per_part_weekday, breaks, max_, bs=samples)
if (zi){
param = c("p","mu")
} else {
param = "mu"
}
outs_weekend = get_matrix_bs(cont_per_age_per_part_weekend, breaks, max_, param=param, bs=samples, trunc_flag=trunc_flag, zi=zi, setting=setting)
outs_weekday = get_matrix_bs(cont_per_age_per_part_weekday, breaks, max_, param=param, bs=samples, trunc_flag=trunc_flag, zi=zi, setting=setting)
}

mus = (outs_weekend[[2]] * 2./7) + (outs_weekday[[2]] * 5./7)
Expand Down
47 changes: 44 additions & 3 deletions r/functions/fitting_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ library(data.table)
# Input data ----------------------------------------------------------------


# this function calculates log of the complement of the sum of a list of liklihoods.
# It is used to find the log liklihood of a tail of a right censored distribution
# this function calculates log of the complement of (i.e. one minus) the sum of a list of likelihoods.
# It is used to find the log likelihood of a tail of a right censored distribution
complementary_logprob <- function(x) {
tryCatch(log1p(-sum(exp(x))), error=function(e) -Inf)
}


# This funtion calculates the likelihood of a negarive binomial disrtibution given set of partameters 'par' and data 'x'.
# This function calculates the log-likelihood of a negative binomial distribution given set of parameters 'par' and data 'x'.
nb_loglik <- function(x, par, n) {
k <- par[["k"]]
mean <- par[["mu"]]
Expand All @@ -33,6 +33,47 @@ nb_loglik <- function(x, par, n) {
return(-sum(ll))
}

# This function calculates the log-likelihood of a Poisson distribution given mean 'par' and data 'x'.
poiss_loglik <- function(x, par, n){
ll <- rep(NA_real_, length(x))
ll[x < n] <- x[x < n] * log(par) - par - log(factorial(x[x < n])) #dpois(x[x < n], par, log = TRUE)
ll[x >= n] <- n * log(par) - par - log(factorial(n)) #dpois(n, par, log = TRUE)
return(-sum(ll))
}

# This function calculates the log-likelihood of a zero-inflated negative binomial distribution given set of parameters 'par' and data 'x'.
zinb_loglik <- function(x, par, n){
p <- par[["p"]]
k <- par[["k"]]
mu <- par[["mu"]]

ll <- rep(NA_real_, length(x))
ll[x == 0] <- log(p + (1-p)*dnbinom(0, mu = mu, size = 1/k))
ll[x > 0 & x < n] <- log(1-p) + dnbinom(x[x > 0 & x < n], mu = mu, size = 1/k, log = T)
ll[x >= n] <- log(1-p) + dnbinom(n, mu = mu, size = 1/k, log = T)
return(-sum(ll))
}

# This function calculates the log-likelihood of a truncated negative binomial distribution given set of parameters 'par', data 'x', and upper truncation limit 'n'.
trunc_nb_loglik <- function(x, par, n) {
k <- par[["k"]]
mean <- par[["mu"]]
ll <- rep(NA_real_, length(x))
ll[x <= n] <- dnbinom(x[x <= n], mu = mean, size = 1/k, log = TRUE) - pnbinom(n, mu = mean, size = 1/k, log.p = TRUE)
ll[x > n] <- 0
return(-sum(ll))
}

# This function calculates the log-likelihood of a truncated Poisson distribution given mean 'par', data 'x', and upper truncation limit 'n'.
trunc_poiss_loglik <- function(x, par, n){
ll <- rep(NA_real_, length(x))
ll[x <= n] <- x[x <= n] * log(par) - par - log(factorial(x[x <= n])) - ppois(n, par, log.p = TRUE)
ll[x > n] <- 0
return(-sum(ll))
}



# This function optimises negative binomial parameters mu and k for contacts reported by age group i in age group j
nbinom_optim_ <- function(i, j, param, n, count_frame) {

Expand Down
Loading