-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Initial Issue
The Mixture node's out marginalisation rule was failing with a cryptic error:
MethodError: no method matching getlogscale(::Nothing)I was a bit surprised by that and I haven't found the tests for mixture node out rules, maybe I missing smt here. So I wrote them on my own. In the current ReactiveMP it fails. Could it be I am misunderstading how this node can be used correctly?
@testset "Marginalisation: (m_switch::Categorical, m_inputs::ManyOf)" begin
@test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
# Test case 1: Equal weights
(
input = (
m_switch = Categorical([0.5, 0.5]),
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
], [0.5, 0.5])
),
# Test case 2: Unequal weights
(
input = (
m_switch = Categorical([0.8, 0.2]),
m_inputs = ManyOf(
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
], [0.8, 0.2])
)
]
endSo I decided to re-write this rule myself.
I come up with the following implementation
@rule Mixture(:out, Marginalisation) (m_switch::Any, m_inputs::ManyOf{N, Any}) where {N} = begin
# Get logscales, defaulting to 0.0 if Nothing
logscales_inputs = map(msg -> getlogscale(getdata(msg)) === nothing ? 0.0 : getlogscale(getdata(msg)), messages[2])
logscale_switch = getlogscale(getdata(messages[1])) === nothing ? 0.0 : getlogscale(getdata(messages[1]))
# compute logscales of individual components
logscales = logscales_inputs .+ logscale_switch
@logscale logsumexp(logscales)
# Use probabilities directly from m_switch
w = probvec(m_switch)
T = promote_type(eltype(w), map(x -> eltype(mean(x)), m_inputs)...)
# Convert inputs to the promoted type
typed_inputs = map(x -> convert_paramfloattype(T, x), m_inputs)
# return mixture with type-preserved components
return MixtureDistribution(collect(typed_inputs), collect(w))
endAlso to make the rule work I need to write quite some helping methods, so to run test now, you need to use the following code
@testitem "rules:Mixture:out" begin
using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions
import ReactiveMP: @test_rules
import ReactiveMP: getlogscale
import BayesBase: paramfloattype
import Base: isapprox
using ExponentialFamily: NormalMeanVariance
function getlogscale(d::NormalMeanVariance{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::Categorical{T}) where {T}
# Categorical distribution is already normalized
return 0.0
end
# Add paramfloattype for MixtureDistribution
function paramfloattype(d::MixtureDistribution{D, T}) where {D, T}
# The float type should be the promoted type of both the component distributions and weights
return promote_type(T, paramfloattype(first(d.components)))
end
# Add isapprox for MixtureDistribution
function isapprox(x::MixtureDistribution, y::MixtureDistribution; kwargs...)
# Check if components and weights match approximately
return length(x.components) == length(y.components) &&
all(isapprox.(x.components, y.components; kwargs...)) &&
isapprox(x.weights, y.weights; kwargs...)
end
@testset "Marginalisation: (m_switch::Categorical, m_inputs::ManyOf)" begin
@test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
# Test case 1: Equal weights
(
input = (
m_switch = Categorical([0.5, 0.5]),
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
], [0.5, 0.5])
),
# Test case 2: Unequal weights
(
input = (
m_switch = Categorical([0.8, 0.2]),
m_inputs = ManyOf(
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
)
),
output = MixtureDistribution([
NormalMeanVariance(1.0, 1.0),
NormalMeanVariance(5.0, 2.0)
], [0.8, 0.2])
)
]
end
@testset "Marginalisation: (m_inputs::ManyOf, q_switch::PointMass)" begin
@test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
# Test case for one-hot encoded switch
(
input = (
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
),
q_switch = PointMass([1.0, 0.0])
),
output = NormalMeanVariance(0.0, 1.0)
),
(
input = (
m_inputs = ManyOf(
NormalMeanVariance(0.0, 1.0),
NormalMeanVariance(2.0, 1.0)
),
q_switch = PointMass([0.0, 1.0])
),
output = NormalMeanVariance(2.0, 1.0)
)
]
end
endInitially my goal was to run the following model
@model function mixture_model(y)
# Use fixed mixing proportions
s = [1/3, 2/3]
# Parameters for the two Gaussian components
m[1] ~ Normal(mean = -20.0, variance = 1e2)
w[1] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
m[2] ~ Normal(mean = 20.0, variance = 1e2)
w[2] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
obs_precision ~ InverseGamma(2.0, 1.0)
# Generate mixture assignments and observations
for i in eachindex(y)
z[i] ~ Categorical(s)
comp1[i] ~ NormalMeanVariance(m[1], w[1]) # Use variance parameterization
comp2[i] ~ NormalMeanVariance(m[2], w[2])
μ[i] ~ Mixture(switch = z[i], inputs = (comp1[i], comp2[i]))
y[i] ~ NormalMeanVariance(μ[i], obs_precision)
end
endSo the end-to-end script to run is the following one:
using RxInfer, Distributions
using Random
using Plots
using StatsPlots
using ReactiveMP: getlogscale
# Add getlogscale methods for both distribution types
import ReactiveMP: getlogscale
function getlogscale(d::NormalMeanPrecision{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::NormalWeightedMeanPrecision{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::NormalMeanVariance{T}) where {T}
μ, τ = mean_precision(d)
# Log of normalization constant for Normal(μ, 1/√τ)
return -0.5 * (log(2π) - log(τ))
end
function getlogscale(d::Categorical{T}) where {T}
# Categorical distribution is already normalized
return 0.0
end
function getlogscale(_::Nothing)
return 0.0
end
# include("mixture_entropy.jl") # include this to use free energy
@model function mixture_model(y)
# Use fixed mixing proportions
s = [1/3, 2/3]
# Parameters for the two Gaussian components
m[1] ~ Normal(mean = -20.0, variance = 1e2)
w[1] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
m[2] ~ Normal(mean = 20.0, variance = 1e2)
w[2] ~ InverseGamma(2.0, 1.0) # Use InverseGamma for precision
obs_precision ~ InverseGamma(2.0, 1.0)
# Generate mixture assignments and observations
for i in eachindex(y)
z[i] ~ Categorical(s)
comp1[i] ~ NormalMeanVariance(m[1], w[1]) # Use variance parameterization
comp2[i] ~ NormalMeanVariance(m[2], w[2])
μ[i] ~ Mixture(switch = z[i], inputs = (comp1[i], comp2[i]))
y[i] ~ NormalMeanVariance(μ[i], obs_precision)
end
end
# Update constraints
@constraints function mixture_constraints()
q(z, m, w, μ, comp1, comp2, obs_precision) = q(z)q(m)q(w)q(μ)q(comp1)q(comp2)q(obs_precision)
q(m) = q(m[1])q(m[2])
q(w) = q(w[1])q(w[2])
q(z) = q(z[1]) .. q(z[end])
q(μ) = q(μ[1]) .. q(μ[end])
q(comp1) = q(comp1[1]) .. q(comp1[end])
q(comp2) = q(comp2[1]) .. q(comp2[end])
end
# Generate synthetic data with unequal proportions
rng = MersenneTwister(42)
true_means = [-20.0, 20.0]
true_precisions = [1.0, 1.0]
N = 1000
switch = [1/3, 2/3] # Unequal proportions as in test
z = rand(rng, Categorical(switch), N)
data = zeros(N)
for i in 1:N
data[i] = randn(rng)/sqrt(true_precisions[z[i]]) + true_means[z[i]]
end
# Update initialization with better starting points
init = @initialization begin
# Initialize means further apart
q(m[1]) = NormalMeanVariance(-30.0, 10.0) # More uncertainty in initial means
q(m[2]) = NormalMeanVariance(30.0, 10.0)
# Initialize variances with more informative priors
q(w[1]) = InverseGamma(3.0, 2.0) # Mode around 1.0
q(w[2]) = InverseGamma(3.0, 2.0)
q(obs_precision) = InverseGamma(3.0, 2.0)
for i in 1:N
# Initialize assignments closer to true proportions
q(z[i]) = Categorical([0.4, 0.6])
# Initialize components with wider separation
q(comp1[i]) = NormalMeanVariance(-20.0, 5.0)
q(comp2[i]) = NormalMeanVariance(20.0, 5.0)
q(μ[i]) = NormalMeanVariance(0.0, 100.0) # Very uncertain about mixture means
end
end
result = infer(
model = mixture_model(),
constraints = mixture_constraints(),
initialization = init,
data = (y = data,),
iterations = 10,
allow_node_contraction = true,
options = (limit_stack_depth = 100,),
# free_energy = true
)
# Create a range for plotting
x_range = range(minimum(data) - 1, maximum(data) + 1, length=200)
# Get the final parameters
m1 = mean(result.posteriors[:m][1][end])
m2 = mean(result.posteriors[:m][2][end])
v1 = mean(result.posteriors[:w][1][end]) # This is variance now
v2 = mean(result.posteriors[:w][2][end])
# Create the plot
p = histogram(data, normalize=true, alpha=0.3, label="Data", bins=50)
plot!(x_range,
x -> pdf(Normal(m1, sqrt(v1)), x),
label="Component 1", linestyle=:dash)
plot!(x_range,
x -> pdf(Normal(m2, sqrt(v2)), x),
label="Component 2", linestyle=:dash)
title!("Gaussian Mixture Model Fit")
xlabel!("x")
ylabel!("Density")
# Save the plot
savefig(p, "mixture_fit.png")
# Print the fitted parameters
println("\nFitted Parameters:")
println("Mean 1: ", round(m1, digits=3))
println("Mean 2: ", round(m2, digits=3))
println("Variance 1: ", round(v1, digits=3))
println("Variance 2: ", round(v2, digits=3))
println("Observation variance: ", round(mean(result.posteriors[:obs_precision][end]), digits=3))
# plot(1:10, result.free_energy)Interestingly it shows a different behavior (and I would say more interesting one) comparing it with NormalMixture
Normal mixture model (at least the following one) showing collapsing behavior:
using RxInfer, Distributions
using Random
using Plots
using StatsPlots
using ReactiveMP: getlogscale
@model function mixture_model_normal_mixture(y)
# Use fixed mixing proportions
s = [1/3, 2/3]
# Parameters for the two Gaussian components
m[1] ~ Normal(mean = -20.0, variance = 1e2)
w[1] ~ GammaShapeRate(2.0, 1.0)
m[2] ~ Normal(mean = 20.0, variance = 1e2)
w[2] ~ GammaShapeRate(2.0, 1.0)
# Generate mixture assignments and observations
for i in eachindex(y)
z[i] ~ Categorical(s)
# Using p for precision interface as shown in the tests
y[i] ~ NormalMixture(
switch = z[i],
m = (m[1], m[2]),
p = (w[1], w[2]) # Changed from v to p to match test code
)
end
end
@constraints function mixture_constraints()
q(z, m, w) = q(z)q(m)q(w)
q(m) = q(m[1])q(m[2])
q(w) = q(w[1])q(w[2])
q(z) = q(z[1]) .. q(z[end])
end
# Use same data generation
rng = MersenneTwister(42)
true_means = [-20.0, 20.0]
true_precisions = [1.0, 1.0]
N = 1000
switch = [1/3, 2/3]
z = rand(rng, Categorical(switch), N)
data = zeros(N)
for i in 1:N
data[i] = randn(rng)/sqrt(true_precisions[z[i]]) + true_means[z[i]]
end
# Use same initialization strategy
init = @initialization begin
q(m[1]) = NormalMeanVariance(-20.0, 5.0)
q(m[2]) = NormalMeanVariance(20.0, 5.0)
q(w[1]) = GammaShapeRate(3.0, 2.0)
q(w[2]) = GammaShapeRate(3.0, 2.0)
for i in 1:N
q(z[i]) = Categorical([0.4, 0.6])
end
end
result = infer(
model = mixture_model_normal_mixture(),
constraints = mixture_constraints(),
initialization = init,
data = (y = data,),
iterations = 10,
allow_node_contraction = true,
options = (limit_stack_depth = 100,),
free_energy = true
)
# Plotting
x_range = range(minimum(data) - 1, maximum(data) + 1, length=200)
m1 = mean(result.posteriors[:m][1][end])
m2 = mean(result.posteriors[:m][2][end])
v1 = mean(result.posteriors[:w][1][end])
v2 = mean(result.posteriors[:w][2][end])
p = histogram(data, normalize=true, alpha=0.3, label="Data", bins=50)
plot!(x_range,
x -> pdf(Normal(m1, sqrt(v1)), x),
label="Component 1", linestyle=:dash)
plot!(x_range,
x -> pdf(Normal(m2, sqrt(v2)), x),
label="Component 2", linestyle=:dash)
title!("Gaussian Mixture Model Fit (NormalMixture)")
xlabel!("x")
ylabel!("Density")
savefig(p, "normal_mixture_fit.png")
println("\nFitted Parameters (NormalMixture):")
println("Mean 1: ", round(m1, digits=3))
println("Mean 2: ", round(m2, digits=3))
println("Variance 1: ", round(v1, digits=3))
println("Variance 2: ", round(v2, digits=3))
@show result.free_energy[end]
plot(1:10, result.free_energy)Metadata
Metadata
Assignees
Labels
Type
Projects
Status

