Skip to content

(ta)scCODA make_arviz performance regression in v.1.0.1 #833

@jpintar

Description

@jpintar

Please make sure these conditions are met

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of pertpy.
  • (optional) I have confirmed this bug exists on the main branch.

Report

There is a significant performance regression in make_arviz going from pertpy version 1.0.0 to 1.0.1. This is apparently mostly due to changes in one or more dependencies, rather than changes to make_arviz itself, but I haven't had time to do a more thorough investigation and pin it down...

Running the following scCODA example multiple times on the same machine:

import pertpy as pt
import time

haber_cells = pt.dt.haber_2017_regions()

sccoda_model = pt.tl.Sccoda()
sccoda_data = sccoda_model.load(
    haber_cells,
    type="cell_level",
    generate_sample_level=True,
    cell_type_identifier="cell_label",
    sample_identifier="batch",
    covariate_obs=["condition"],
)
sccoda_data = sccoda_model.prepare(
    sccoda_data,
    modality_key="coda",
    formula="condition",
    reference_cell_type="Endocrine",
)
sccoda_model.run_nuts(sccoda_data, modality_key="coda", rng_key=42)

t0 = time.perf_counter()
sccoda_arviz = sccoda_model.make_arviz(sccoda_data, modality_key="coda")
t1 = time.perf_counter()

elapsed = t1 - t0

I get ~5s with version 1.0.0, but ~55s with version 1.0.1.

With the following tascCODA example:

import pertpy as pt

smillie_counts = pt.dt.smillie_2019()

tasccoda_model = pt.tl.Tasccoda()
smillie_data = tasccoda_model.load(
    smillie_counts,
    type="cell_level",
    cell_type_identifier="Cluster",
    sample_identifier=["Subject", "Sample"],
    covariate_obs=["Location", "Health"],
    levels_orig=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
    add_level_name=True,
)
smillie_data.mod["coda_LP"] = smillie_data["coda"][
    (smillie_data["coda"].obs["Health"].isin(["Healthy", "Non-inflamed"]))
    & (smillie_data["coda"].obs["Location"] == "LP")
].copy()
smillie_data = tasccoda_model.prepare(
    smillie_data,
    modality_key="coda_LP",
    tree_key="tree",
    reference_cell_type="automatic",
    formula="Health",
    pen_args={"phi": 0},
)
tasccoda_model.run_nuts(smillie_data, modality_key="coda_LP", rng_key=42)

t0 = time.perf_counter()
tasccoda_arviz = tasccoda_model.make_arviz(smillie_data, modality_key="coda_LP")
t1 = time.perf_counter()

elapsed = t1 - t0

I get ~4.5min with version 1.0.0, but with version 1.0.1, I gave up when it didn't finish in one hour.

I poked about the code a bit, and at first it looked like the slow step was az.from_numpyro(...), but when I added

prior = jax.tree_util.tree_map(lambda x: x.block_until_ready(), prior)
posterior_predictive = jax.tree_util.tree_map(lambda x: x.block_until_ready(), posterior_predictive)

for testing purposes right before, it turned out that was the bottleneck, and az.from_numpyro still finishes in a flash afterwards...

To replicate the testing environments:

  • for version 1.0.1, I did:
    micromamba env create -n pertpy-1.0.1 python=3.12.10
    micromamba activate pertpy-1.0.1 
    uv pip install "pertpy[tcoda]" ipywidgets ipykernel
  • for version 1.0.0, I did:
    micromamba env create -n pertpy-1.0.0 python=3.12.10
    micromamba activate pertpy-1.0.0
    uv pip install --exclude-newer 2025-06-03 "pertpy[tcoda]" ipywidgets ipykernel
    If I did just pertpy[tcoda]==1.0.0 without --exclude-newer, the environment ended up broken and the examples wouldn't run, so some newer-version dependencies are apparently not compatible with the older pertpy.

Versions

Collated from `session_info2` outputs for the two environments.

| Package    | Version (1.0.0 env) | Version (1.0.1 env) |
| ---------- | ------------------- | ------------------- |
| anndata    | 0.11.4              | 0.12.1              |
| arviz      | 0.21.0              | 0.22.0              |
| ipykernel  | 6.29.5              | 6.30.0              |
| ipywidgets | 8.1.7               | 8.1.7               |
| mudata     | 0.3.2               | 0.3.2               |
| numpy      | 2.2.6               | 2.2.6               |
| pandas     | 2.2.3               | 2.3.1               |
| pertpy     | 1.0.0               | 1.0.1               |
| torch      | 2.7.0 (2.7.0+cu126) | 2.7.1 (2.7.1+cu126) |
| xarray     | 2025.4.0            | 2025.7.1            |

