Skip to content

Commit dc74e85

Browse files
seabbs-botseabbs
andcommitted
fix: use minimal fit for pre-compilation instead of empty = TRUE
update() doesn't work on empty brmsfits, so use a minimal fit (chains = 1, iter = 5) for pre-compilation instead. Co-authored-by: Sam Abbott <contact@samabbott.co.uk>
1 parent df856a3 commit dc74e85

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

vignettes/approx-inference.Rmd

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,17 @@ data <- as_epidist_marginal_model(linelist_data)
154154
155155
# Pre-compile the model so compilation time is not included
156156
# in the timing comparisons
157-
fit_empty <- epidist(
158-
data = data, backend = "cmdstanr", empty = TRUE
157+
fit_compile <- epidist(
158+
data = data, backend = "cmdstanr",
159+
chains = 1, iter = 5
159160
)
160161
```
161162

162163
We now perform inference with HMC:
163164

164165
```{r results='hide'}
165166
t <- proc.time()
166-
fit_hmc <- update(fit_empty, chains = 4, iter = 2000)
167+
fit_hmc <- update(fit_compile, chains = 4, iter = 2000)
167168
time_hmc <- proc.time() - t
168169
```
169170

@@ -175,13 +176,13 @@ To match the four Markov chains of length 1000 in HMC above, we then draw 4000 s
175176
```{r results='hide'}
176177
t <- proc.time()
177178
fit_laplace <- update(
178-
fit_empty, algorithm = "laplace", draws = 4000
179+
fit_compile, algorithm = "laplace", draws = 4000
179180
)
180181
time_laplace <- proc.time() - t
181182
182183
t <- proc.time()
183184
fit_advi <- update(
184-
fit_empty, algorithm = "meanfield", draws = 4000
185+
fit_compile, algorithm = "meanfield", draws = 4000
185186
)
186187
time_advi <- proc.time() - t
187188
```
@@ -191,7 +192,7 @@ For the Pathfinder algorithm we will set `num_paths = 1`.
191192
```{r results='hide'}
192193
t <- proc.time()
193194
fit_pathfinder <- update(
194-
fit_empty, algorithm = "pathfinder", draws = 4000,
195+
fit_compile, algorithm = "pathfinder", draws = 4000,
195196
chains = 1
196197
)
197198
time_pathfinder <- proc.time() - t

0 commit comments

Comments
 (0)