Skip to content

Nested submodels slow down sampling #2844

Description

@simonsteiger

Hi!

I have recently simplified a model by nesting a @submodel inside another @submodel and realised that sampling was much slower.

I used the eight schools data to test this in a much simpler scenario, and found that the nested submodel implementation is slower here, too.

# Example adapted from
# https://gist.github.com/penelopeysm/5656697ea20c94d80a285f5f6a69b8ab

using Turing, Mooncake, LinearAlgebra
using DynamicPPL, Distributions
using ADTypes, ForwardDiff, ReverseDiff, Enzyme
using DynamicPPL.TestUtils.AD: run_ad

# Eight schools data
J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]

# Submodels

@model function tau_prior()
    @inline
    p ~ truncated(Cauchy(0, 5); lower = 0)
    return (; p)
end

@model function mu_prior()
    @inline
    p ~ Normal(0, 5)
    return (; p)
end

@model function z_prior(J)
    @inline
    p ~ MvNormal(zeros(J), I)
    return (; p)
end

@model function theta_prior(J)
    @inline
    mu ~ to_submodel(mu_prior())
    tau ~ to_submodel(tau_prior())
    z ~ to_submodel(z_prior(J))
    p := z.p .* tau.p .+ mu.p
    return (; p)
end

# Nested submodel

@model function esc_subm_nested(J, y, sigma)
    @inline
    theta ~ to_submodel(theta_prior(J))
    for i in 1:J
        y[i] ~ Normal(theta.p[i], sigma[i])
    end
end

# Single submodel

@model function esc_subm(J, y, sigma)
    mu ~ to_submodel(mu_prior())
    tau ~ to_submodel(tau_prior())
    z ~ to_submodel(z_prior(J))
    theta := z.p .* tau.p .+ mu.p
    for i in 1:J
        y[i] ~ Normal(theta[i], sigma[i])
    end
end

# No submodel

@model function esc(J, y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower = 0)
    z ~ MvNormal(zeros(J), I)
    theta := z .* tau .+ mu
    for i in 1:J
        y[i] ~ Normal(theta[i], sigma[i])
    end
end

# Same sampler for both models, deliberately high adapt_delta
sampler = NUTS(10000, 0.99; adtype=AutoMooncake())

model1 = esc_subm_nested(J, y, sigma)
model2 = esc_subm(J, y, sigma)
model3 = esc(J, y, sigma)

adtypes = [AutoForwardDiff(), AutoReverseDiff(), AutoEnzyme(), AutoMooncake()]
models = [model1, model2, model3]

res = NamedTuple[]
for adtype in adtypes
    for (i, model) in enumerate(models)
        adr = run_ad(model, adtype; test=false, benchmark=true)
        push!(res, (; n=i, adtype=string(adtype), grad=adr.grad_time, primal=adr.primal_time))
    end
end

# ... and some DataFrames.jl code to build the table below ...

Gradient and primal time are the same for the flat and the submodel versions, but take a considerable hit once submodels are nested.

Model ADType Grad Slowdown: grad Primal Slowdown: primal
Flat ForwardDiff 837.82 ns ref 201.20 ns ref
Flat ReverseDiff 54.17 μs ref 212.99 ns ref
Flat Enzyme 635.64 ns ref 210.85 ns ref
Flat Mooncake 1.57 μs ref 203.02 ns ref
Submodel ForwardDiff 977.77 ns 1.2x 193.69 ns 1.0x
Submodel ReverseDiff 56.10 μs 1.0x 201.15 ns 0.9x
Submodel Enzyme 632.26 ns 1.0x 196.35 ns 0.9x
Submodel Mooncake 1.73 μs 1.1x 208.90 ns 1.0x
Nested submodel ForwardDiff 4.56 μs 5.4x 2.84 μs 14.1x
Nested submodel Enzyme 52.04 μs 81.9x 3.15 μs 14.9x
Nested submodel ReverseDiff 55.17 μs 1.0x 2.87 μs 13.5x
Nested submodel Mooncake 51.04 μs 32.5x 3.13 μs 15.4x

Here's a link to a brief discussion about this on Discourse.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions