Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SaferIntegers = "88634af6-177f-5301-88b8-7819386cfa38"
Expand Down
2 changes: 1 addition & 1 deletion src/enums.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@enum DetectionStatus::UInt8 Undetected UnderObservation TestPending Detected #2 bits

@enum ContactKind::UInt8 NoContact=0 HouseholdContact HospitalContact AgeCouplingContact ConstantKernelContact OutsideContact # 3 bits
@enum ContactKind::UInt8 NoContact=0 HouseholdContact HospitalContact AgeCouplingContact ConstantKernelContact OutsideContact SplitAgeContact # 3 bits

@enum DetectionKind::UInt8 NoDetection=0 OutsideQuarantineDetection=1 FromQuarantineDetection FromTracingDetection

Expand Down
45 changes: 45 additions & 0 deletions src/infection_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,48 @@ function enqueue_transmissions!(state::SimState, ::Val{AgeCouplingContact}, sour
end
nothing
end

function enqueue_transmissions!(state::SimState, ::Val{SplitAgeContact}, source_id::Integer, params::SimParams)
if params.age_coupling_params === nothing
return
end

progression = progressionof(state, source_id)

start_time = progression.incubation_time
end_time = if !ismissing(progression.mild_symptoms_time); progression.mild_symptoms_time
elseif !ismissing(progression.severe_symptoms_time); progression.severe_symptoms_time
elseif !ismissing(progression.recovery_time); progression.recovery_time
else error("no recovery nor symptoms time defined")
end

strain = strainof(state, source_id)

total_infection_rate = (end_time - start_time) * rawinfectivity(params, strain)
total_infection_rate *= spreading(params, source_id)

num_infections = rand(state.rng, Poisson(total_infection_rate))

if num_infections == 0
return
end
@assert start_time != end_time "pathologicaly short time for infections there shouldn't be any infections but there are $num_infections, progression=$progression"

time_dist = Uniform(state.time, end_time - start_time + state.time) # in global time reference frame

for _ in 1:num_infections
subject_id = sample(state.rng, params.age_coupling_params.coupling, source_id)

if subject_id == source_id # self-infection difficult to avoid at sampling
break
end

if Healthy == health(state, subject_id) && rand(state.rng) < condinfectivity(params, immunityof(state, subject_id), strain)
infection_time::TimePoint = rand(state.rng, time_dist) |> TimePoint
@assert state.time <= infection_time <= (end_time-start_time + state.time)
push!(state.queue, Event(Val(TransmissionEvent), infection_time, subject_id, source_id, AgeCouplingContact, strain))
end
end
nothing
end

32 changes: 32 additions & 0 deletions src/params/splitage_coupling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

struct SplitAgeCouplingParams
coupling::Prod2CouplingSampler{UInt32, UInt8, PersonIdx}
end

function SplitAgeCouplingParams(
ages::AbstractVector{T} where T<:Real,
genders::Union{Nothing, AbstractVector{Bool}},
age_thresholds::AbstractVector{T} where T<:Real,
age_coupling_weights::AbstractMatrix{Float64},
split_group_ids::AbstractVector{T} where T<:Integer,
split_coupling_weights::AbstractMatrix{Float64})

@assert age_thresholds[1] == 0
num_groups = length(age_thresholds)
@assert (num_groups,num_groups) == size(coupling_weights)

age_group_ids = agegroup.((age_thresholds,), ages)
@assert minimum(age_group_ids) > 0
if genders !== nothing
@assert length(ages) == length(genders)
age_group_ids .= age_group_ids .*2 .+ genders .- 1
end
@assert maximum(age_group_ids) <= num_groups <= typemax(GroupIdx) "minmax: $(extrema(age_group_ids)), num=$num_groups, typemax=$(typemax(GroupIdx))"

coupling_sampler = Prod2CouplingSampler(
split_group_ids, split_coupling_weights,
age_group_ids, age_coupling_weights,
UInt32, PersonIdx)

SplitAgeCouplingParams(coupling_sampler)
end
24 changes: 21 additions & 3 deletions src/simparams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ include("params/phonetracing.jl")
include("params/spreading.jl")
include("params/outside_cases.jl")
include("params/screening.jl")
include("params/splitage_coupling.jl")

struct SimParams <: AbstractSimParams
household_ptrs::Vector{Tuple{PersonIdx,PersonIdx}} # (i1,i2) where i1 and i2 are the indices of first and last member of the household
Expand All @@ -35,8 +36,9 @@ struct SimParams <: AbstractSimParams

constant_kernel_param::Float32
household_kernel_param::Float32
age_coupling_params::Union{Nothing, AgeCouplingParams} # nothing if kernel not active
hospital_kernel_params::Union{Nothing, HospitalInfectionParams} # nothing if hospital kernel not active
age_coupling_params::Union{Nothing, AgeCouplingParams} # nothing if kernel not active
splitage_coupling_params::Union{Nothing, SplitAgeCouplingParams} # nothing if kernel not active
hospital_kernel_params::Union{Nothing, HospitalInfectionParams} # nothing if hospital kernel not active

hospital_detections::Bool
mild_detection_prob::Float64
Expand Down Expand Up @@ -157,10 +159,14 @@ function make_params(
quarantine_length::Float64=14.0,

age_coupling_param::Union{Nothing, Real}=nothing,
age_coupling_thresholds::Union{Nothing, AbstractArray{T} where T<:Real}=nothing,
age_coupling_thresholds::Union{Nothing, AbstractVector{T} where T<:Real}=nothing,
age_coupling_weights::Union{Nothing, AbstractMatrix{T} where T<:Real}=nothing,
age_coupling_use_genders::Bool=false,

split_group_ids::Union{Nothing, AbstractVector{T} where T<:Integer}=nothing,
split_coupling_weights::Union{Nothing, AbstractMatrix{T} where T<:Real}=nothing,
splitage_kernel_param::Union{Nothing, Real}=nothing,

screening_params::Union{Nothing,ScreeningParams}=nothing,

spreading_alpha::Union{Nothing,Real}=nothing,
Expand Down Expand Up @@ -201,6 +207,17 @@ function make_params(
else error("age couplig params not fully given")
end

splitage_coupling_kernel_params =
if nothing === age_coupling_weights || nothing === split_coupling_weights || splitage_couplig_param; nothing
else
SplitAgeCouplingParams(
individuals_df.age,
age_coupling_use_genders === nothing ? individuals_df.gender : nothing,
age_coupling_thresholds, age_coupling_weights,
split_group_ids, split_coupling_weights
)
end

hospital_kernel_params =
if 0 == hospital_kernel_param; nothing
elseif 0.0 < hospital_kernel_param;
Expand Down Expand Up @@ -242,6 +259,7 @@ function make_params(
constant_kernel_param,
household_kernel_param,
age_coupling_kernel_params,
splitage_coupling_kernel_params,
hospital_kernel_params,

hospital_detections,
Expand Down
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ include("utils/alias_sampling.jl")
include("utils/matrix_sampler.jl")
include("utils/population_grouping.jl")
include("utils/coupling_sampler.jl")
include("utils/prod2coupling_sampler.jl")

function countuniquesorted(arr::AbstractVector{T}) where T
d = Dict{T,Int}()
Expand Down
16 changes: 11 additions & 5 deletions src/utils/coupling_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@ struct CouplingSampler{Tprob<:Real, PersonIdx<:Integer, GroupIdx<:Integer}
matrix_sampler::MatrixAliasSampler{GroupIdx, Float64}
end

function CouplingSampler(group_ids::AbstractVector{GroupIdx}, coupling_weights::Matrix{T} where T<:Real, PersonIdx::DataType=UInt32) where GroupIdx<:Integer
function couplingsizecheck(group_ids::AbstractVector{GroupIdx}, coupling_weights::AbstractMatrix{T} where T<:Real) where GroupIdx<:Integer
num_groups, M = size(coupling_weights)
@assert num_groups == M

@assert maximum(group_ids) <= num_groups <= typemax(GroupIdx)
@assert minimum(group_ids) > 0
ming, maxg = extrema(group_ids)
@assert ming > 0
@assert maxg <= num_groups <= typemax(GroupIdx)
minp, maxp = extrema(coupling_weights)
@assert minp >= 0
@assert maxp < Inf
num_groups
end

@assert maximum(coupling_weights) < Inf
@assert minimum(coupling_weights) >= 0
function CouplingSampler(group_ids::AbstractVector{GroupIdx}, coupling_weights::AbstractMatrix{T} where T<:Real, PersonIdx::DataType=UInt32) where GroupIdx<:Integer
num_groups = couplingsizecheck(group_ids, coupling_weights)

grouping = PopulationGrouping(group_ids, num_groups, PersonIdx)
matrix_sampler = MatrixAliasSampler(coupling_weights, GroupIdx)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/matrix_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function MatrixAliasSampler(probs::AbstractMatrix{T}, IdxType::Type=Int) where T
setup_alias_sampler!(weights, acceptances, aliases, smalls, larges, sum(weights))
end
MatrixAliasSampler{IdxType, T}(acceptances_mat, aliases_mat)
end
end

function sample(rng::AbstractRNG, m::MatrixAliasSampler, source::Integer)
N, M = m.aliases_mat |> size
Expand Down
1 change: 1 addition & 0 deletions src/utils/population_grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ PopulationGrouping(group_ids, num_groups::Integer, PersonIdx::DataType=UInt32) =

numgroups(g::PopulationGrouping) = length(g.group_ptrs) - 1
getgroup(g::PopulationGrouping, group_id::Integer) = @view g.person_ids[g.group_ptrs[group_id]:(g.group_ptrs[group_id+1]-1)]
groupsize(g::PopulationGrouping, group_id::Integer) = g.group_ptrs[group_id+1] - g.group_ptrs[group_id]
groupsizes!(sizes::AbstractVector{T} where T<:Integer, g::PopulationGrouping) = (sizes .= (@view g.group_ptrs[2:end]) .- (@view g.group_ptrs[1:end-1]))
groupsizes(g::PopulationGrouping) = groupsizes!(Vector{PersonIdx}(undef, numgroups(g)), g)
66 changes: 66 additions & 0 deletions src/utils/prod2coupling_sampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
struct Prod2CouplingSampler{GroupIdx1<:Integer, GroupIdx2<:Integer, PersonIdx<:Integer}
group_ids1::Vector{GroupIdx1}
matrix_sampler1::MatrixAliasSampler{GroupIdx1, Float64}

group_ids2::Vector{GroupIdx2}
weights_matrix2::Matrix{Float64}

grouping::PopulationGrouping{PersonIdx}
end

jointgrouping(g1::Integer, num_g1::Integer, g2::Integer) = 1 + (g1-1) + (g2-1) * num_g1

function Prod2CouplingSampler(
group_ids1::AbstractVector{GroupIdx1}, coupling_weights1::AbstractMatrix{T} where T<:Real,
group_ids2::AbstractVector{GroupIdx2}, coupling_weights2::AbstractMatrix{T} where T<:Real,
JointGroupIdx::DataType=UInt32, PersonIdx::DataType=UInt32) where {GroupIdx1<:Integer, GroupIdx2<:Integer}

num_groups1 = couplingsizecheck(group_ids1, coupling_weights1)
num_groups2 = couplingsizecheck(group_ids2, coupling_weights2)

# grouping2 is the faster changing index
joint_ids = jointgrouping.(group_ids2, num_groups2, group_ids1) .|> JointGroupIdx
grouping = PopulationGrouping(joint_ids, num_groups1 * num_groups2, PersonIdx)

group_sizes = reshape(groupsizes(grouping), num_groups2, num_groups1)
group1_sizes = vec(sum(group_sizes, dims=1))
matrix_sampler1 = MatrixAliasSampler(coupling_weights1 .* group1_sizes, GroupIdx1)

Prod2CouplingSampler(group_ids1, matrix_sampler1, group_ids2, coupling_weights2, grouping)
end

function sample(rng::AbstractRNG, coupling::Prod2CouplingSampler, source_group1_id::Integer, source_group2_id::Integer)
# first sample the group1
target_group1_id = sample(rng, coupling.matrix_sampler1, source_group1_id)

# extract the relevant column for sampling from group2
target_prob_column2 = view(coupling.weights_matrix2, :, source_group2_id)

num_groups2 = length(target_prob_column2)
# compute the sum of weights
weight_sum = 0.0
@simd for i in 1:num_groups2
joint_group_id = jointgrouping(i, num_groups2, target_group1_id)
weight_sum += target_prob_column2[i] * groupsize(coupling.grouping, joint_group_id)
end

target_weight = weight_sum * rand(rng)

# direct sample the target group 2
accumulated_weight = 0.0
target_group2_id = num_groups2
for i in 1:num_groups2
joint_group_id = jointgrouping(i, num_groups2, target_group1_id)
accumulated_weight += target_prob_column2[i] * groupsize(coupling.grouping, joint_group_id)
if accumulated_weight >= target_weight
target_group2_id = i
break
end
end

target_joint_group_id = jointgrouping(target_group2_id, num_groups2, target_group1_id)
target_group = getgroup(coupling.grouping, target_joint_group_id)

# sample an individual from the selected group
rand(rng, target_group)
end
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14 changes: 9 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
using MocosSim
using Test
using DataFrames
using Random
using LinearAlgebra
using StatsBase

import MocosSim: time, numdead, numdetected

tests = [
"age_coupling",
"matrix_alias_sampler",
"population_grouping",
"infection_modulations",
"household_grouping",
# "age_coupling",
# "matrix_alias_sampler",
# "population_grouping",
# "infection_modulations",
# "household_grouping",
"prod2_coupling",
]

if length(ARGS) > 0
Expand Down
1 change: 1 addition & 0 deletions test/test_population_grouping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@test length(sizes) == num_groups
for i in 1:MocosSim.numgroups(grouping)
@test length(MocosSim.getgroup(grouping, i)) == sizes[i]
@test MocosSim.groupsize(grouping, i) == sizes[i]
end
end
end
Expand Down
70 changes: 70 additions & 0 deletions test/test_prod2_coupling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@testset "Prod2Coupling" begin
@testset "jointgrouping is indexing in the right order" begin
num_groups1 = 7
num_groups2 = 13

joint_ids = Int[]
for group1_id in 1:num_groups1
for group2_id in 1:num_groups2
joint_group_id = MocosSim.jointgrouping(group2_id, num_groups2, group1_id)
push!(joint_ids, joint_group_id)
end
end
@test issorted(joint_ids)
@test length(unique(joint_ids)) == num_groups1 * num_groups2
end

@testset "sampling with right probability" begin
function empiricalprobs(rng::AbstractRNG, coupling::MocosSim.Prod2CouplingSampler,
source_g1::Integer, num_g1::Integer, source_g2::Integer, num_g2::Integer,
group1_ids::AbstractVector{T} where T<:Integer, group2_ids::AbstractVector{T} where T<:Integer;
N::Int=10^6)
num_individuals = length(group1_ids)
sampled_people = UInt32[]
for _ in 1:N
person_id = MocosSim.sample(rng, coupling, source_g1, source_g2)
push!(sampled_people, person_id)
end
hist = fit(Histogram, sampled_people, 0:num_individuals, closed=:right)

df = DataFrame(
g1 = group1_ids,
g2 = group2_ids,
g = MocosSim.jointgrouping.(group2_ids, num_g2, group1_ids),
results = hist.weights / N
)
df = combine(groupby(df, :g), :results => sum => :prob)
df = leftjoin(DataFrame(g=1:(num_g1*num_g2)), df, on=:g)
df = sort(df, :g)
replace!(df.prob, missing => 0.0) .|> Float64
end

@testset "all coupling equal" begin
rng = MersenneTwister(13)

num_groups1 = 3
num_groups2 = 5

group1_ids = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3]
group2_ids = [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 1, 2, 3, 5, 4]

weights1 = ones(num_groups1, num_groups1)
weights2 = ones(num_groups2, num_groups2)

coupling = MocosSim.Prod2CouplingSampler(group1_ids, weights1, group2_ids, weights2)

prod_weights = kron(weights1, weights2);
prod_weights .*= MocosSim.groupsizes(coupling.grouping)
prod_weights ./= sum(prod_weights, dims=1);

for g1 in 1:num_groups1
for g2 in 1:num_groups2
empirical = empiricalprobs(rng, coupling, g1, num_groups1, g2, num_groups2, group1_ids, group2_ids)
theoretical = prod_weights[:, MocosSim.jointgrouping.(g2, num_groups2, g1)]
@test norm(empirical .- theoretical, Inf) < 0.001

end
end
end
end
end