Skip to content

[WIP] New libtask interface #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

FredericWantiez
Copy link
Member

@FredericWantiez FredericWantiez commented Mar 23, 2025

Integrate refactor from TuringLang/Libtask.jl#179

Two things worth noting:

  1. Dealing with the RNG will be the user's responsability. Before
mutable struct Model <: AdvancedPS.AbstractGenericModel
  mu::Float64
  sig::Float64

  Model() = new()
end


function (model::Model)(rng::Random.AbstractRNG)
  model.sig = rand(rng, Beta(1, 1))  # AdvancedPS took care of syncing these
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

and now:

function (model::Model)()
  rng = Libtask.get_dynamic_scope() # We now need to query the RNG explicitly
  model.sig = rand(rng, Beta(1, 1))
  Libtask.produce(model.sig)

  rng = Libtask.get_dynamic_scope() # and do it everytime we want to sample random values
  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end
  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs
    function AdvancedPS.forkr(trace::LibtaskTrace)
    newf = AdvancedPS.reset_model(trace.model.ctask.fargs[1])
    Random123.set_counter!(trace.rng, 1)

@willtebbutt

@FredericWantiez FredericWantiez changed the title New libtask interface [WIP] New libtask interface Mar 23, 2025
@willtebbutt
Copy link
Member

Thanks for having a look at this!

  1. Dealing with the RNG will be the user's responsability. Before

Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)

  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs

I agree re not wanting ot dig into tapedtask.fargs. Could you elaborate a little bit on what is required here? My understanding was that task copying would handle this -- i.e. when you copy a task, all references to the model get updated, so from the perspective of the code inside the task, things just continue as normal.

As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do.

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Mar 25, 2025

  1. We can drop this one, that really only applies when AdvancedPS is used with Libtask outside of Turing. We will probably sunset that (or target people who supposedly know enough about Libtask)

  2. Still not 100% sure about Turing but we need something like this to manage the reference particle in the Particle Gibbs loop. Here's a mvp that should replicate a simple loop of the algo:

using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems


mutable struct Model <: AdvancedPS.AbstractGenericModel
  x::Float64
  y::Float64

  Model() = new()
end


function (model::Model)()
  rng = Libtask.get_dynamic_scope()
  model.x = rand(rng, Beta(1,1))
  Libtask.produce(model.x)

  rng = Libtask.get_dynamic_scope()
  model.y = rand(rng, Normal(0, model.x))
  Libtask.produce(model.y)
end

rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()

trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)

trace2 = AdvancedPS.fork(trace)

key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)

Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])

# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)

println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])

# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory, 
# sampled values are properly captured during the execution of the task

@yebai
Copy link
Member

yebai commented Mar 26, 2025

println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the Model
Note, this is only a problem when creating a reference trajectory,

@FredericWantiez can we store trace.rng inside TapedTask instead of trace? That way, when copying a TapedTask, we will copy the trace.rng.

@FredericWantiez
Copy link
Member Author

@willtebbutt I think 2) might also be a problem for Turing, when looking at this part:
https://github.com/TuringLang/Turing.jl/blob/afb5c44d6dc1736831f45620328c9d5681748111/src/mcmc/particle_mcmc.jl#L140-L142

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 5, 2025

Two small issues I found cleaning up the tests.

Libtask returns a value after the last produce statement:

function f()
  Libtask.produce(1)
  Libtask.produce(2)
end

t1 = TapedTask(nothing, f)
consume(t1)  # 1
consume(t1)  # 2
consume(t2)  # 2 (?)

Libtask doesn't catch some of the produce statements:

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
  a::Float64
  b::Float64

  NormalModel() = new()
end

function (m::NormalModel)()
  # First latent variable.
  rng = Libtask.get_dynamic_scope()
  m.a = a = rand(rng, Normal(4, 5))

  # First observation.
  AdvancedPS.observe(Normal(a, 2), 3)

  # Second latent variable.
  rng = Libtask.get_dynamic_scope()
  m.b = b = rand(rng, Normal(a, 1))

  # Second observation.
  AdvancedPS.observe(Normal(b, 2), 1.5)
  return nothing
end

rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())

consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?)

this works fine if I call Libtask.produce explicitly instead of observe

EDIT: Changing observe to something like this seems to work:

function AdvancedPS.observe(dist::Distributions.Distribution, x)
    Libtask.produce(Distributions.loglikelihood(dist, x))
    return nothing
end

@yebai
Copy link
Member

yebai commented Apr 8, 2025

If we store both rng and varinfo in the scoped variable, then the following suggestions will address (2):

  • store varinfo in the Trace struct, then change here to Libtask.set_dynamic_scope!(trace.model.ctask, (trace.rng, trace.varinfo))
  • change here and here to rng, varinfo = Libtask.get_dynamic_scope()
  • change here to transition = SMCTransition(model, particle.varinfo, weight)

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 8, 2025

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The other solution is to use one replay step before the transition, to repopulate the varinfo properly:

    new_particle = AdvancedPS.replay(particle)
    transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
    state = SMCState(particles, 2, logevidence)
    return transition, state

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 8, 2025

@willtebbutt running models against this PR I see a large performance drop:

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    rng = Libtask.get_dynamic_scope()
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.get_dynamic_scope()
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

On master:

1.816623 seconds (5.92 M allocations: 311.647 MiB, 1.52% gc time, 96.09% compilation time)

On this PR:

72.085056 seconds (369.62 M allocations: 17.322 GiB, 2.83% gc time, 77.21% compilation time)

@willtebbutt
Copy link
Member

Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance.

@yebai
Copy link
Member

yebai commented Apr 9, 2025

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The varinfo variable is updated during inference. I think we have to carefully ensure the correct varinfo is stored in the scoped variable.

cc @mhauru @FredericWantiez

@willtebbutt
Copy link
Member

@willtebbutt running models against this PR I see a large performance drop:

@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is

(jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
  [576499cb] AdvancedPS v0.6.1
  [31c24e10] Distributions v0.25.118
  [6f1fad26] Libtask v0.8.8
  [9a3f8284] Random v1.11.0

I tried it on LTS and 1.11.4.

In particular, I'm seeing the error:

ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
  [1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
    @ Base ./essentials.jl:14
  [2] getindex
    @ ./essentials.jl:916 [inlined]
  [3] _infer(f::NormalModel, args_type::Tuple{DataType})
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
  [4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
  [5] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
  [6] _
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [7] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [8] #TapedTask#15
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
  [9] TapedTask
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
 [10] LibtaskModel
    @ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
 [11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
 [12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
 [13] iterate
    @ ./generator.jl:48 [inlined]
 [14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:811
 [15] collect_similar
    @ ./array.jl:720 [inlined]
 [16] map
    @ ./abstractarray.jl:3371 [inlined]
 [17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
 [18] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
 [19] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
 [20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
 [21] mcmcsample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
 [22] #sample#20
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
 [23] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
 [24] #sample#19
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
 [25] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
 [26] macro expansion
    @ ./timing.jl:581 [inlined]
 [27] top-level scope
    @ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.

Any idea whether I'm doing something wrong?

@willtebbutt
Copy link
Member

But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to Libtask.get_dynamic_scope, which should be the type of the thing that it's going to return. We need this because there's no way to make the container typed (I assume that the previous implementation had a similar limitation). The docstring has been updated to reflect the changes.

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 15, 2025

@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that:

function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
  model.sig = rand(rng, Beta(1, 1)) 
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants