Skip to content

Commit 95789dd

Browse files
committed
convert to alias sampler
1 parent 8a87701 commit 95789dd

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/params/progression.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct ProgressionParams
55
dist_hospitalization_time::Gamma{Float64}
66
dist_mild_recovery_time::Uniform{Float64}
77
dist_death_time::LogNormal{Float64}
8-
dist_severity_by_age::Vector{Vector{Distributions.Categorical{P, Ps} where {P<:Real, Ps<:AbstractVector{P}}}}
8+
dist_severity_by_age::Matrix{AliasSampler}
99
effectiveness_table::Matrix{Float64}
1010
hospitalization_time_ratio::Float64
1111
end
@@ -40,11 +40,11 @@ const hospitalization_time_sampler = AliasSampler(Int, hospitalization_time_prob
4040
const age_hospitalization_thresholds = Int[0, 40, 50, 60, 70, 80]
4141

4242

43-
function sample_severity(rng::AbstractRNG, age::Real, gender::Bool, immunity::ImmunityState, effectiveness_table::Matrix{Float64}, dist_severity_by_age::Vector{Vector{Distributions.Categorical{P, Ps} where {P<:Real, Ps<:AbstractVector{P}}}})
43+
function sample_severity(rng::AbstractRNG, age::Real, gender::Bool, immunity::ImmunityState, effectiveness_table::Matrix{Float64}, dist_severity_by_age::Matrix{AliasSampler})
4444
@assert age >= 0 "age should be non-negative"
45-
idx = gender + 1 |> Int32
46-
dist = dist_severity_by_age[idx][age < max_age_hosp ? age + 1 : max_age_hosp]
47-
severity = rand(rng, dist) |> Severity
45+
gender_int = gender + 1 |> UInt8
46+
dist = dist_severity_by_age[min(age + 1, max_age_hosp), gender_int]
47+
severity = asample(dist) |> Severity
4848
severity_int = severity |> UInt8
4949
#reduction severity for immunited subject with some probability
5050
if size(effectiveness_table,1) > 1
@@ -197,11 +197,16 @@ end
197197

198198
function make_dist_severity_by_age(hospitalization_men_probs::Vector{T}, hospitalization_women_probs::Vector{T}, hospitalization_multiplier::Float64, death_multiplier::Float64) where T<:Real
199199
n = length(hospitalization_women_probs)
200-
dist_severity_by_age = [Vector{Categorical}(undef, n),Vector{Categorical}(undef, n)]
200+
hosp_man = hospitalization_men_probs * hospitalization_multiplier
201+
hosp_woman = hospitalization_women_probs * hospitalization_multiplier
202+
men_sampler = AliasSampler[]
203+
women_sampler = AliasSampler[]
201204
for idx in 1:n
202205
group_ids = agegroup(age_hospitalization_thresholds, idx-1)
203-
dist_severity_by_age[1][idx] = Categorical([0, 1-(hospitalization_men_probs[idx]*hospitalization_multiplier), hospitalization_multiplier*hospitalization_men_probs[idx]*(1-critical_probs[group_ids]*death_multiplier), hospitalization_multiplier*hospitalization_men_probs[idx]*critical_probs[group_ids]*death_multiplier])
204-
dist_severity_by_age[2][idx] = Categorical([0, 1-(hospitalization_women_probs[idx]*hospitalization_multiplier), hospitalization_multiplier*hospitalization_women_probs[idx]*(1-critical_probs[group_ids]*death_multiplier), hospitalization_multiplier*hospitalization_women_probs[idx]*critical_probs[group_ids]*death_multiplier])
206+
hosp_prob = [0, 1-hosp_man[idx], hosp_man[idx]*(1-critical_probs[group_ids]*death_multiplier), hosp_man[idx]critical_probs[group_ids]*death_multiplier]
207+
push!(men_sampler, AliasSampler(UInt8,hosp_prob))
208+
hosp_prob = [0, 1-hosp_woman[idx], hosp_woman[idx]*(1-critical_probs[group_ids]*death_multiplier), hosp_woman[idx]critical_probs[group_ids]*death_multiplier]
209+
push!(women_sampler, AliasSampler(UInt8,hosp_prob))
205210
end
206-
dist_severity_by_age
211+
hcat(men_sampler,women_sampler)
207212
end

0 commit comments

Comments
 (0)