@@ -76,6 +76,46 @@ function getcontext(f::LogDensityFunction)
76
76
return f. context === nothing ? leafcontext (f. model. context) : f. context
77
77
end
78
78
79
+ """
80
+ getmodel(f)
81
+
82
+ Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
83
+ """
84
+ getmodel (f:: LogDensityProblemsAD.ADGradientWrapper ) =
85
+ getmodel (LogDensityProblemsAD. parent (f))
86
+ getmodel (f:: DynamicPPL.LogDensityFunction ) = f. model
87
+
88
+ """
89
+ setmodel(f, model[, adtype])
90
+
91
+ Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
92
+
93
+ !!! warning
94
+ Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
95
+ `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
96
+ might require recompilation of the gradient tape, depending on the AD backend.
97
+ """
98
+ function setmodel (
99
+ f:: LogDensityProblemsAD.ADGradientWrapper ,
100
+ model:: DynamicPPL.Model ,
101
+ adtype:: ADTypes.AbstractADType ,
102
+ )
103
+ # TODO : Should we handle `SciMLBase.NoAD`?
104
+ # For an `ADGradientWrapper` we do the following:
105
+ # 1. Update the `Model` in the underlying `LogDensityFunction`.
106
+ # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
107
+ # to ensure that the recompilation of gradient tapes, etc. also occur. For example,
108
+ # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
109
+ # replacing the corresponding field with the new model won't be sufficient to obtain
110
+ # the correct gradients.
111
+ return LogDensityProblemsAD. ADgradient (
112
+ adtype, setmodel (LogDensityProblemsAD. parent (f), model)
113
+ )
114
+ end
115
+ function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
116
+ return Accessors. @set f. model = model
117
+ end
118
+
79
119
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
80
120
# we need to define these annoying methods to ensure that we stay compatible with everything.
81
121
getsampler (f:: LogDensityFunction ) = getsampler (getcontext (f))
0 commit comments