Skip to content

Improve documentation and error handling mixture node #444

@Nimrais

Description

@Nimrais

Initial Issue

The Mixture node's out marginalisation rule was failing with a cryptic error:

MethodError: no method matching getlogscale(::Nothing)

I was a bit surprised by that and I haven't found the tests for mixture node out rules, maybe I missing smt here. So I wrote them on my own. In the current ReactiveMP it fails. Could it be I am misunderstading how this node can be used correctly?

@testset "Marginalisation: (m_switch::Categorical, m_inputs::ManyOf)" begin
        @test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
            # Test case 1: Equal weights
            (
                input = (
                    m_switch = Categorical([0.5, 0.5]),
                    m_inputs = ManyOf(
                        NormalMeanVariance(0.0, 1.0),
                        NormalMeanVariance(2.0, 1.0)
                    )
                ),
                output = MixtureDistribution([
                    NormalMeanVariance(0.0, 1.0),
                    NormalMeanVariance(2.0, 1.0)
                ], [0.5, 0.5])
            ),
            # Test case 2: Unequal weights
            (
                input = (
                    m_switch = Categorical([0.8, 0.2]),
                    m_inputs = ManyOf(
                        NormalMeanVariance(1.0, 1.0),
                        NormalMeanVariance(5.0, 2.0)
                    )
                ),
                output = MixtureDistribution([
                    NormalMeanVariance(1.0, 1.0),
                    NormalMeanVariance(5.0, 2.0)
                ], [0.8, 0.2])
            )
        ]
    end

So I decided to re-write this rule myself.

I come up with the following implementation

@rule Mixture(:out, Marginalisation) (m_switch::Any, m_inputs::ManyOf{N, Any}) where {N} = begin
    # Get logscales, defaulting to 0.0 if Nothing
    logscales_inputs = map(msg -> getlogscale(getdata(msg)) === nothing ? 0.0 : getlogscale(getdata(msg)), messages[2])
    logscale_switch = getlogscale(getdata(messages[1])) === nothing ? 0.0 : getlogscale(getdata(messages[1]))

    # compute logscales of individual components
    logscales = logscales_inputs .+ logscale_switch

    @logscale logsumexp(logscales)

    # Use probabilities directly from m_switch
    w = probvec(m_switch)
    T = promote_type(eltype(w), map(x -> eltype(mean(x)), m_inputs)...)

    # Convert inputs to the promoted type
    typed_inputs = map(x -> convert_paramfloattype(T, x), m_inputs)

    # return mixture with type-preserved components
    return MixtureDistribution(collect(typed_inputs), collect(w))
end

Also to make the rule work I need to write quite some helping methods, so to run test now, you need to use the following code

@testitem "rules:Mixture:out" begin
    using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions

    import ReactiveMP: @test_rules
    import ReactiveMP: getlogscale
    import BayesBase: paramfloattype
    import Base: isapprox
    using ExponentialFamily: NormalMeanVariance
    
    function getlogscale(d::NormalMeanVariance{T}) where {T}
        μ, τ = mean_precision(d)
        # Log of normalization constant for Normal(μ, 1/√τ)
        return -0.5 * (log(2π) - log(τ))
    end

    function getlogscale(d::Categorical{T}) where {T}
        # Categorical distribution is already normalized
        return 0.0
    end

    # Add paramfloattype for MixtureDistribution
    function paramfloattype(d::MixtureDistribution{D, T}) where {D, T}
        # The float type should be the promoted type of both the component distributions and weights
        return promote_type(T, paramfloattype(first(d.components)))
    end

    # Add isapprox for MixtureDistribution
    function isapprox(x::MixtureDistribution, y::MixtureDistribution; kwargs...)
        # Check if components and weights match approximately
        return length(x.components) == length(y.components) &&
               all(isapprox.(x.components, y.components; kwargs...)) &&
               isapprox(x.weights, y.weights; kwargs...)
    end

    @testset "Marginalisation: (m_switch::Categorical, m_inputs::ManyOf)" begin
        @test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
            # Test case 1: Equal weights
            (
                input = (
                    m_switch = Categorical([0.5, 0.5]),
                    m_inputs = ManyOf(
                        NormalMeanVariance(0.0, 1.0),
                        NormalMeanVariance(2.0, 1.0)
                    )
                ),
                output = MixtureDistribution([
                    NormalMeanVariance(0.0, 1.0),
                    NormalMeanVariance(2.0, 1.0)
                ], [0.5, 0.5])
            ),
            # Test case 2: Unequal weights
            (
                input = (
                    m_switch = Categorical([0.8, 0.2]),
                    m_inputs = ManyOf(
                        NormalMeanVariance(1.0, 1.0),
                        NormalMeanVariance(5.0, 2.0)
                    )
                ),
                output = MixtureDistribution([
                    NormalMeanVariance(1.0, 1.0),
                    NormalMeanVariance(5.0, 2.0)
                ], [0.8, 0.2])
            )
        ]
    end

    @testset "Marginalisation: (m_inputs::ManyOf, q_switch::PointMass)" begin
        @test_rules [check_type_promotion = false] Mixture(:out, Marginalisation) [
            # Test case for one-hot encoded switch
            (
                input = (
                    m_inputs = ManyOf(
                        NormalMeanVariance(0.0, 1.0),
                        NormalMeanVariance(2.0, 1.0)
                    ),
                    q_switch = PointMass([1.0, 0.0])
                ),
                output = NormalMeanVariance(0.0, 1.0)
            ),
            (
                input = (
                    m_inputs = ManyOf(
                        NormalMeanVariance(0.0, 1.0),
                        NormalMeanVariance(2.0, 1.0)
                    ),
                    q_switch = PointMass([0.0, 1.0])
                ),
                output = NormalMeanVariance(2.0, 1.0)
            )
        ]
    end
end

Initially my goal was to run the following model

@model function mixture_model(y)
    # Use fixed mixing proportions
    s = [1/3, 2/3]
    
    # Parameters for the two Gaussian components
    m[1] ~ Normal(mean = -20.0, variance = 1e2)
    w[1] ~ InverseGamma(2.0, 1.0)  # Use InverseGamma for precision
    
    m[2] ~ Normal(mean = 20.0, variance = 1e2)
    w[2] ~ InverseGamma(2.0, 1.0)  # Use InverseGamma for precision
    
    obs_precision ~ InverseGamma(2.0, 1.0)
    
    # Generate mixture assignments and observations
    for i in eachindex(y)
        z[i] ~ Categorical(s)
        comp1[i] ~ NormalMeanVariance(m[1], w[1])  # Use variance parameterization
        comp2[i] ~ NormalMeanVariance(m[2], w[2])
        μ[i] ~ Mixture(switch = z[i], inputs = (comp1[i], comp2[i]))
        y[i] ~ NormalMeanVariance(μ[i], obs_precision)
    end
end

So the end-to-end script to run is the following one:

using RxInfer, Distributions
using Random
using Plots
using StatsPlots
using ReactiveMP: getlogscale

# Add getlogscale methods for both distribution types
import ReactiveMP: getlogscale
function getlogscale(d::NormalMeanPrecision{T}) where {T}
    μ, τ = mean_precision(d)
    # Log of normalization constant for Normal(μ, 1/√τ)
    return -0.5 * (log(2π) - log(τ))
end

function getlogscale(d::NormalWeightedMeanPrecision{T}) where {T}
    μ, τ = mean_precision(d)
    # Log of normalization constant for Normal(μ, 1/√τ)
    return -0.5 * (log(2π) - log(τ))
end



function getlogscale(d::NormalMeanVariance{T}) where {T}
    μ, τ = mean_precision(d)
    # Log of normalization constant for Normal(μ, 1/√τ)
    return -0.5 * (log(2π) - log(τ))
end

function getlogscale(d::Categorical{T}) where {T}
    # Categorical distribution is already normalized
    return 0.0
end

function getlogscale(_::Nothing)
    return 0.0
end

# include("mixture_entropy.jl") # include this to use free energy

