-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Please make sure these conditions are met
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest version of pertpy.
- (optional) I have confirmed this bug exists on the main branch.
Report
There is a significant performance regression in make_arviz
going from pertpy
version 1.0.0
to 1.0.1
. This is apparently mostly due to changes in one or more dependencies, rather than changes to make_arviz
itself, but I haven't had time to do a more thorough investigation and pin it down...
Running the following scCODA example multiple times on the same machine:
import pertpy as pt
import time
haber_cells = pt.dt.haber_2017_regions()
sccoda_model = pt.tl.Sccoda()
sccoda_data = sccoda_model.load(
haber_cells,
type="cell_level",
generate_sample_level=True,
cell_type_identifier="cell_label",
sample_identifier="batch",
covariate_obs=["condition"],
)
sccoda_data = sccoda_model.prepare(
sccoda_data,
modality_key="coda",
formula="condition",
reference_cell_type="Endocrine",
)
sccoda_model.run_nuts(sccoda_data, modality_key="coda", rng_key=42)
t0 = time.perf_counter()
sccoda_arviz = sccoda_model.make_arviz(sccoda_data, modality_key="coda")
t1 = time.perf_counter()
elapsed = t1 - t0
I get ~5s with version 1.0.0
, but ~55s with version 1.0.1
.
With the following tascCODA example:
import pertpy as pt
smillie_counts = pt.dt.smillie_2019()
tasccoda_model = pt.tl.Tasccoda()
smillie_data = tasccoda_model.load(
smillie_counts,
type="cell_level",
cell_type_identifier="Cluster",
sample_identifier=["Subject", "Sample"],
covariate_obs=["Location", "Health"],
levels_orig=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
add_level_name=True,
)
smillie_data.mod["coda_LP"] = smillie_data["coda"][
(smillie_data["coda"].obs["Health"].isin(["Healthy", "Non-inflamed"]))
& (smillie_data["coda"].obs["Location"] == "LP")
].copy()
smillie_data = tasccoda_model.prepare(
smillie_data,
modality_key="coda_LP",
tree_key="tree",
reference_cell_type="automatic",
formula="Health",
pen_args={"phi": 0},
)
tasccoda_model.run_nuts(smillie_data, modality_key="coda_LP", rng_key=42)
t0 = time.perf_counter()
tasccoda_arviz = tasccoda_model.make_arviz(smillie_data, modality_key="coda_LP")
t1 = time.perf_counter()
elapsed = t1 - t0
I get ~4.5min with version 1.0.0
, but with version 1.0.1
, I gave up when it didn't finish in one hour.
I poked about the code a bit, and at first it looked like the slow step was az.from_numpyro(...)
, but when I added
prior = jax.tree_util.tree_map(lambda x: x.block_until_ready(), prior)
posterior_predictive = jax.tree_util.tree_map(lambda x: x.block_until_ready(), posterior_predictive)
for testing purposes right before, it turned out that was the bottleneck, and az.from_numpyro
still finishes in a flash afterwards...
To replicate the testing environments:
- for version
1.0.1
, I did:micromamba env create -n pertpy-1.0.1 python=3.12.10 micromamba activate pertpy-1.0.1 uv pip install "pertpy[tcoda]" ipywidgets ipykernel
- for version
1.0.0
, I did:If I did justmicromamba env create -n pertpy-1.0.0 python=3.12.10 micromamba activate pertpy-1.0.0 uv pip install --exclude-newer 2025-06-03 "pertpy[tcoda]" ipywidgets ipykernel
pertpy[tcoda]==1.0.0
without--exclude-newer
, the environment ended up broken and the examples wouldn't run, so some newer-version dependencies are apparently not compatible with the olderpertpy
.
Versions
Collated from `session_info2` outputs for the two environments.
| Package | Version (1.0.0 env) | Version (1.0.1 env) |
| ---------- | ------------------- | ------------------- |
| anndata | 0.11.4 | 0.12.1 |
| arviz | 0.21.0 | 0.22.0 |
| ipykernel | 6.29.5 | 6.30.0 |
| ipywidgets | 8.1.7 | 8.1.7 |
| mudata | 0.3.2 | 0.3.2 |
| numpy | 2.2.6 | 2.2.6 |
| pandas | 2.2.3 | 2.3.1 |
| pertpy | 1.0.0 | 1.0.1 |
| torch | 2.7.0 (2.7.0+cu126) | 2.7.1 (2.7.1+cu126) |
| xarray | 2025.4.0 | 2025.7.1 |
| Dependency | Version (1.0.0 env) | Version (1.0.1 env) |
| ------------------------ | ---------------------- | --------------------- |
| absl-py | 2.3.0 | 2.3.1 |
| adjustText | 1.3.0 | 1.3.0 |
| arrow | 1.3.0 | N/A |
| asttokens | 3.0.0 | 3.0.0 |
| attrs | 25.3.0 | 25.3.0 |
| blitzgsea | 1.3.54 | 1.3.54 |
| bottle | 0.13.3 | N/A |
| Brotli | 1.1.0 | 1.1.0 |
| certifi | 2025.4.26 (2025.04.26) | 2025.8.3 (2025.08.03) |
| chardet | 5.2.0 | N/A |
| charset-normalizer | 3.4.2 | 3.4.2 |
| chex | 0.1.89 | 0.1.90 |
| comm | 0.2.2 | 0.2.3 |
| custom-inherit | 2.4.1 | N/A |
| crc32c | N/A | 2.7.1 |
| cycler | 0.12.1 | 0.12.1 |
| debugpy | 1.8.14 | 1.8.15 |
| decorator | 5.2.1 | 5.2.1 |
| docrep | 0.3.2 | 0.3.2 |
| donfig | N/A | 0.8.1.post1 |
| equinox | 0.12.2 | 0.13.0 |
| ete4 | 4.3.0 | N/A |
| etils | 1.12.2 | 1.13.0 |
| executing | 2.2.0 | 2.2.0 |
| fast-array-utils | 1.2.1 | 1.2.1 |
| filelock | 3.18.0 | 3.18.0 |
| flax | 0.10.6 | 0.11.0 |
| fsspec | 2025.5.1 | 2025.7.0 |
| h5py | 3.13.0 | 3.14.0 |
| idna | 3.10 | 3.10 |
| importlib_resources | 6.5.2 | 6.5.2 |
| ipython | 9.3.0 | 9.4.0 |
| jax | 0.6.1 | 0.6.2 |
| jaxlib | 0.6.1 | 0.6.2 |
| jaxopt | 0.8.5 | 0.8.5 |
| jaxtyping | 0.3.2 | 0.3.2 |
| jedi | 0.19.2 | 0.19.2 |
| joblib | 1.5.1 | 1.5.1 |
| jupyter_client | 8.6.3 | 8.6.3 |
| jupyter_core | 5.8.1 | 5.8.1 |
| kiwisolver | 1.4.8 | 1.4.8 |
| lamin_utils | 0.14.0 | 0.15.0 |
| legacy-api-wrap | 1.4.1 | 1.4.1 |
| lightning | 2.5.1.post0 | 2.5.2 |
| lightning-utilities | 0.14.3 | 0.15.0 |
| lineax | 0.0.8 | 0.0.8 |
| llvmlite | 0.44.0 | 0.44.0 |
| loguru | 0.7.3 | N/A |
| matplotlib | 3.10.3 | 3.10.5 |
| ml_collections | 1.1.0 | 1.1.0 |
| ml_dtypes | 0.5.1 | 0.5.3 |
| mpmath | 1.3.0 | 1.3.0 |
| msgpack | 1.1.0 | 1.1.1 |
| multipledispatch | 1.0.0 (0.6.0) | 1.0.0 (0.6.0) |
| natsort | 8.4.0 | 8.4.0 |
| networkx | 3.5 | 3.5 |
| numba | 0.61.2 | 0.61.2 |
| numcodecs | N/A | 0.16.1 |
| numpyro | 0.18.0 | 0.18.0 |
| nvidia-cublas-cu12 | 12.6.4.1 | 12.6.4.1 |
| nvidia-cuda-cupti-cu12 | 12.6.80 | 12.6.80 |
| nvidia-cuda-nvrtc-cu12 | 12.6.77 | 12.6.77 |
| nvidia-cuda-runtime-cu12 | 12.6.77 | 12.6.77 |
| nvidia-cudnn-cu12 | 9.5.1.17 | 9.5.1.17 |
| nvidia-cufft-cu12 | 11.3.0.4 | 11.3.0.4 |
| nvidia-cufile-cu12 | 1.11.1.6 | 1.11.1.6 |
| nvidia-curand-cu12 | 10.3.7.77 | 10.3.7.77 |
| nvidia-cusolver-cu12 | 11.7.1.2 | 11.7.1.2 |
| nvidia-cusparse-cu12 | 12.5.4.2 | 12.5.4.2 |
| nvidia-nccl-cu12 | 2.26.2 | 2.26.2 |
| nvidia-nvjitlink-cu12 | 12.6.85 | 12.6.85 |
| nvidia-nvtx-cu12 | 12.6.77 | 12.6.77 |
| opt_einsum | 3.4.0 | 3.4.0 |
| optax | 0.2.4 | 0.2.5 |
| ott-jax | 0.5.0 | 0.5.1 |
| packaging | 24.2 | 25.0 |
| parso | 0.8.4 | 0.8.4 |
| patsy | 1.0.1 | 1.0.1 |
| pillow | 11.2.1 | 11.3.0 |
| platformdirs | 4.3.8 | 4.3.8 |
| ply | 3.11 | 3.11 |
| prompt_toolkit | 3.0.51 | 3.0.51 |
| psutil | 7.0.0 | 7.0.0 |
| PubChemPy | 1.0.4 | 1.0.4 |
| pure_eval | 0.2.3 | 0.2.3 |
| pyarrow | 20.0.0 | 21.0.0 |
| Pygments | 2.19.1 | 2.19.2 |
| pyomo | 6.9.2 | 6.9.2 |
| pyparsing | 3.2.3 | 3.2.3 |
| pypng | 0.20220715.0 | N/A |
| PyQt6 | 6.9.0 | N/A |
| PyQt6_sip | 13.10.2 | N/A |
| PyQt6-Qt6 | 6.9.1 | N/A |
| pyro-ppl | 1.9.1 | 1.9.1 |
| python-dateutil | 2.9.0.post0 | 2.9.0.post0 |
| pytorch-lightning | 2.5.1.post0 | 2.5.2 |
| pytz | 2025.2 | 2025.2 |
| PyYAML | 6.0.2 | 6.0.2 |
| pyzmq | 26.4.0 | 27.0.1 |
| reportlab | 4.4.1 | N/A |
| requests | 2.32.3 | 2.32.4 |
| rich | 14.0.0 | 14.1.0 |
| scanpy | 1.11.2 | 1.11.4 |
| scikit-learn | 1.6.1 | 1.7.1 |
| scikit-misc | 0.5.1 | 0.5.1 |
| scipy | 1.15.3 | 1.16.1 |
| scvi-tools | 1.3.1.post1 | 1.3.3 |
| seaborn | 0.13.2 | 0.13.2 |
| session-info2 | 0.1.2 | 0.2 |
| setuptools | 80.9.0 | 80.9.0 |
| simplejson | 3.20.1 | 3.20.1 |
| six | 1.17.0 | 1.17.0 |
| sparse | 0.17.0 | 0.17.0 |
| sparsecca | 0.3.1 | 0.3.1 |
| stack-data | 0.6.3 | 0.6.3 |
| statsmodels | 0.14.4 | 0.14.5 |
| sympy | 1.14.0 | 1.14.0 |
| threadpoolctl | 3.6.0 | 3.6.0 |
| toolz | 1.0.0 | 1.0.0 |
| torchmetrics | 1.7.2 | 1.8.0 |
| tornado | 6.5.1 | 6.5.1 |
| toyplot | 2.0.0 | N/A |
| toytree | 3.0.10 | N/A |
| tqdm | 4.67.1 | 4.67.1 |
| traitlets | 5.14.3 | 5.14.3 |
| triton | 3.3.0 | 3.3.1 |
| typing_extensions | 4.14.0 | 4.14.1 |
| urllib3 | 2.4.0 | 2.5.0 |
| wadler_lindig | 0.1.6 | 0.1.7 |
| wcwidth | 0.2.13 | 0.2.13 |
| xarray-einstats | 0.9.0 | 0.9.1 |
| zarr | N/A | 3.1.1 |
| Component | Info |
| --------- | ------------------------------------------------------------------------------ |
| Python | 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0] |
| OS | Linux-4.18.0-425.19.2.el8_7.x86_64-x86_64-with-glibc2.28 |