Skip to content

Commit 8d47718

Browse files
fix: pin PyMC version below 5.20.1 to avoid TypeError (closes #1397) (#1697)
* fix: pin PyMC version below 5.20.1 to avoid TypeError (closes #1397) * fix: Disable PyMC progress bar to resolve TypeError * Add note and warning and update tests
1 parent a1f0343 commit 8d47718

File tree

7 files changed

+21
-63
lines changed

7 files changed

+21
-63
lines changed

sbi/inference/posteriors/mcmc_posterior.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,16 @@ def _pymc_mcmc(
908908
num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains
909909
steps = dict(slice_pymc="slice", hmc_pymc="hmc", nuts_pymc="nuts")
910910

911+
if show_progress_bars:
912+
warn(
913+
"Note: progress bars for PyMC sampling are disabled due to an "
914+
"incompatibility with PyMC>=5.20.1; "
915+
"progressbar will be set to False. See PR #1697: "
916+
"https://github.com/sbi-dev/sbi/pull/1697",
917+
UserWarning,
918+
stacklevel=2,
919+
)
920+
911921
sampler = PyMCSampler(
912922
potential_fn=potential_function,
913923
step=steps[mcmc_method],
@@ -916,7 +926,7 @@ def _pymc_mcmc(
916926
tune=warmup_steps,
917927
chains=num_chains,
918928
mp_ctx=mp_context,
919-
progressbar=show_progress_bars,
929+
progressbar=False,
920930
param_name=self.param_name,
921931
device=self._device,
922932
)

sbi/samplers/mcmc/pymc_wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
tune: int = 1000,
113113
chains: Optional[int] = None,
114114
mp_ctx: str = "spawn",
115-
progressbar: bool = True,
115+
progressbar: bool = False,
116116
param_name: str = "theta",
117117
device: str = "cpu",
118118
):
@@ -127,6 +127,10 @@ def __init__(
127127
chains: Number of MCMC chains to run in parallel.
128128
mp_ctx: Multiprocessing context for parallel sampling.
129129
progressbar: Whether to show/hide progress bars.
130+
Note: Progress bars are disabled for PyMC sampling in this
131+
project due to an incompatibility from PyMC >= 5.20.1 and
132+
that affects progress display.
133+
See PR #1697 for details: https://github.com/sbi-dev/sbi/pull/1697
130134
param_name: Name for parameter variable, for PyMC and ArviZ structures
131135
device: The device to which to move the parameters for potential_fn.
132136
"""

tests/inference_on_device_test.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33

44
from __future__ import annotations
55

6-
import sys
76
from dataclasses import asdict
87
from typing import Tuple, Union
98

10-
import pymc
119
import pytest
1210
import torch
1311
import torch.distributions.transforms as torch_tf
@@ -76,19 +74,7 @@
7674
pytest.param(NRE_B, "resnet", "slice_np", marks=pytest.mark.mcmc),
7775
(NRE_C, "resnet", "rejection"),
7876
(NRE_C, "resnet", "importance"),
79-
pytest.param(
80-
NRE_C,
81-
"resnet",
82-
"nuts_pymc",
83-
marks=(
84-
pytest.mark.mcmc,
85-
pytest.mark.skipif(
86-
condition=sys.version_info >= (3, 10)
87-
and pymc.__version__ >= "5.20.1",
88-
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
89-
),
90-
),
91-
),
77+
pytest.param(NRE_C, "resnet", "nuts_pymc", marks=pytest.mark.mcmc),
9278
],
9379
)
9480
@pytest.mark.parametrize(

tests/linearGaussian_snle_test.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
from __future__ import annotations
55

6-
import sys
76
from dataclasses import asdict
87

9-
import pymc
108
import pytest
119
import torch
1210
from torch import eye, ones, zeros
@@ -397,18 +395,7 @@ def simulator(theta):
397395
pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc),
398396
pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc),
399397
pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
400-
pytest.param(
401-
"nuts_pymc",
402-
"gaussian",
403-
marks=(
404-
pytest.mark.mcmc,
405-
pytest.mark.skipif(
406-
condition=sys.version_info >= (3, 10)
407-
and pymc.__version__ >= "5.20.1",
408-
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
409-
),
410-
),
411-
),
398+
pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc),
412399
pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
413400
pytest.param("hmc_pymc", "gaussian", marks=pytest.mark.mcmc),
414401
("rejection", "uniform"),

tests/linearGaussian_snre_test.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
from __future__ import annotations
55

6-
import sys
76
from dataclasses import asdict
87

9-
import pymc
108
import pytest
119
from torch import eye, ones, zeros
1210
from torch.distributions import MultivariateNormal
@@ -319,18 +317,7 @@ def simulator(theta):
319317
pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc),
320318
pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc),
321319
pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
322-
pytest.param(
323-
"nuts_pymc",
324-
"gaussian",
325-
marks=(
326-
pytest.mark.mcmc,
327-
pytest.mark.skipif(
328-
condition=sys.version_info >= (3, 10)
329-
and pymc.__version__ >= "5.20.1",
330-
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
331-
),
332-
),
333-
),
320+
pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc),
334321
pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
335322
pytest.param("hmc_pyro", "gaussian", marks=pytest.mark.mcmc),
336323
("rejection", "uniform"),

tests/mcmc_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33

44
from __future__ import annotations
55

6-
import sys
76
from dataclasses import asdict
87

98
import numpy as np
10-
import pymc
119
import pytest
1210
import torch
1311
from torch import eye, ones, zeros
@@ -193,13 +191,7 @@ def lp_f(x, track_gradients=True):
193191
(
194192
"nuts_pyro",
195193
"hmc_pyro",
196-
pytest.param(
197-
"nuts_pymc",
198-
marks=pytest.mark.skipif(
199-
condition=sys.version_info >= (3, 10) and pymc.__version__ >= "5.20.1",
200-
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
201-
),
202-
),
194+
"nuts_pymc",
203195
"hmc_pymc",
204196
"slice_pymc",
205197
"slice_np",

tests/posterior_sampler_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
from __future__ import annotations
55

6-
import sys
76
from dataclasses import asdict
87

9-
import pymc
108
import pytest
119
from pyro.infer.mcmc import MCMC
1210
from torch import Tensor, eye, zeros
@@ -30,13 +28,7 @@
3028
"slice_np_vectorized",
3129
"nuts_pyro",
3230
"hmc_pyro",
33-
pytest.param(
34-
"nuts_pymc",
35-
marks=pytest.mark.skipif(
36-
condition=sys.version_info >= (3, 10) and pymc.__version__ >= "5.20.1",
37-
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
38-
),
39-
),
31+
"nuts_pymc",
4032
"hmc_pymc",
4133
"slice_pymc",
4234
),

0 commit comments

Comments
 (0)