@model function mixture_model(y)
    # Use fixed mixing proportions
    s = [1/3, 2/3]
    
    # Parameters for the two Gaussian components
    m[1] ~ Normal(mean = -20.0, variance = 1e2)
    w[1] ~ InverseGamma(2.0, 1.0)  # Use InverseGamma for precision
    
    m[2] ~ Normal(mean = 20.0, variance = 1e2)
    w[2] ~ InverseGamma(2.0, 1.0)  # Use InverseGamma for precision
    
    obs_precision ~ InverseGamma(2.0, 1.0)
    
    # Generate mixture assignments and observations
    for i in eachindex(y)
        z[i] ~ Categorical(s)
        comp1[i] ~ NormalMeanVariance(m[1], w[1])  # Use variance parameterization
        comp2[i] ~ NormalMeanVariance(m[2], w[2])
        μ[i] ~ Mixture(switch = z[i], inputs = (comp1[i], comp2[i]))
        y[i] ~ NormalMeanVariance(μ[i], obs_precision)
    end
end

# Update constraints
@constraints function mixture_constraints()
    q(z, m, w, μ, comp1, comp2, obs_precision) = q(z)q(m)q(w)q(μ)q(comp1)q(comp2)q(obs_precision)
    q(m) = q(m[1])q(m[2])
    q(w) = q(w[1])q(w[2])
    q(z) = q(z[1]) .. q(z[end])
    q(μ) = q(μ[1]) .. q(μ[end])
    q(comp1) = q(comp1[1]) .. q(comp1[end])
    q(comp2) = q(comp2[1]) .. q(comp2[end])
end

# Generate synthetic data with unequal proportions
rng = MersenneTwister(42)
true_means = [-20.0, 20.0]
true_precisions = [1.0, 1.0]
N = 1000
switch = [1/3, 2/3]  # Unequal proportions as in test
z = rand(rng, Categorical(switch), N)
data = zeros(N)
for i in 1:N
    data[i] = randn(rng)/sqrt(true_precisions[z[i]]) + true_means[z[i]]
end

# Update initialization with better starting points
init = @initialization begin
    # Initialize means further apart
    q(m[1]) = NormalMeanVariance(-30.0, 10.0)  # More uncertainty in initial means
    q(m[2]) = NormalMeanVariance(30.0, 10.0)
    
    # Initialize variances with more informative priors
    q(w[1]) = InverseGamma(3.0, 2.0)  # Mode around 1.0
    q(w[2]) = InverseGamma(3.0, 2.0)
    q(obs_precision) = InverseGamma(3.0, 2.0)
    
    for i in 1:N
        # Initialize assignments closer to true proportions
        q(z[i]) = Categorical([0.4, 0.6])
        
        # Initialize components with wider separation
        q(comp1[i]) = NormalMeanVariance(-20.0, 5.0)
        q(comp2[i]) = NormalMeanVariance(20.0, 5.0)
        q(μ[i]) = NormalMeanVariance(0.0, 100.0)  # Very uncertain about mixture means
    end
end

result = infer(
    model = mixture_model(),
    constraints = mixture_constraints(),
    initialization = init,
    data = (y = data,),
    iterations = 10,
    allow_node_contraction = true,
    options = (limit_stack_depth = 100,),
    # free_energy = true
)

# Create a range for plotting
x_range = range(minimum(data) - 1, maximum(data) + 1, length=200)

# Get the final parameters
m1 = mean(result.posteriors[:m][1][end])
m2 = mean(result.posteriors[:m][2][end])
v1 = mean(result.posteriors[:w][1][end])  # This is variance now
v2 = mean(result.posteriors[:w][2][end])

# Create the plot
p = histogram(data, normalize=true, alpha=0.3, label="Data", bins=50)
plot!(x_range, 
    x -> pdf(Normal(m1, sqrt(v1)), x),
    label="Component 1", linestyle=:dash)
plot!(x_range, 
    x -> pdf(Normal(m2, sqrt(v2)), x),
    label="Component 2", linestyle=:dash)
title!("Gaussian Mixture Model Fit")
xlabel!("x")
ylabel!("Density")

