From 36dff147785f6bd63fbc888e73c9376cccb22719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20O=C5=BCa=C5=84ski?= Date: Mon, 13 Dec 2021 01:53:35 +0100 Subject: [PATCH] Adds kernel that can be a product of age and split --- Project.toml | 1 + src/enums.jl | 2 +- src/infection_kernels.jl | 45 +++++++++++++++++++ src/params/splitage_coupling.jl | 32 ++++++++++++++ src/simparams.jl | 24 ++++++++-- src/utils.jl | 1 + src/utils/coupling_sampler.jl | 16 ++++--- src/utils/matrix_sampler.jl | 2 +- src/utils/population_grouping.jl | 1 + src/utils/prod2coupling_sampler.jl | 66 ++++++++++++++++++++++++++++ test/Project.toml | 3 ++ test/runtests.jl | 14 +++--- test/test_population_grouping.jl | 1 + test/test_prod2_coupling.jl | 70 ++++++++++++++++++++++++++++++ 14 files changed, 263 insertions(+), 15 deletions(-) create mode 100644 src/params/splitage_coupling.jl create mode 100644 src/utils/prod2coupling_sampler.jl create mode 100644 test/test_prod2_coupling.jl diff --git a/Project.toml b/Project.toml index fc76265..f187735 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SaferIntegers = "88634af6-177f-5301-88b8-7819386cfa38" diff --git a/src/enums.jl b/src/enums.jl index 4c69260..16638e8 100644 --- a/src/enums.jl +++ b/src/enums.jl @@ -6,7 +6,7 @@ @enum DetectionStatus::UInt8 Undetected UnderObservation TestPending Detected #2 bits -@enum ContactKind::UInt8 NoContact=0 HouseholdContact HospitalContact AgeCouplingContact ConstantKernelContact OutsideContact # 3 bits +@enum ContactKind::UInt8 NoContact=0 HouseholdContact HospitalContact AgeCouplingContact ConstantKernelContact OutsideContact SplitAgeContact # 3 bits @enum DetectionKind::UInt8 NoDetection=0 OutsideQuarantineDetection=1 FromQuarantineDetection FromTracingDetection diff --git a/src/infection_kernels.jl b/src/infection_kernels.jl index fe8585d..ca974e7 100644 --- a/src/infection_kernels.jl +++ b/src/infection_kernels.jl @@ -170,3 +170,48 @@ function enqueue_transmissions!(state::SimState, ::Val{AgeCouplingContact}, sour end nothing end + +function enqueue_transmissions!(state::SimState, ::Val{SplitAgeContact}, source_id::Integer, params::SimParams) + if params.age_coupling_params === nothing + return + end + + progression = progressionof(state, source_id) + + start_time = progression.incubation_time + end_time = if !ismissing(progression.mild_symptoms_time); progression.mild_symptoms_time + elseif !ismissing(progression.severe_symptoms_time); progression.severe_symptoms_time + elseif !ismissing(progression.recovery_time); progression.recovery_time + else error("no recovery nor symptoms time defined") + end + + strain = strainof(state, source_id) + + total_infection_rate = (end_time - start_time) * rawinfectivity(params, strain) + total_infection_rate *= spreading(params, source_id) + + num_infections = rand(state.rng, Poisson(total_infection_rate)) + + if num_infections == 0 + return + end + @assert start_time != end_time "pathologicaly short time for infections there shouldn't be any infections but there are $num_infections, progression=$progression" + + time_dist = Uniform(state.time, end_time - start_time + state.time) # in global time reference frame + + for _ in 1:num_infections + subject_id = sample(state.rng, params.age_coupling_params.coupling, source_id) + + if subject_id == source_id # self-infection difficult to avoid at sampling + break + end + + if Healthy == health(state, subject_id) && rand(state.rng) < condinfectivity(params, immunityof(state, subject_id), strain) + infection_time::TimePoint = rand(state.rng, time_dist) |> TimePoint + @assert state.time <= infection_time <= (end_time-start_time + state.time) + push!(state.queue, Event(Val(TransmissionEvent), infection_time, subject_id, source_id, AgeCouplingContact, strain)) + end + end + nothing +end + diff --git a/src/params/splitage_coupling.jl b/src/params/splitage_coupling.jl new file mode 100644 index 0000000..f0add32 --- /dev/null +++ b/src/params/splitage_coupling.jl @@ -0,0 +1,32 @@ + +struct SplitAgeCouplingParams + coupling::Prod2CouplingSampler{UInt32, UInt8, PersonIdx} +end + +function SplitAgeCouplingParams( + ages::AbstractVector{T} where T<:Real, + genders::Union{Nothing, AbstractVector{Bool}}, + age_thresholds::AbstractVector{T} where T<:Real, + age_coupling_weights::AbstractMatrix{Float64}, + split_group_ids::AbstractVector{T} where T<:Integer, + split_coupling_weights::AbstractMatrix{Float64}) + + @assert age_thresholds[1] == 0 + num_groups = length(age_thresholds) + @assert (num_groups,num_groups) == size(coupling_weights) + + age_group_ids = agegroup.((age_thresholds,), ages) + @assert minimum(age_group_ids) > 0 + if genders !== nothing + @assert length(ages) == length(genders) + age_group_ids .= age_group_ids .*2 .+ genders .- 1 + end + @assert maximum(age_group_ids) <= num_groups <= typemax(GroupIdx) "minmax: $(extrema(age_group_ids)), num=$num_groups, typemax=$(typemax(GroupIdx))" + + coupling_sampler = Prod2CouplingSampler( + split_group_ids, split_coupling_weights, + age_group_ids, age_coupling_weights, + UInt32, PersonIdx) + + SplitAgeCouplingParams(coupling_sampler) +end diff --git a/src/simparams.jl b/src/simparams.jl index 0446e98..ff3c3a0 100644 --- a/src/simparams.jl +++ b/src/simparams.jl @@ -22,6 +22,7 @@ include("params/phonetracing.jl") include("params/spreading.jl") include("params/outside_cases.jl") include("params/screening.jl") +include("params/splitage_coupling.jl") struct SimParams <: AbstractSimParams household_ptrs::Vector{Tuple{PersonIdx,PersonIdx}} # (i1,i2) where i1 and i2 are the indices of first and last member of the household @@ -35,8 +36,9 @@ struct SimParams <: AbstractSimParams constant_kernel_param::Float32 household_kernel_param::Float32 - age_coupling_params::Union{Nothing, AgeCouplingParams} # nothing if kernel not active - hospital_kernel_params::Union{Nothing, HospitalInfectionParams} # nothing if hospital kernel not active + age_coupling_params::Union{Nothing, AgeCouplingParams} # nothing if kernel not active + splitage_coupling_params::Union{Nothing, SplitAgeCouplingParams} # nothing if kernel not active + hospital_kernel_params::Union{Nothing, HospitalInfectionParams} # nothing if hospital kernel not active hospital_detections::Bool mild_detection_prob::Float64 @@ -157,10 +159,14 @@ function make_params( quarantine_length::Float64=14.0, age_coupling_param::Union{Nothing, Real}=nothing, - age_coupling_thresholds::Union{Nothing, AbstractArray{T} where T<:Real}=nothing, + age_coupling_thresholds::Union{Nothing, AbstractVector{T} where T<:Real}=nothing, age_coupling_weights::Union{Nothing, AbstractMatrix{T} where T<:Real}=nothing, age_coupling_use_genders::Bool=false, + split_group_ids::Union{Nothing, AbstractVector{T} where T<:Integer}=nothing, + split_coupling_weights::Union{Nothing, AbstractMatrix{T} where T<:Real}=nothing, + splitage_kernel_param::Union{Nothing, Real}=nothing, + screening_params::Union{Nothing,ScreeningParams}=nothing, spreading_alpha::Union{Nothing,Real}=nothing, @@ -201,6 +207,17 @@ function make_params( else error("age couplig params not fully given") end + splitage_coupling_kernel_params = + if nothing === age_coupling_weights || nothing === split_coupling_weights || splitage_couplig_param; nothing + else + SplitAgeCouplingParams( + individuals_df.age, + age_coupling_use_genders === nothing ? individuals_df.gender : nothing, + age_coupling_thresholds, age_coupling_weights, + split_group_ids, split_coupling_weights + ) + end + hospital_kernel_params = if 0 == hospital_kernel_param; nothing elseif 0.0 < hospital_kernel_param; @@ -242,6 +259,7 @@ function make_params( constant_kernel_param, household_kernel_param, age_coupling_kernel_params, + splitage_coupling_kernel_params, hospital_kernel_params, hospital_detections, diff --git a/src/utils.jl b/src/utils.jl index 83637db..524284c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,7 @@ include("utils/alias_sampling.jl") include("utils/matrix_sampler.jl") include("utils/population_grouping.jl") include("utils/coupling_sampler.jl") +include("utils/prod2coupling_sampler.jl") function countuniquesorted(arr::AbstractVector{T}) where T d = Dict{T,Int}() diff --git a/src/utils/coupling_sampler.jl b/src/utils/coupling_sampler.jl index cfd7441..cabf0b8 100644 --- a/src/utils/coupling_sampler.jl +++ b/src/utils/coupling_sampler.jl @@ -4,15 +4,21 @@ struct CouplingSampler{Tprob<:Real, PersonIdx<:Integer, GroupIdx<:Integer} matrix_sampler::MatrixAliasSampler{GroupIdx, Float64} end -function CouplingSampler(group_ids::AbstractVector{GroupIdx}, coupling_weights::Matrix{T} where T<:Real, PersonIdx::DataType=UInt32) where GroupIdx<:Integer +function couplingsizecheck(group_ids::AbstractVector{GroupIdx}, coupling_weights::AbstractMatrix{T} where T<:Real) where GroupIdx<:Integer num_groups, M = size(coupling_weights) @assert num_groups == M - @assert maximum(group_ids) <= num_groups <= typemax(GroupIdx) - @assert minimum(group_ids) > 0 + ming, maxg = extrema(group_ids) + @assert ming > 0 + @assert maxg <= num_groups <= typemax(GroupIdx) + minp, maxp = extrema(coupling_weights) + @assert minp >= 0 + @assert maxp < Inf + num_groups +end - @assert maximum(coupling_weights) < Inf - @assert minimum(coupling_weights) >= 0 +function CouplingSampler(group_ids::AbstractVector{GroupIdx}, coupling_weights::AbstractMatrix{T} where T<:Real, PersonIdx::DataType=UInt32) where GroupIdx<:Integer + num_groups = couplingsizecheck(group_ids, coupling_weights) grouping = PopulationGrouping(group_ids, num_groups, PersonIdx) matrix_sampler = MatrixAliasSampler(coupling_weights, GroupIdx) diff --git a/src/utils/matrix_sampler.jl b/src/utils/matrix_sampler.jl index 7635609..1ce4c26 100644 --- a/src/utils/matrix_sampler.jl +++ b/src/utils/matrix_sampler.jl @@ -23,7 +23,7 @@ function MatrixAliasSampler(probs::AbstractMatrix{T}, IdxType::Type=Int) where T setup_alias_sampler!(weights, acceptances, aliases, smalls, larges, sum(weights)) end MatrixAliasSampler{IdxType, T}(acceptances_mat, aliases_mat) - end +end function sample(rng::AbstractRNG, m::MatrixAliasSampler, source::Integer) N, M = m.aliases_mat |> size diff --git a/src/utils/population_grouping.jl b/src/utils/population_grouping.jl index 0a9afde..4685a05 100644 --- a/src/utils/population_grouping.jl +++ b/src/utils/population_grouping.jl @@ -37,5 +37,6 @@ PopulationGrouping(group_ids, num_groups::Integer, PersonIdx::DataType=UInt32) = numgroups(g::PopulationGrouping) = length(g.group_ptrs) - 1 getgroup(g::PopulationGrouping, group_id::Integer) = @view g.person_ids[g.group_ptrs[group_id]:(g.group_ptrs[group_id+1]-1)] +groupsize(g::PopulationGrouping, group_id::Integer) = g.group_ptrs[group_id+1] - g.group_ptrs[group_id] groupsizes!(sizes::AbstractVector{T} where T<:Integer, g::PopulationGrouping) = (sizes .= (@view g.group_ptrs[2:end]) .- (@view g.group_ptrs[1:end-1])) groupsizes(g::PopulationGrouping) = groupsizes!(Vector{PersonIdx}(undef, numgroups(g)), g) diff --git a/src/utils/prod2coupling_sampler.jl b/src/utils/prod2coupling_sampler.jl new file mode 100644 index 0000000..b31d3e7 --- /dev/null +++ b/src/utils/prod2coupling_sampler.jl @@ -0,0 +1,66 @@ +struct Prod2CouplingSampler{GroupIdx1<:Integer, GroupIdx2<:Integer, PersonIdx<:Integer} + group_ids1::Vector{GroupIdx1} + matrix_sampler1::MatrixAliasSampler{GroupIdx1, Float64} + + group_ids2::Vector{GroupIdx2} + weights_matrix2::Matrix{Float64} + + grouping::PopulationGrouping{PersonIdx} +end + +jointgrouping(g1::Integer, num_g1::Integer, g2::Integer) = 1 + (g1-1) + (g2-1) * num_g1 + +function Prod2CouplingSampler( + group_ids1::AbstractVector{GroupIdx1}, coupling_weights1::AbstractMatrix{T} where T<:Real, + group_ids2::AbstractVector{GroupIdx2}, coupling_weights2::AbstractMatrix{T} where T<:Real, + JointGroupIdx::DataType=UInt32, PersonIdx::DataType=UInt32) where {GroupIdx1<:Integer, GroupIdx2<:Integer} + + num_groups1 = couplingsizecheck(group_ids1, coupling_weights1) + num_groups2 = couplingsizecheck(group_ids2, coupling_weights2) + + # grouping2 is the faster changing index + joint_ids = jointgrouping.(group_ids2, num_groups2, group_ids1) .|> JointGroupIdx + grouping = PopulationGrouping(joint_ids, num_groups1 * num_groups2, PersonIdx) + + group_sizes = reshape(groupsizes(grouping), num_groups2, num_groups1) + group1_sizes = vec(sum(group_sizes, dims=1)) + matrix_sampler1 = MatrixAliasSampler(coupling_weights1 .* group1_sizes, GroupIdx1) + + Prod2CouplingSampler(group_ids1, matrix_sampler1, group_ids2, coupling_weights2, grouping) +end + +function sample(rng::AbstractRNG, coupling::Prod2CouplingSampler, source_group1_id::Integer, source_group2_id::Integer) + # first sample the group1 + target_group1_id = sample(rng, coupling.matrix_sampler1, source_group1_id) + + # extract the relevant column for sampling from group2 + target_prob_column2 = view(coupling.weights_matrix2, :, source_group2_id) + + num_groups2 = length(target_prob_column2) + # compute the sum of weights + weight_sum = 0.0 + @simd for i in 1:num_groups2 + joint_group_id = jointgrouping(i, num_groups2, target_group1_id) + weight_sum += target_prob_column2[i] * groupsize(coupling.grouping, joint_group_id) + end + + target_weight = weight_sum * rand(rng) + + # direct sample the target group 2 + accumulated_weight = 0.0 + target_group2_id = num_groups2 + for i in 1:num_groups2 + joint_group_id = jointgrouping(i, num_groups2, target_group1_id) + accumulated_weight += target_prob_column2[i] * groupsize(coupling.grouping, joint_group_id) + if accumulated_weight >= target_weight + target_group2_id = i + break + end + end + + target_joint_group_id = jointgrouping(target_group2_id, num_groups2, target_group1_id) + target_group = getgroup(coupling.grouping, target_joint_group_id) + + # sample an individual from the selected group + rand(rng, target_group) +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 7a21f89..4ebde1e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,3 +1,6 @@ [deps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 28a2e1d..212ba6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,15 +1,19 @@ using MocosSim using Test +using DataFrames using Random +using LinearAlgebra +using StatsBase import MocosSim: time, numdead, numdetected tests = [ - "age_coupling", - "matrix_alias_sampler", - "population_grouping", - "infection_modulations", - "household_grouping", +# "age_coupling", +# "matrix_alias_sampler", +# "population_grouping", +# "infection_modulations", +# "household_grouping", + "prod2_coupling", ] if length(ARGS) > 0 diff --git a/test/test_population_grouping.jl b/test/test_population_grouping.jl index bb89e2c..097e99b 100644 --- a/test/test_population_grouping.jl +++ b/test/test_population_grouping.jl @@ -34,6 +34,7 @@ @test length(sizes) == num_groups for i in 1:MocosSim.numgroups(grouping) @test length(MocosSim.getgroup(grouping, i)) == sizes[i] + @test MocosSim.groupsize(grouping, i) == sizes[i] end end end diff --git a/test/test_prod2_coupling.jl b/test/test_prod2_coupling.jl new file mode 100644 index 0000000..3e5e41f --- /dev/null +++ b/test/test_prod2_coupling.jl @@ -0,0 +1,70 @@ +@testset "Prod2Coupling" begin + @testset "jointgrouping is indexing in the right order" begin + num_groups1 = 7 + num_groups2 = 13 + + joint_ids = Int[] + for group1_id in 1:num_groups1 + for group2_id in 1:num_groups2 + joint_group_id = MocosSim.jointgrouping(group2_id, num_groups2, group1_id) + push!(joint_ids, joint_group_id) + end + end + @test issorted(joint_ids) + @test length(unique(joint_ids)) == num_groups1 * num_groups2 + end + + @testset "sampling with right probability" begin + function empiricalprobs(rng::AbstractRNG, coupling::MocosSim.Prod2CouplingSampler, + source_g1::Integer, num_g1::Integer, source_g2::Integer, num_g2::Integer, + group1_ids::AbstractVector{T} where T<:Integer, group2_ids::AbstractVector{T} where T<:Integer; + N::Int=10^6) + num_individuals = length(group1_ids) + sampled_people = UInt32[] + for _ in 1:N + person_id = MocosSim.sample(rng, coupling, source_g1, source_g2) + push!(sampled_people, person_id) + end + hist = fit(Histogram, sampled_people, 0:num_individuals, closed=:right) + + df = DataFrame( + g1 = group1_ids, + g2 = group2_ids, + g = MocosSim.jointgrouping.(group2_ids, num_g2, group1_ids), + results = hist.weights / N + ) + df = combine(groupby(df, :g), :results => sum => :prob) + df = leftjoin(DataFrame(g=1:(num_g1*num_g2)), df, on=:g) + df = sort(df, :g) + replace!(df.prob, missing => 0.0) .|> Float64 + end + + @testset "all coupling equal" begin + rng = MersenneTwister(13) + + num_groups1 = 3 + num_groups2 = 5 + + group1_ids = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3] + group2_ids = [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 1, 2, 3, 5, 4] + + weights1 = ones(num_groups1, num_groups1) + weights2 = ones(num_groups2, num_groups2) + + coupling = MocosSim.Prod2CouplingSampler(group1_ids, weights1, group2_ids, weights2) + + prod_weights = kron(weights1, weights2); + prod_weights .*= MocosSim.groupsizes(coupling.grouping) + prod_weights ./= sum(prod_weights, dims=1); + + for g1 in 1:num_groups1 + for g2 in 1:num_groups2 + empirical = empiricalprobs(rng, coupling, g1, num_groups1, g2, num_groups2, group1_ids, group2_ids) + theoretical = prod_weights[:, MocosSim.jointgrouping.(g2, num_groups2, g1)] + @test norm(empirical .- theoretical, Inf) < 0.001 + + end + end + end + end +end \ No newline at end of file