@@ -334,90 +334,53 @@ function AbstractMCMC.step(
334334end
335335
336336"""
337- get_trace_local_varinfo_maybe(vi::AbstractVarInfo )
337+ get_trace_local_varinfo( )
338338
339339Get the varinfo stored in the 'taped globals' of a `Libtask.TapedTask`, if one exists.
340- If this function is not called within a TapedTask, return the provided `varinfo`.
341340"""
342- function get_trace_local_varinfo_maybe (varinfo:: AbstractVarInfo )
343- trace = try
344- Libtask. get_taped_globals (Any). other
345- catch e
346- e == KeyError (:task_variable ) ? nothing : rethrow (e)
347- end
348- has_trace = trace != = nothing
349- return (has_trace ? trace. model. f. varinfo : varinfo):: AbstractVarInfo
341+ function get_trace_local_varinfo ()
342+ trace = Libtask. get_taped_globals (Any). other
343+ return trace. model. f. varinfo:: AbstractVarInfo
350344end
351345
352346"""
353- get_trace_local_resampled_maybe(fallback_resampled::Bool )
347+ get_trace_local_resampled( )
354348
355- Get the Boolean `resample` value stored in the 'taped globals' of a `Libtask.TapedTask`, if
356- one exists. If this function is not called within a TapedTask, return the provided
357- `fallback_resampled` value.
349+ Get the Boolean `resample` value stored in the 'taped globals' of a `Libtask.TapedTask`.
358350"""
359- function get_trace_local_resampled_maybe (fallback_resampled:: Bool )
360- trace = try
361- Libtask. get_taped_globals (Any). other
362- catch e
363- e == KeyError (:task_variable ) ? nothing : rethrow (e)
364- end
365- has_trace = trace != = nothing
366- return (has_trace ? trace. model. f. resample : fallback_resampled):: Bool
351+ function get_trace_local_resampled ()
352+ trace = Libtask. get_taped_globals (Any). other
353+ return trace. model. f. resample:: Bool
367354end
368355
369356"""
370- get_trace_local_rng_maybe(rng::Random.AbstractRNG )
357+ get_trace_local_rng( )
371358
372- Get the `Trace` local rng if one exists.
373-
374- If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the
375- task, otherwise return `vi`.
359+ Get the RNG stored in the 'taped globals' of a `Libtask.TapedTask`, if one exists.
376360"""
377- function get_trace_local_rng_maybe (rng:: Random.AbstractRNG )
378- return try
379- Libtask. get_taped_globals (Any). rng
380- catch e
381- e == KeyError (:task_variable ) ? rng : rethrow (e)
382- end
361+ function get_trace_local_rng ()
362+ return Libtask. get_taped_globals (Any). rng
383363end
384364
385365"""
386- set_trace_local_varinfo_maybe (vi::AbstractVarInfo)
366+ set_trace_local_varinfo (vi::AbstractVarInfo)
387367
388368Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`.
389369
390370If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the
391371task. Otherwise do nothing.
392372"""
393- function set_trace_local_varinfo_maybe (vi:: AbstractVarInfo )
394- # TODO (mhauru) This should be done in a try-catch block, as in the commented out code.
395- # However, Libtask currently can't handle this block.
396- trace = # try
397- Libtask. get_taped_globals (Any). other
398- # catch e
399- # e == KeyError(:task_variable) ? nothing : rethrow(e)
400- # end
401- if trace != = nothing
402- # trace isa AdvancedPS.Trace (defined in AdvancedPS src/model.jl)
403- # trace.model isa AdvancedPS.LibtaskModel (defined in AdvancedPSLibtaskExt)
404- # trace.model.f isa TracedModel (defined above)
405- trace. model. f. varinfo = vi
406- end
407- return nothing
373+ function set_trace_local_varinfo (vi:: AbstractVarInfo )
374+ trace = Libtask. get_taped_globals (Any). other
375+ return trace. model. f. varinfo = vi
408376end
409377
410378function DynamicPPL. tilde_assume!! (
411- ctx:: ParticleMCMCContext ,
412- dist:: Distribution ,
413- vn:: VarName ,
414- template:: Any ,
415- vi:: AbstractVarInfo ,
379+ :: ParticleMCMCContext , dist:: Distribution , vn:: VarName , template:: Any , :: AbstractVarInfo
416380)
417- vi = get_trace_local_varinfo_maybe (vi)
418-
419- trng = get_trace_local_rng_maybe (ctx. rng)
420- resample = get_trace_local_resampled_maybe (true )
381+ vi = get_trace_local_varinfo ()
382+ trng = get_trace_local_rng ()
383+ resample = get_trace_local_resampled ()
421384
422385 dispatch_ctx = if ~ haskey (vi, vn) || resample
423386 DynamicPPL. InitContext (trng, DynamicPPL. InitFromPrior ())
@@ -426,7 +389,7 @@ function DynamicPPL.tilde_assume!!(
426389 end
427390 x, vi = DynamicPPL. tilde_assume!! (dispatch_ctx, dist, vn, template, vi)
428391
429- set_trace_local_varinfo_maybe (vi)
392+ set_trace_local_varinfo (vi)
430393 return x, vi
431394end
432395
@@ -437,9 +400,9 @@ function DynamicPPL.tilde_observe!!(
437400 vn:: Union{VarName,Nothing} ,
438401 vi:: AbstractVarInfo ,
439402)
440- vi = get_trace_local_varinfo_maybe (vi )
403+ vi = get_trace_local_varinfo ( )
441404 left, vi = DynamicPPL. tilde_observe!! (DefaultContext (), right, left, vn, vi)
442- set_trace_local_varinfo_maybe (vi)
405+ set_trace_local_varinfo (vi)
443406 return left, vi
444407end
445408
0 commit comments