-
Notifications
You must be signed in to change notification settings - Fork 11
[WIP] New libtask interface #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks for having a look at this!
Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)
I agree re not wanting ot dig into As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do. |
using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems
mutable struct Model <: AdvancedPS.AbstractGenericModel
x::Float64
y::Float64
Model() = new()
end
function (model::Model)()
rng = Libtask.get_dynamic_scope()
model.x = rand(rng, Beta(1,1))
Libtask.produce(model.x)
rng = Libtask.get_dynamic_scope()
model.y = rand(rng, Normal(0, model.x))
Libtask.produce(model.y)
end
rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()
trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)
trace2 = AdvancedPS.fork(trace)
key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)
Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])
# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)
println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])
# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory,
# sampled values are properly captured during the execution of the task
|
@FredericWantiez can we store |
@willtebbutt I think 2) might also be a problem for Turing, when looking at this part: |
Two small issues I found cleaning up the tests. Libtask returns a value after the last produce statement: function f()
Libtask.produce(1)
Libtask.produce(2)
end
t1 = TapedTask(nothing, f)
consume(t1) # 1
consume(t1) # 2
consume(t2) # 2 (?) Libtask doesn't catch some of the produce statements: mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_dynamic_scope()
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.get_dynamic_scope()
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
return nothing
end
rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())
consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?) this works fine if I call EDIT: Changing function AdvancedPS.observe(dist::Distributions.Distribution, x)
Libtask.produce(Distributions.loglikelihood(dist, x))
return nothing
end |
If we store both
|
That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct. The other solution is to use one new_particle = AdvancedPS.replay(particle)
transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
state = SMCState(particles, 2, logevidence)
return transition, state |
@willtebbutt running models against this PR I see a large performance drop: using Libtask
using AdvancedPS
using Distributions
using Random
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_dynamic_scope()
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.get_dynamic_scope()
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
end
@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false) On master:
On this PR:
|
Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance. |
The |
@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is (jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
[576499cb] AdvancedPS v0.6.1
[31c24e10] Distributions v0.25.118
[6f1fad26] Libtask v0.8.8
[9a3f8284] Random v1.11.0 I tried it on LTS and 1.11.4. In particular, I'm seeing the error: ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
[1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
@ Base ./essentials.jl:14
[2] getindex
@ ./essentials.jl:916 [inlined]
[3] _infer(f::NormalModel, args_type::Tuple{DataType})
@ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
[4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
@ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
[5] TapedFunction
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
[6] _
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
[7] TapedFunction
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
[8] #TapedTask#15
@ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
[9] TapedTask
@ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
[10] LibtaskModel
@ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
[11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
[12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
[13] iterate
@ ./generator.jl:48 [inlined]
[14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
@ Base ./array.jl:811
[15] collect_similar
@ ./array.jl:720 [inlined]
[16] map
@ ./abstractarray.jl:3371 [inlined]
[17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
[18] macro expansion
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
[19] macro expansion
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
[20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
[21] mcmcsample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
[22] #sample#20
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
[23] sample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
[24] #sample#19
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
[25] sample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
[26] macro expansion
@ ./timing.jl:581 [inlined]
[27] top-level scope
@ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types. Any idea whether I'm doing something wrong? |
But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to |
@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that: function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
model.sig = rand(rng, Beta(1, 1))
Libtask.produce(model.sig)
model.mu = rand(rng, Normal())
Libtask.produce(model.mu)
end |
Integrate refactor from TuringLang/Libtask.jl#179
Two things worth noting:
and now:
tapedtask.fargs
AdvancedPS.jl/ext/AdvancedPSLibtaskExt.jl
Lines 89 to 91 in 50d493c
@willtebbutt