Skip to content

Bambi + HSGP + nutpie crashes during arviz conversion. #982

@juanitorduz

Description

@juanitorduz

Summary

Fitting a Bambi model whose formula contains an hsgp(...) term with
inference_method="nutpie" raises:

KeyError: ' m=10'

at the very end of sampling, while nutpie is converting its arrow-backed
draws into an arviz InferenceData. The default PyMC sampler
(inference_method="pymc") works fine on the same model.

Root cause (best guess from the traceback)

Bambi names the HSGP weights variable after the formula call literally:

hsgp(x, m=10, c=1.5)_weights

That name contains commas. nutpie's _arrow_to_arviz path passes a
list/tuple of dim names through to arviz's from_dict
numpy_to_data_array. Somewhere in that chain the variable name appears
to be split on commas, producing strings like m=10 and c=1.5)_weights
that are then looked up in the coords dict — and they aren't there,
hence:

File ".../arviz/data/base.py", line 305, in numpy_to_data_array
    coords = {key: xr.IndexVariable((key,), ...) for key in dims}
KeyError: ' m=10'

The default PyMC sampler builds the InferenceData via a different path
(pm.to_inference_data), which only emits a UserWarning about the
weights dim and otherwise succeeds — so this is specific to
nutpie's arviz conversion, not to Bambi's HSGP naming itself.

Likely fix locations (in order of fewest side effects)

  1. nutpie._arrow_to_arviz: stop passing the variable name through
    anything that splits on commas; treat the name as opaque.
  2. Bambi: sanitize HSGP term names to avoid commas (e.g. hsgp_x_m10_c15
    instead of hsgp(x, m=10, c=1.5)). Cleanest long-term, but renames
    user-visible posterior variables.
  3. arviz: defensive handling of dim-name strings containing commas.

Workaround

Drop the inference_method="nutpie" argument and use the default PyMC
sampler. You lose nutpie's speed but the model fits cleanly.

Versions observed

bambi == 0.17.2
pymc  == 5.27.x
nutpie == 0.16.8
arviz == 0.23.4
Python 3.13
macOS arm64

Run

import bambi as bmb
import numpy as np
import pandas as pd

# Tiny synthetic dataset — the bug is independent of data shape; we just need
# *any* fit that includes an HSGP term so Bambi emits a variable named with
# commas (e.g. `hsgp(x, m=10, c=1.5)_weights`).
rng = np.random.default_rng(seed=0)
n = 200
x = np.linspace(-3, 3, n)
y = np.sin(x) + 0.1 * rng.standard_normal(n)
df = pd.DataFrame({"x": x, "y": y})

# The trigger is the `hsgp(...)` term in the formula. The keyword args inside
# the call (`m=10`, `c=1.5`) are baked into the resulting variable name verbatim,
# commas included.
model = bmb.Model(
    bmb.Formula("y ~ hsgp(x, m=10, c=1.5)"),
    df,
    family="gaussian",
)

# Sampling itself completes — the failure happens in nutpie's post-sampling
# conversion to InferenceData. Tiny draws/tune is enough to demonstrate.
idata = model.fit(
    draws=200,
    tune=200,
    chains=2,
    inference_method="nutpie",  # remove this kwarg to use PyMC's NUTS and avoid the bug.
    random_seed=0,
)
print(idata)

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 100
     96 )
     97 
     98 # Sampling itself completes — the failure happens in nutpie's post-sampling
     99 # conversion to InferenceData. Tiny draws/tune is enough to demonstrate.
--> 100 idata = model.fit(
    101     draws=200,
    102     tune=200,
    103     chains=2,

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/bambi/models.py:347, in Model.fit(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, include_response_params, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
    340     warnings.warn(
    341         "'include_mean' has been replaced by 'include_response_params' and "
    342         "is not going to work in the future",
    343         FutureWarning,
    344     )
    345     include_response_params = include_mean
--> 347 return self.backend.run(
    348     draws=draws,
    349     tune=tune,
    350     discard_tuned_samples=discard_tuned_samples,
    351     omit_offsets=omit_offsets,
    352     include_response_params=include_response_params,
    353     inference_method=inference_method,
    354     init=init,
    355     n_init=n_init,
    356     chains=chains,
    357     cores=cores,
    358     random_seed=random_seed,
    359     **kwargs,
    360 )

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/bambi/backend/pymc.py:153, in PyMCModel.run(self, draws, tune, discard_tuned_samples, omit_offsets, include_response_params, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
    147     result = self._run_laplace(
    148         draws=draws,
    149         omit_offsets=omit_offsets,
    150         include_response_params=include_response_params,
    151     )
    152 else:
--> 153     result = self._run_mcmc(
    154         draws=draws,
    155         tune=tune,
    156         discard_tuned_samples=discard_tuned_samples,
    157         omit_offsets=omit_offsets,
    158         include_response_params=include_response_params,
    159         init=init,
    160         n_init=n_init,
    161         chains=chains,
    162         cores=cores,
    163         random_seed=random_seed,
    164         sampler_backend=inference_method,
    165         **kwargs,
    166     )
    168 self.fit = True
    169 return result

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/bambi/backend/pymc.py:224, in PyMCModel._run_mcmc(self, draws, tune, discard_tuned_samples, omit_offsets, include_response_params, init, n_init, chains, cores, random_seed, sampler_backend, **kwargs)
    222 with self.model:
    223     try:
--> 224         idata = pm.sample(
    225             draws=draws,
    226             tune=tune,
    227             discard_tuned_samples=discard_tuned_samples,
    228             init=init,
    229             n_init=n_init,
    230             chains=chains,
    231             cores=cores,
    232             random_seed=random_seed,
    233             var_names=vars_to_sample,
    234             nuts_sampler=sampler_backend,
    235             **kwargs,
    236         )
    237     except (RuntimeError, ValueError):
    238         if "ValueError: Mass matrix contains" in traceback.format_exc() and init == "auto":

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, quiet, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    827         raise ValueError(
    828             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    829         )
    831     with joined_blas_limiter():
--> 832         return _sample_external_nuts(
    833             sampler=nuts_sampler,
    834             draws=draws,
    835             tune=tune,
    836             chains=chains,
    837             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    838             random_seed=random_seed,
    839             initvals=initvals,
    840             model=model,
    841             var_names=var_names,
    842             progressbar=progress_bool,
    843             quiet=quiet,
    844             idata_kwargs=idata_kwargs,
    845             compute_convergence_checks=compute_convergence_checks,
    846             nuts_sampler_kwargs=nuts_sampler_kwargs,
    847             **kwargs,
    848         )
    850 if exclusive_nuts and not provided_steps:
    851     # Special path for NUTS initialization
    852     if "nuts" in kwargs:

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/pymc/sampling/mcmc.py:377, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, quiet, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    371 compiled_model = nutpie.compile_pymc_model(
    372     model,
    373     var_names=var_names,
    374     **compile_kwargs,
    375 )
    376 t_start = time.time()
--> 377 idata = nutpie.sample(
    378     compiled_model,
    379     draws=draws,
    380     tune=tune,
    381     chains=chains,
    382     target_accept=target_accept,
    383     seed=_get_seeds_per_chain(random_seed, 1)[0],
    384     progress_bar=progressbar,
    385     **nuts_sampler_kwargs,
    386 )
    387 t_sample = time.time() - t_start
    388 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
    389 # gather observed and constant data as nutpie.sample() has no access to the PyMC model

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:865, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, transform_adapt, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, zarr_store, **kwargs)
    862     return sampler
    864 try:
--> 865     result = sampler.wait()
    866 except KeyboardInterrupt:
    867     result = sampler.abort()

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:528, in _BackgroundSampler.wait(self, timeout)
    526 self._sampler.wait(timeout)
    527 results = self._sampler.take_results()
--> 528 return self._extract(results)

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:573, in _BackgroundSampler._extract(self, results)
    570             skip_vars.extend(names)
    572     draw_batches, stat_batches = results.get_arrow_trace()
--> 573     return _arrow_to_arviz(
    574         draw_batches,
    575         stat_batches,
    576         skip_vars=skip_vars,
    577         coords={
    578             name: pd.Index(vals)
    579             for name, vals in self._compiled_model.coords.items()
    580         },
    581         save_warmup=self._save_warmup,
    582     )
    583 else:
    584     raise ValueError("Unknown results type")

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:99, in _arrow_to_arviz(draw_batches, stat_batches, skip_vars, **kwargs)
     94     stat_posterior = stat.slice(num_tuning[i], stat.num_rows - num_tuning[i])
     95     _add_arrow_data(
     96         stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars
     97     )
---> 99 return arviz.from_dict(
    100     data_posterior,
    101     sample_stats=stats_posterior,
    102     warmup_posterior=data_tune,
    103     warmup_sample_stats=stats_tune,
    104     dims=dims,
    105     **kwargs,
    106 )

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/io_dict.py:459, in from_dict(posterior, posterior_predictive, predictions, sample_stats, log_likelihood, prior, prior_predictive, sample_stats_prior, observed_data, constant_data, predictions_constant_data, warmup_posterior, warmup_posterior_predictive, warmup_predictions, warmup_log_likelihood, warmup_sample_stats, save_warmup, index_origin, coords, dims, pred_dims, pred_coords, attrs, **kwargs)
    352 def from_dict(
    353     posterior=None,
    354     *,
   (...)    377     **kwargs,
    378 ):
    379     """Convert Dictionary data into an InferenceData object.
    380 
    381     For a usage example read the
   (...)    432     InferenceData
    433     """
    434     return DictConverter(
    435         posterior=posterior,
    436         posterior_predictive=posterior_predictive,
    437         predictions=predictions,
    438         sample_stats=sample_stats,
    439         log_likelihood=log_likelihood,
    440         prior=prior,
    441         prior_predictive=prior_predictive,
    442         sample_stats_prior=sample_stats_prior,
    443         observed_data=observed_data,
    444         constant_data=constant_data,
    445         predictions_constant_data=predictions_constant_data,
    446         warmup_posterior=warmup_posterior,
    447         warmup_posterior_predictive=warmup_posterior_predictive,
    448         warmup_predictions=warmup_predictions,
    449         warmup_log_likelihood=warmup_log_likelihood,
    450         warmup_sample_stats=warmup_sample_stats,
    451         save_warmup=save_warmup,
    452         index_origin=index_origin,
    453         coords=coords,
    454         dims=dims,
    455         pred_dims=pred_dims,
    456         pred_coords=pred_coords,
    457         attrs=attrs,
    458         **kwargs,
--> 459     ).to_inference_data()

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/io_dict.py:334, in DictConverter.to_inference_data(self)
    326 def to_inference_data(self):
    327     """Convert all available data to an InferenceData object.
    328 
    329     Note that if groups can not be created, then the InferenceData
    330     will not have those groups.
    331     """
    332     return InferenceData(
    333         **{
--> 334             "posterior": self.posterior_to_xarray(),
    335             "sample_stats": self.sample_stats_to_xarray(),
    336             "log_likelihood": self.log_likelihood_to_xarray(),
    337             "posterior_predictive": self.posterior_predictive_to_xarray(),
    338             "predictions": self.predictions_to_xarray(),
    339             "prior": self.prior_to_xarray(),
    340             "sample_stats_prior": self.sample_stats_prior_to_xarray(),
    341             "prior_predictive": self.prior_predictive_to_xarray(),
    342             "observed_data": self.observed_data_to_xarray(),
    343             "constant_data": self.constant_data_to_xarray(),
    344             "predictions_constant_data": self.predictions_constant_data_to_xarray(),
    345             "save_warmup": self.save_warmup,
    346             "attrs": self.attrs,
    347         }
    348     )

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/base.py:71, in requires.__call__.<locals>.wrapped(cls)
     69     if all((getattr(cls, prop_i) is None for prop_i in prop)):
     70         return None
---> 71 return func(cls)

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/io_dict.py:97, in DictConverter.posterior_to_xarray(self)
     94 posterior_attrs = self._kwargs.get("posterior_attrs")
     95 posterior_warmup_attrs = self._kwargs.get("posterior_warmup_attrs")
     96 return (
---> 97     dict_to_dataset(
     98         data,
     99         library=None,
    100         coords=self.coords,
    101         dims=self.dims,
    102         attrs=posterior_attrs,
    103         index_origin=self.index_origin,
    104     ),
    105     dict_to_dataset(
    106         data_warmup,
    107         library=None,
    108         coords=self.coords,
    109         dims=self.dims,
    110         attrs=posterior_warmup_attrs,
    111         index_origin=self.index_origin,
    112     ),
    113 )

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/base.py:406, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    402     except TypeError:  # probably unsortable keys -- the function will still work if
    403         pass  # it is an honest dictionary.
    405 data_vars = {
--> 406     key: numpy_to_data_array(
    407         values,
    408         var_name=key,
    409         coords=coords,
    410         dims=dims.get(key),
    411         default_dims=default_dims,
    412         index_origin=index_origin,
    413         skip_event_dims=skip_event_dims,
    414     )
    415     for key, values in data.items()
    416 }
    417 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/base.py:305, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    302     coords["draw"] = np.arange(index_origin, n_samples + index_origin)
    304 # filter coords based on the dims
--> 305 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
    306 return xr.DataArray(ary, coords=coords, dims=dims)

KeyError: ' m=10'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions