-
Notifications
You must be signed in to change notification settings - Fork 37
Description
This is currently what check_model does (on main, i.e., DynamicPPL 0.38):
DynamicPPL.jl/src/debug_utils.jl
Lines 427 to 428 in 052bc19
| # Force single-threaded execution. | |
| _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) |
The problem is that evaluate_threadunsafe!! does not actually force single-threaded execution. Julia is still running with whatever number of threads it was launched with.
Rather, evaluate_threadunsafe!! disables the use of TSVI. What that means is that there may be data races if pushing to the same accumulator at the same time.
For example, launch Julia with more than 1 thread and run this. check_model will stochastically error, so it's not actually thread-safe at all.
using DynamicPPL, Distributions
@model function f()
Threads.@threads for i in 1:100
x ~ Normal()
end
end
model = DynamicPPL.setleafcontext(f(), InitContext())
check_model(model, VarInfo())So, with DPPL 0.38 on the way out and TSVI being reworked a little bit (#1151), what happens?
Well, instead of calling evaluate_threadunsafe!! we will now call evaluate!!, which DOES use TSVI if it's necessary.
DynamicPPL.jl/src/debug_utils.jl
Lines 427 to 430 in 9d0edc4
| # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a | |
| # check on the merged accumulator, rather than checking it in the accumulate_assume | |
| # calls. That way we can also correctly support multi-threaded evaluation. | |
| _, varinfo = DynamicPPL.evaluate!!(model, varinfo) |
This will actually stop it from erroring because we will have one DebugAccumulator per thread, which avoids the data race. However, it presents a different problem.
Now consider this code, which is the same as above, but with the number of loop iterations reduced, and with the explicit enabling of TSVI with setthreadsafe. Again make sure to run this with more than one thread.
using DynamicPPL, Distributions
@model function f()
Threads.@threads for i in 1:2
x ~ Normal()
end
end
model = setleafcontext(setthreadsafe(f(), true), InitContext())
check_model(model, VarInfo())check_model will now claim that there is nothing wrong with the model, even though it is nonsensical to have x ~ Normal() multiple times. The reason is because the two x ~ Normal() statements are sent to two different DebugAccumulators, each of which individually think that it's OK.
If you run Julia with 1 thread, it will detect the error because there will then only be one DebugAccumulator which sees the same x. In general, if the number of loop iterations is more than the number of threads, it will detect the error (because of the pigeonhole principle).
I think this is better behaviour than before (it's a failure to catch a mistake, rather than a bug in the code); but it's obviously still not correct. What should really happen is that there needs to be a combine method for DebugAccumulator, and the check for correctness should be done on the combined one rather than each individual one.