# Save the plot
savefig(p, "mixture_fit.png")

# Print the fitted parameters
println("\nFitted Parameters:")
println("Mean 1: ", round(m1, digits=3))
println("Mean 2: ", round(m2, digits=3))
println("Variance 1: ", round(v1, digits=3))
println("Variance 2: ", round(v2, digits=3))
println("Observation variance: ", round(mean(result.posteriors[:obs_precision][end]), digits=3))

# plot(1:10, result.free_energy)

Interestingly it shows a different behavior (and I would say more interesting one) comparing it with NormalMixture

Image

Normal mixture model (at least the following one) showing collapsing behavior:

using RxInfer, Distributions
using Random
using Plots
using StatsPlots
using ReactiveMP: getlogscale


@model function mixture_model_normal_mixture(y)
    # Use fixed mixing proportions
    s = [1/3, 2/3]
    
    # Parameters for the two Gaussian components
    m[1] ~ Normal(mean = -20.0, variance = 1e2)
    w[1] ~ GammaShapeRate(2.0, 1.0)
    
    m[2] ~ Normal(mean = 20.0, variance = 1e2)
    w[2] ~ GammaShapeRate(2.0, 1.0)
    
    # Generate mixture assignments and observations
    for i in eachindex(y)
        z[i] ~ Categorical(s)
        # Using p for precision interface as shown in the tests
        y[i] ~ NormalMixture(
            switch = z[i],
            m = (m[1], m[2]),
            p = (w[1], w[2])  # Changed from v to p to match test code
        )
    end
end

@constraints function mixture_constraints()
    q(z, m, w) = q(z)q(m)q(w)
    q(m) = q(m[1])q(m[2])
    q(w) = q(w[1])q(w[2])
    q(z) = q(z[1]) .. q(z[end])
end

# Use same data generation
rng = MersenneTwister(42)
true_means = [-20.0, 20.0]
true_precisions = [1.0, 1.0]
N = 1000
switch = [1/3, 2/3]
z = rand(rng, Categorical(switch), N)
data = zeros(N)
for i in 1:N
    data[i] = randn(rng)/sqrt(true_precisions[z[i]]) + true_means[z[i]]
end

# Use same initialization strategy
init = @initialization begin
    q(m[1]) = NormalMeanVariance(-20.0, 5.0)
    q(m[2]) = NormalMeanVariance(20.0, 5.0)
    
    q(w[1]) = GammaShapeRate(3.0, 2.0)
    q(w[2]) = GammaShapeRate(3.0, 2.0)
    
    for i in 1:N
        q(z[i]) = Categorical([0.4, 0.6])
    end
end

result = infer(
    model = mixture_model_normal_mixture(),
    constraints = mixture_constraints(),
    initialization = init,
    data = (y = data,),
    iterations = 10,
    allow_node_contraction = true,
    options = (limit_stack_depth = 100,),
    free_energy = true
)

# Plotting
x_range = range(minimum(data) - 1, maximum(data) + 1, length=200)

m1 = mean(result.posteriors[:m][1][end])
m2 = mean(result.posteriors[:m][2][end])
v1 = mean(result.posteriors[:w][1][end])
v2 = mean(result.posteriors[:w][2][end])

p = histogram(data, normalize=true, alpha=0.3, label="Data", bins=50)
plot!(x_range, 
    x -> pdf(Normal(m1, sqrt(v1)), x),
    label="Component 1", linestyle=:dash)
plot!(x_range, 
    x -> pdf(Normal(m2, sqrt(v2)), x),
    label="Component 2", linestyle=:dash)
title!("Gaussian Mixture Model Fit (NormalMixture)")
xlabel!("x")
ylabel!("Density")

savefig(p, "normal_mixture_fit.png")

println("\nFitted Parameters (NormalMixture):")
println("Mean 1: ", round(m1, digits=3))
println("Mean 2: ", round(m2, digits=3))
println("Variance 1: ", round(v1, digits=3))
println("Variance 2: ", round(v2, digits=3))


@show result.free_energy[end]

plot(1:10, result.free_energy)

Image

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

Status

Backlog

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions