Skip to content

Commit 6234599

Browse files
committed
Explain a parameter option and update blackjax dependency
1 parent 47bf969 commit 6234599

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

probdiffeq/probdiffeq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ def make_model(s):
260260
return ssm.conditional.to_derivative(tcoeff_index, s)
261261

262262
model = func.vmap(make_model)(std)
263+
264+
# Use solve_triu=lstsq because for noise-free observations, the initial state
265+
# of the ODE solution tends to be noise-free,
266+
# which clashes and returns NaNs if we use exact solvers.
263267
return posterior.evaluate_lml(
264268
u, model=model, ssm=ssm, average_pdfs=average_pdfs, solve_triu=solve_triu
265269
)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ doc = [
4848
"tueplots",
4949
"tqdm",
5050
"optax",
51-
"blackjax>=1.0.0",
51+
"blackjax[progress]>=1.0.0",
5252
"diffrax",
5353
"numba",
5454
# mkdocs 2.0 is a backwards-incompatible rewrite
@@ -141,7 +141,7 @@ ignore = [
141141
"D105",
142142
# Some backend's names shadow builtins, eg backend.abc. Ignore the warnings.
143143
"A005",
144-
# Line-too-long is handled by the formatter, not the linter
144+
# Line-too-long is handled by the formatter, not the linter
145145
"E501",
146146
]
147147

0 commit comments

Comments
 (0)