Skip to content

Commit a72d750

Browse files
build: drop py3.9, fix deps, fix types, lc2st tests (#1412)
* chore: drop Python 3.9 support, update workflows and docs. * lint: checks and linting with py3.10. * remove uv lock, update docs action * fix pymc x-failing tests * fix python 3.13 induced errors * fix torch deps and plot pyright * Fix circular import test * test more python versions * add numfocus badge and credits in readme * remove testmon temporarily * add testmon adain * test python version in CD * change locally passing but ci failing test * turn off fast fail * remote test still failing, increas n --------- Co-authored-by: manuelgloeckler <manu.gloeckler@hotmail.de>
1 parent 69b7f38 commit a72d750

28 files changed

+108
-87
lines changed

.github/workflows/build_docs.yml

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
- main
66
release:
77
types: [ published ]
8+
workflow_dispatch:
89

910
jobs:
1011
docs:
@@ -17,31 +18,24 @@ jobs:
1718
fetch-depth: 0
1819
lfs: false
1920

20-
- name: Set up Python
21-
uses: actions/setup-python@v2
21+
- name: Install uv and set the python version
22+
uses: astral-sh/setup-uv@v5
2223
with:
2324
python-version: '3.10'
25+
enable-cache: true
26+
cache-dependency-glob: "pyproject.toml"
2427

25-
- name: Cache dependency
26-
id: cache-dependencies
27-
uses: actions/cache@v4
28-
with:
29-
path: ~/.cache/pip
30-
key: ${{ runner.os }}-pip
31-
32-
- name: Install sbi and dependencies
33-
run: |
34-
python -m pip install --upgrade pip
35-
python -m pip install .[doc]
28+
- name: Install dependencies with uv
29+
run: uv sync --all-extras --doc
3630

3731
- name: strip output except plots and prints from tutorial notebooks
3832
run: |
39-
python tests/strip_notebook_outputs.py tutorials/
33+
uv run python tests/strip_notebook_outputs.py tutorials/
4034
4135
- name: convert notebooks to markdown
4236
run: |
4337
cd docs
44-
jupyter nbconvert --to markdown ../tutorials/*.ipynb --output-dir docs/tutorials/
38+
uv run jupyter nbconvert --to markdown ../tutorials/*.ipynb --output-dir docs/tutorials/
4539
4640
- name: Configure Git user for bot
4741
run: |
@@ -52,10 +46,10 @@ jobs:
5246
if: ${{ github.event_name == 'push' }}
5347
run: |
5448
cd docs
55-
mike deploy dev --push
49+
uv run mike deploy dev --push
5650
5751
- name: Build and deploy the lastest documentation upon new release
5852
if: ${{ github.event_name == 'release' }}
5953
run: |
6054
cd docs
61-
mike deploy ${{ github.event.release.name }} latest -u --push
55+
uv run mike deploy ${{ github.event.release.name }} latest -u --push

.github/workflows/cd.yml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ jobs:
1717
cd:
1818
name: Continuous Deployment
1919
runs-on: ubuntu-latest
20+
strategy:
21+
fail-fast: false
22+
matrix:
23+
python-version: ['3.10', '3.11', '3.12', '3.13']
2024

2125
steps:
2226
- name: Checkout
@@ -28,16 +32,15 @@ jobs:
2832
- name: Install uv and set the python version
2933
uses: astral-sh/setup-uv@v5
3034
with:
31-
python-version: '3.9'
35+
python-version: ${{ matrix.python-version }}
3236
enable-cache: true
37+
cache-dependency-glob: "pyproject.toml"
3338

3439
- name: Install dependencies with uv
35-
run: |
36-
uv pip install -e .[dev]
40+
run: uv sync --all-extras --dev
3741

3842
- name: Run the fast and the slow CPU tests with coverage
39-
run: |
40-
uv run pytest -v -x -n auto -m "not gpu" --cov=sbi --cov-report=xml tests/
43+
run: uv run pytest -v -x -n auto -m "not gpu" --cov=sbi --cov-report=xml tests/
4144

4245
- name: Upload coverage to Codecov
4346
uses: codecov/codecov-action@v4-beta

.github/workflows/ci.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Tests
1+
name: Continuous Integration
22

33
on: [pull_request, workflow_dispatch]
44

@@ -12,15 +12,15 @@ concurrency:
1212

1313
jobs:
1414
ci:
15-
name: CI
15+
name: Continuous Integration
1616
runs-on: ubuntu-latest
1717
if: |
1818
github.event_name == 'push' ||
1919
(github.event_name == 'pull_request' && github.event.pull_request.draft == false)
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: ['3.9', '3.12']
23+
python-version: ['3.10']
2424

2525
steps:
2626
- name: Checkout
@@ -34,6 +34,7 @@ jobs:
3434
with:
3535
python-version: ${{ matrix.python-version }}
3636
enable-cache: true
37+
cache-dependency-glob: "pyproject.toml"
3738

3839
- name: Install dependencies with uv
3940
run: |
@@ -56,7 +57,7 @@ jobs:
5657
restore-keys: |
5758
testmon-${{ runner.os }}-${{ matrix.python-version }}-
5859
59-
- name: Run the fast CPU tests with coverage
60+
- name: Run fast CPU tests with coverage
6061
run: uv run pytest --testmon-forceselect -v -n auto -m "not slow and not gpu" --cov=sbi --cov-report=xml tests/
6162

6263
- name: Upload coverage to Codecov

.github/workflows/lint.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ jobs:
2222
- uses: actions/checkout@v4
2323
- uses: actions/setup-python@v5
2424
with:
25-
python-version: '3.9'
25+
python-version: '3.10'
2626
- uses: pre-commit/action@v3.0.1
2727
with:
2828
extra_args: --all-files --show-diff-on-failure
2929

3030
pyright:
31-
name: Check types
31+
name: type checking.
3232
runs-on: ubuntu-latest
3333
steps:
3434
- name: Checkout
@@ -40,14 +40,12 @@ jobs:
4040
- name: Install uv and set the python version
4141
uses: astral-sh/setup-uv@v5
4242
with:
43-
python-version: '3.9'
43+
python-version: '3.10'
4444
enable-cache: true
45+
cache-dependency-glob: "pyproject.toml"
4546

4647
- name: Install dependencies with uv
47-
run: |
48-
uv pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
49-
uv pip install -e .[dev]
48+
run: uv sync --all-extras --dev
5049

5150
- name: Check types with pyright
52-
run: |
53-
uv run pyright sbi
51+
run: uv run pyright sbi

.github/workflows/publish.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Set up Python
1616
uses: actions/setup-python@v5
1717
with:
18-
python-version: "3.9"
18+
python-version: "3.10"
1919
- name: Install pypa/build
2020
run: >-
2121
python3 -m
@@ -25,7 +25,7 @@ jobs:
2525
- name: Build a binary wheel and a source tarball
2626
run: python3 -m build
2727
- name: Store the distribution packages
28-
uses: actions/upload-artifact@v3
28+
uses: actions/upload-artifact@v4
2929
with:
3030
name: python-package-distributions
3131
path: dist/
@@ -45,7 +45,7 @@ jobs:
4545

4646
steps:
4747
- name: Download all the dists
48-
uses: actions/download-artifact@v3
48+
uses: actions/download-artifact@v4
4949
with:
5050
name: python-package-distributions
5151
path: dist/
@@ -66,7 +66,7 @@ jobs:
6666

6767
steps:
6868
- name: Download all the dists
69-
uses: actions/download-artifact@v3
69+
uses: actions/download-artifact@v4
7070
with:
7171
name: python-package-distributions
7272
path: dist/

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
[![codecov](https://codecov.io/gh/sbi-dev/sbi/branch/main/graph/badge.svg)](https://codecov.io/gh/sbi-dev/sbi)
66
[![GitHub license](https://img.shields.io/github/license/sbi-dev/sbi)](https://github.com/sbi-dev/sbi/blob/master/LICENSE.txt)
77
[![DOI](https://joss.theoj.org/papers/10.21105/joss.02505/status.svg)](https://doi.org/10.21105/joss.02505)
8+
[![NumFOCUS affiliated](https://camo.githubusercontent.com/a0f197cee66ccd8ed498cf64e9f3f384c78a072fe1e65bada8d3015356ac7599/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f4e756d464f4355532d616666696c696174656425323070726f6a6563742d6f72616e67652e7376673f7374796c653d666c617426636f6c6f72413d45313532334426636f6c6f72423d303037443841)]
89

910
## `sbi`: Simulation-Based Inference
1011

@@ -59,14 +60,14 @@ posterior = inference.build_posterior()
5960

6061
### Installation
6162

62-
`sbi` requires Python 3.9 or higher. While a GPU isn't necessary, it can improve
63+
`sbi` requires Python 3.10 or higher. While a GPU isn't necessary, it can improve
6364
performance in some cases. We recommend using a virtual environment with
6465
[`conda`](https://docs.conda.io/en/latest/miniconda.html) for an easy setup.
6566

6667
If `conda` is installed on the system, an environment for installing `sbi` can be created as follows:
6768

6869
```
69-
conda create -n sbi_env python=3.9 && conda activate sbi_env
70+
conda create -n sbi_env python=3.10 && conda activate sbi_env
7071
```
7172

7273
### From PyPI
@@ -205,7 +206,7 @@ Durkan's `lfi`. `sbi` runs as a community project. See also
205206
`sbi` has been supported by the German Federal Ministry of Education and Research (BMBF)
206207
through project ADIMEM (FKZ 01IS18052 A-D), project SiMaLeSAM (FKZ 01IS21055A) and the
207208
Tübingen AI Center (FKZ 01IS18039A). Since 2024, `sbi` is supported by the appliedAI
208-
Institute for Europe.
209+
Institute for Europe, and by NumFOCUS.
209210

210211
## License
211212

docs/docs/install.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Installation
22

3-
`sbi` requires Python 3.9 or higher. A GPU is not required, but can lead to
3+
`sbi` requires Python 3.10 or higher. A GPU is not required, but can lead to
44
speed-up in some cases. We recommend using a
55
[`conda`](https://docs.conda.io/en/latest/miniconda.html) virtual environment
66
([Miniconda installation

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
"Programming Language :: Python :: 3",
2525
"Development Status :: 3 - Alpha",
2626
]
27-
requires-python = ">=3.9"
27+
requires-python = ">=3.10"
2828
dynamic = ["version"]
2929
readme = "README.md"
3030
keywords = ["Bayesian inference", "simulation-based inference", "PyTorch"]
@@ -33,12 +33,12 @@ dependencies = [
3333
"joblib>=1.0.0",
3434
"matplotlib",
3535
"notebook <= 6.4.12",
36-
"numpy<2.0.0",
36+
"numpy",
3737
"pillow",
3838
"pyknos>=0.16.0",
3939
"pyro-ppl>=1.3.1",
4040
"scikit-learn",
41-
"scipy<1.13",
41+
"scipy",
4242
"tensorboard",
4343
"torch>=1.13.0, <2.6.0",
4444
"tqdm",
@@ -140,7 +140,7 @@ xfail_strict = true
140140
[tool.pyright]
141141
include = ["sbi"]
142142
exclude = ["**/__pycache__", "**/__node_modules__", ".git", "docs", "tutorials", "tests"]
143-
python_version = "3.9"
143+
python_version = "3.10"
144144
reportUnsupportedDunderAll = false
145145
reportGeneralTypeIssues = false
146146
reportInvalidTypeForm = false

sbi/analysis/plot.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,18 @@ def plt_hist_2d(
124124
ax: Axes,
125125
samples_col: np.ndarray,
126126
samples_row: np.ndarray,
127-
limits_col: torch.Tensor,
128-
limits_row: torch.Tensor,
127+
limits_col: Union[torch.Tensor, List],
128+
limits_row: Union[torch.Tensor, List],
129129
offdiag_kwargs: Dict,
130130
):
131-
hist_kwargs = copy.deepcopy(offdiag_kwargs)
132131
"""Plot 2D histogram."""
132+
hist_kwargs = copy.deepcopy(offdiag_kwargs)
133+
134+
if isinstance(limits_col, torch.Tensor):
135+
limits_col = limits_col.tolist()
136+
if isinstance(limits_row, torch.Tensor):
137+
limits_row = limits_row.tolist()
138+
133139
if (
134140
"bins" not in hist_kwargs["np_hist_kwargs"]
135141
or hist_kwargs["np_hist_kwargs"]["bins"] is None
@@ -787,7 +793,9 @@ def pairplot(
787793
diag_kwargs_list = to_list_kwargs(diag_kwargs, len(samples))
788794
diag_func = get_diag_funcs(diag_list)
789795
diag_kwargs_filled = []
790-
for i, (diag_i, diag_kwargs_i) in enumerate(zip(diag_list, diag_kwargs_list)):
796+
for i, (diag_i, diag_kwargs_i) in enumerate(
797+
zip(diag_list, diag_kwargs_list, strict=False)
798+
):
791799
diag_kwarg_filled_i = _get_default_diag_kwargs(diag_i, i)
792800
# update the defaults dictionary with user provided values
793801
diag_kwarg_filled_i = _update(diag_kwarg_filled_i, diag_kwargs_i)
@@ -798,7 +806,9 @@ def pairplot(
798806
upper_kwargs_list = to_list_kwargs(upper_kwargs, len(samples))
799807
upper_func = get_offdiag_funcs(upper_list)
800808
upper_kwargs_filled = []
801-
for i, (upper_i, upper_kwargs_i) in enumerate(zip(upper_list, upper_kwargs_list)):
809+
for i, (upper_i, upper_kwargs_i) in enumerate(
810+
zip(upper_list, upper_kwargs_list, strict=False)
811+
):
802812
upper_kwarg_filled_i = _get_default_offdiag_kwargs(upper_i, i)
803813
# update the defaults dictionary with user provided values
804814
upper_kwarg_filled_i = _update(upper_kwarg_filled_i, upper_kwargs_i)
@@ -809,7 +819,9 @@ def pairplot(
809819
lower_kwargs_list = to_list_kwargs(lower_kwargs, len(samples))
810820
lower_func = get_offdiag_funcs(lower_list)
811821
lower_kwargs_filled = []
812-
for i, (lower_i, lower_kwargs_i) in enumerate(zip(lower_list, lower_kwargs_list)):
822+
for i, (lower_i, lower_kwargs_i) in enumerate(
823+
zip(lower_list, lower_kwargs_list, strict=False)
824+
):
813825
lower_kwarg_filled_i = _get_default_offdiag_kwargs(lower_i, i)
814826
# update the defaults dictionary with user provided values
815827
lower_kwarg_filled_i = _update(lower_kwarg_filled_i, lower_kwargs_i)
@@ -910,7 +922,9 @@ def marginal_plot(
910922
diag_kwargs_list = to_list_kwargs(diag_kwargs, len(samples))
911923
diag_func = get_diag_funcs(diag_list)
912924
diag_kwargs_filled = []
913-
for i, (diag_i, diag_kwargs_i) in enumerate(zip(diag_list, diag_kwargs_list)):
925+
for i, (diag_i, diag_kwargs_i) in enumerate(
926+
zip(diag_list, diag_kwargs_list, strict=False)
927+
):
914928
diag_kwarg_filled_i = _get_default_diag_kwargs(diag_i, i)
915929
diag_kwarg_filled_i = _update(diag_kwarg_filled_i, diag_kwargs_i)
916930
diag_kwargs_filled.append(diag_kwarg_filled_i)
@@ -2031,7 +2045,7 @@ def marginal_plot_with_probs_intensity(
20312045
# normalize color intensity
20322046
norm = Normalize(vmin=vmin, vmax=vmax)
20332047
# set color intensity
2034-
for w, p in zip(weights, patches):
2048+
for w, p in zip(weights, patches, strict=False):
20352049
p.set_facecolor(cmap(w))
20362050
if show_colorbar:
20372051
plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax_, label=label)

sbi/diagnostics/sbc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _run_sbc(
138138
ranks = torch.zeros((num_sbc_samples, len(reduce_fns)))
139139
# Iterate over all sbc samples and calculate ranks.
140140
for sbc_idx, (true_theta, x_i) in tqdm(
141-
enumerate(zip(thetas, xs)),
141+
enumerate(zip(thetas, xs, strict=False)),
142142
total=num_sbc_samples,
143143
disable=not show_progress_bar,
144144
desc=f"Calculating ranks for {num_sbc_samples} sbc samples.",
@@ -188,7 +188,7 @@ def get_nltp(thetas: Tensor, xs: Tensor, posterior: NeuralPosterior) -> Tensor:
188188
nltp = torch.zeros(thetas.shape[0])
189189
unnormalized_log_prob = not isinstance(posterior, (DirectPosterior, ScorePosterior))
190190

191-
for idx, (tho, xo) in enumerate(zip(thetas, xs)):
191+
for idx, (tho, xo) in enumerate(zip(thetas, xs, strict=False)):
192192
# Log prob of true params under posterior.
193193
if unnormalized_log_prob:
194194
nltp[idx] = -posterior.potential(tho, x=xo)
@@ -266,7 +266,7 @@ def check_prior_vs_dap(prior_samples: Tensor, dap_samples: Tensor) -> Tensor:
266266

267267
return torch.tensor([
268268
c2st(s1.unsqueeze(1), s2.unsqueeze(1))
269-
for s1, s2 in zip(prior_samples.T, dap_samples.T)
269+
for s1, s2 in zip(prior_samples.T, dap_samples.T, strict=False)
270270
])
271271

272272

0 commit comments

Comments
 (0)