| Dependency               | Version (1.0.0 env)    | Version (1.0.1 env)   |
| ------------------------ | ---------------------- | --------------------- |
| absl-py                  | 2.3.0                  | 2.3.1                 |
| adjustText               | 1.3.0                  | 1.3.0                 |
| arrow                    | 1.3.0                  | N/A                   |
| asttokens                | 3.0.0                  | 3.0.0                 |
| attrs                    | 25.3.0                 | 25.3.0                |
| blitzgsea                | 1.3.54                 | 1.3.54                |
| bottle                   | 0.13.3                 | N/A                   |
| Brotli                   | 1.1.0                  | 1.1.0                 |
| certifi                  | 2025.4.26 (2025.04.26) | 2025.8.3 (2025.08.03) |
| chardet                  | 5.2.0                  | N/A                   |
| charset-normalizer       | 3.4.2                  | 3.4.2                 |
| chex                     | 0.1.89                 | 0.1.90                |
| comm                     | 0.2.2                  | 0.2.3                 |
| custom-inherit           | 2.4.1                  | N/A                   |
| crc32c                   | N/A                    | 2.7.1                 |
| cycler                   | 0.12.1                 | 0.12.1                |
| debugpy                  | 1.8.14                 | 1.8.15                |
| decorator                | 5.2.1                  | 5.2.1                 |
| docrep                   | 0.3.2                  | 0.3.2                 |
| donfig                   | N/A                    | 0.8.1.post1           |
| equinox                  | 0.12.2                 | 0.13.0                |
| ete4                     | 4.3.0                  | N/A                   |
| etils                    | 1.12.2                 | 1.13.0                |
| executing                | 2.2.0                  | 2.2.0                 |
| fast-array-utils         | 1.2.1                  | 1.2.1                 |
| filelock                 | 3.18.0                 | 3.18.0                |
| flax                     | 0.10.6                 | 0.11.0                |
| fsspec                   | 2025.5.1               | 2025.7.0              |
| h5py                     | 3.13.0                 | 3.14.0                |
| idna                     | 3.10                   | 3.10                  |
| importlib_resources      | 6.5.2                  | 6.5.2                 |
| ipython                  | 9.3.0                  | 9.4.0                 |
| jax                      | 0.6.1                  | 0.6.2                 |
| jaxlib                   | 0.6.1                  | 0.6.2                 |
| jaxopt                   | 0.8.5                  | 0.8.5                 |
| jaxtyping                | 0.3.2                  | 0.3.2                 |
| jedi                     | 0.19.2                 | 0.19.2                |
| joblib                   | 1.5.1                  | 1.5.1                 |
| jupyter_client           | 8.6.3                  | 8.6.3                 |
| jupyter_core             | 5.8.1                  | 5.8.1                 |
| kiwisolver               | 1.4.8                  | 1.4.8                 |
| lamin_utils              | 0.14.0                 | 0.15.0                |
| legacy-api-wrap          | 1.4.1                  | 1.4.1                 |
| lightning                | 2.5.1.post0            | 2.5.2                 |
| lightning-utilities      | 0.14.3                 | 0.15.0                |
| lineax                   | 0.0.8                  | 0.0.8                 |
| llvmlite                 | 0.44.0                 | 0.44.0                |
| loguru                   | 0.7.3                  | N/A                   |
| matplotlib               | 3.10.3                 | 3.10.5                |
| ml_collections           | 1.1.0                  | 1.1.0                 |
| ml_dtypes                | 0.5.1                  | 0.5.3                 |
| mpmath                   | 1.3.0                  | 1.3.0                 |
| msgpack                  | 1.1.0                  | 1.1.1                 |
| multipledispatch         | 1.0.0 (0.6.0)          | 1.0.0 (0.6.0)         |
| natsort                  | 8.4.0                  | 8.4.0                 |
| networkx                 | 3.5                    | 3.5                   |
| numba                    | 0.61.2                 | 0.61.2                |
| numcodecs                | N/A                    | 0.16.1                |
| numpyro                  | 0.18.0                 | 0.18.0                |
| nvidia-cublas-cu12       | 12.6.4.1               | 12.6.4.1              |
| nvidia-cuda-cupti-cu12   | 12.6.80                | 12.6.80               |
| nvidia-cuda-nvrtc-cu12   | 12.6.77                | 12.6.77               |
| nvidia-cuda-runtime-cu12 | 12.6.77                | 12.6.77               |
| nvidia-cudnn-cu12        | 9.5.1.17               | 9.5.1.17              |
| nvidia-cufft-cu12        | 11.3.0.4               | 11.3.0.4              |
| nvidia-cufile-cu12       | 1.11.1.6               | 1.11.1.6              |
| nvidia-curand-cu12       | 10.3.7.77              | 10.3.7.77             |
| nvidia-cusolver-cu12     | 11.7.1.2               | 11.7.1.2              |
| nvidia-cusparse-cu12     | 12.5.4.2               | 12.5.4.2              |
| nvidia-nccl-cu12         | 2.26.2                 | 2.26.2                |
| nvidia-nvjitlink-cu12    | 12.6.85                | 12.6.85               |
| nvidia-nvtx-cu12         | 12.6.77                | 12.6.77               |
| opt_einsum               | 3.4.0                  | 3.4.0                 |
| optax                    | 0.2.4                  | 0.2.5                 |
| ott-jax                  | 0.5.0                  | 0.5.1                 |
| packaging                | 24.2                   | 25.0                  |
| parso                    | 0.8.4                  | 0.8.4                 |
| patsy                    | 1.0.1                  | 1.0.1                 |
| pillow                   | 11.2.1                 | 11.3.0                |
| platformdirs             | 4.3.8                  | 4.3.8                 |
| ply                      | 3.11                   | 3.11                  |
| prompt_toolkit           | 3.0.51                 | 3.0.51                |
| psutil                   | 7.0.0                  | 7.0.0                 |
| PubChemPy                | 1.0.4                  | 1.0.4                 |
| pure_eval                | 0.2.3                  | 0.2.3                 |
| pyarrow                  | 20.0.0                 | 21.0.0                |
| Pygments                 | 2.19.1                 | 2.19.2                |
| pyomo                    | 6.9.2                  | 6.9.2                 |
| pyparsing                | 3.2.3                  | 3.2.3                 |
| pypng                    | 0.20220715.0           | N/A                   |
| PyQt6                    | 6.9.0                  | N/A                   |
| PyQt6_sip                | 13.10.2                | N/A                   |
| PyQt6-Qt6                | 6.9.1                  | N/A                   |
| pyro-ppl                 | 1.9.1                  | 1.9.1                 |
| python-dateutil          | 2.9.0.post0            | 2.9.0.post0           |
| pytorch-lightning        | 2.5.1.post0            | 2.5.2                 |
| pytz                     | 2025.2                 | 2025.2                |
| PyYAML                   | 6.0.2                  | 6.0.2                 |
| pyzmq                    | 26.4.0                 | 27.0.1                |
| reportlab                | 4.4.1                  | N/A                   |
| requests                 | 2.32.3                 | 2.32.4                |
| rich                     | 14.0.0                 | 14.1.0                |
| scanpy                   | 1.11.2                 | 1.11.4                |
| scikit-learn             | 1.6.1                  | 1.7.1                 |
| scikit-misc              | 0.5.1                  | 0.5.1                 |
| scipy                    | 1.15.3                 | 1.16.1                |
| scvi-tools               | 1.3.1.post1            | 1.3.3                 |
| seaborn                  | 0.13.2                 | 0.13.2                |
| session-info2            | 0.1.2                  | 0.2                   |
| setuptools               | 80.9.0                 | 80.9.0                |
| simplejson               | 3.20.1                 | 3.20.1                |
| six                      | 1.17.0                 | 1.17.0                |
| sparse                   | 0.17.0                 | 0.17.0                |
| sparsecca                | 0.3.1                  | 0.3.1                 |
| stack-data               | 0.6.3                  | 0.6.3                 |
| statsmodels              | 0.14.4                 | 0.14.5                |
| sympy                    | 1.14.0                 | 1.14.0                |
| threadpoolctl            | 3.6.0                  | 3.6.0                 |
| toolz                    | 1.0.0                  | 1.0.0                 |
| torchmetrics             | 1.7.2                  | 1.8.0                 |
| tornado                  | 6.5.1                  | 6.5.1                 |
| toyplot                  | 2.0.0                  | N/A                   |
| toytree                  | 3.0.10                 | N/A                   |
| tqdm                     | 4.67.1                 | 4.67.1                |
| traitlets                | 5.14.3                 | 5.14.3                |
| triton                   | 3.3.0                  | 3.3.1                 |
| typing_extensions        | 4.14.0                 | 4.14.1                |
| urllib3                  | 2.4.0                  | 2.5.0                 |
| wadler_lindig            | 0.1.6                  | 0.1.7                 |
| wcwidth                  | 0.2.13                 | 0.2.13                |
| xarray-einstats          | 0.9.0                  | 0.9.1                 |
| zarr                     | N/A                    | 3.1.1                 |

| Component | Info                                                                           |
| --------- | ------------------------------------------------------------------------------ |
| Python    | 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0] |
| OS        | Linux-4.18.0-425.19.2.el8_7.x86_64-x86_64-with-glibc2.28                       |

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtriage

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions