That name contains commas. nutpie's _arrow_to_arviz path passes a
list/tuple of dim names through to arviz's from_dict →
numpy_to_data_array. Somewhere in that chain the variable name appears
to be split on commas, producing strings like m=10 and c=1.5)_weights
that are then looked up in the coords dict — and they aren't there,
hence:
import bambi as bmb
import numpy as np
import pandas as pd
# Tiny synthetic dataset — the bug is independent of data shape; we just need
# *any* fit that includes an HSGP term so Bambi emits a variable named with
# commas (e.g. `hsgp(x, m=10, c=1.5)_weights`).
rng = np.random.default_rng(seed=0)
n = 200
x = np.linspace(-3, 3, n)
y = np.sin(x) + 0.1 * rng.standard_normal(n)
df = pd.DataFrame({"x": x, "y": y})
# The trigger is the `hsgp(...)` term in the formula. The keyword args inside
# the call (`m=10`, `c=1.5`) are baked into the resulting variable name verbatim,
# commas included.
model = bmb.Model(
bmb.Formula("y ~ hsgp(x, m=10, c=1.5)"),
df,
family="gaussian",
)
# Sampling itself completes — the failure happens in nutpie's post-sampling
# conversion to InferenceData. Tiny draws/tune is enough to demonstrate.
idata = model.fit(
draws=200,
tune=200,
chains=2,
inference_method="nutpie", # remove this kwarg to use PyMC's NUTS and avoid the bug.
random_seed=0,
)
print(idata)
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[1], line 100
96 )
97
98 # Sampling itself completes — the failure happens in nutpie's post-sampling
99 # conversion to InferenceData. Tiny draws/tune is enough to demonstrate.
--> 100 idata = model.fit(
101 draws=200,
102 tune=200,
103 chains=2,
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/bambi/models.py:347, in Model.fit(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, include_response_params, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
340 warnings.warn(
341 "'include_mean' has been replaced by 'include_response_params' and "
342 "is not going to work in the future",
343 FutureWarning,
344 )
345 include_response_params = include_mean
--> 347 return self.backend.run(
348 draws=draws,
349 tune=tune,
350 discard_tuned_samples=discard_tuned_samples,
351 omit_offsets=omit_offsets,
352 include_response_params=include_response_params,
353 inference_method=inference_method,
354 init=init,
355 n_init=n_init,
356 chains=chains,
357 cores=cores,
358 random_seed=random_seed,
359 **kwargs,
360 )
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/bambi/backend/pymc.py:153, in PyMCModel.run(self, draws, tune, discard_tuned_samples, omit_offsets, include_response_params, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
147 result = self._run_laplace(
148 draws=draws,
149 omit_offsets=omit_offsets,
150 include_response_params=include_response_params,
151 )
152 else:
--> 153 result = self._run_mcmc(
154 draws=draws,
155 tune=tune,
156 discard_tuned_samples=discard_tuned_samples,
157 omit_offsets=omit_offsets,
158 include_response_params=include_response_params,
159 init=init,
160 n_init=n_init,
161 chains=chains,
162 cores=cores,
163 random_seed=random_seed,
164 sampler_backend=inference_method,
165 **kwargs,
166 )
168 self.fit = True
169 return result
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/bambi/backend/pymc.py:224, in PyMCModel._run_mcmc(self, draws, tune, discard_tuned_samples, omit_offsets, include_response_params, init, n_init, chains, cores, random_seed, sampler_backend, **kwargs)
222 with self.model:
223 try:
--> 224 idata = pm.sample(
225 draws=draws,
226 tune=tune,
227 discard_tuned_samples=discard_tuned_samples,
228 init=init,
229 n_init=n_init,
230 chains=chains,
231 cores=cores,
232 random_seed=random_seed,
233 var_names=vars_to_sample,
234 nuts_sampler=sampler_backend,
235 **kwargs,
236 )
237 except (RuntimeError, ValueError):
238 if "ValueError: Mass matrix contains" in traceback.format_exc() and init == "auto":
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, quiet, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
827 raise ValueError(
828 "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
829 )
831 with joined_blas_limiter():
--> 832 return _sample_external_nuts(
833 sampler=nuts_sampler,
834 draws=draws,
835 tune=tune,
836 chains=chains,
837 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
838 random_seed=random_seed,
839 initvals=initvals,
840 model=model,
841 var_names=var_names,
842 progressbar=progress_bool,
843 quiet=quiet,
844 idata_kwargs=idata_kwargs,
845 compute_convergence_checks=compute_convergence_checks,
846 nuts_sampler_kwargs=nuts_sampler_kwargs,
847 **kwargs,
848 )
850 if exclusive_nuts and not provided_steps:
851 # Special path for NUTS initialization
852 if "nuts" in kwargs:
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/pymc/sampling/mcmc.py:377, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, quiet, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
371 compiled_model = nutpie.compile_pymc_model(
372 model,
373 var_names=var_names,
374 **compile_kwargs,
375 )
376 t_start = time.time()
--> 377 idata = nutpie.sample(
378 compiled_model,
379 draws=draws,
380 tune=tune,
381 chains=chains,
382 target_accept=target_accept,
383 seed=_get_seeds_per_chain(random_seed, 1)[0],
384 progress_bar=progressbar,
385 **nuts_sampler_kwargs,
386 )
387 t_sample = time.time() - t_start
388 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
389 # gather observed and constant data as nutpie.sample() has no access to the PyMC model
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:865, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, transform_adapt, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, zarr_store, **kwargs)
862 return sampler
864 try:
--> 865 result = sampler.wait()
866 except KeyboardInterrupt:
867 result = sampler.abort()
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:528, in _BackgroundSampler.wait(self, timeout)
526 self._sampler.wait(timeout)
527 results = self._sampler.take_results()
--> 528 return self._extract(results)
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:573, in _BackgroundSampler._extract(self, results)
570 skip_vars.extend(names)
572 draw_batches, stat_batches = results.get_arrow_trace()
--> 573 return _arrow_to_arviz(
574 draw_batches,
575 stat_batches,
576 skip_vars=skip_vars,
577 coords={
578 name: pd.Index(vals)
579 for name, vals in self._compiled_model.coords.items()
580 },
581 save_warmup=self._save_warmup,
582 )
583 else:
584 raise ValueError("Unknown results type")
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/nutpie/sample.py:99, in _arrow_to_arviz(draw_batches, stat_batches, skip_vars, **kwargs)
94 stat_posterior = stat.slice(num_tuning[i], stat.num_rows - num_tuning[i])
95 _add_arrow_data(
96 stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars
97 )
---> 99 return arviz.from_dict(
100 data_posterior,
101 sample_stats=stats_posterior,
102 warmup_posterior=data_tune,
103 warmup_sample_stats=stats_tune,
104 dims=dims,
105 **kwargs,
106 )
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/io_dict.py:459, in from_dict(posterior, posterior_predictive, predictions, sample_stats, log_likelihood, prior, prior_predictive, sample_stats_prior, observed_data, constant_data, predictions_constant_data, warmup_posterior, warmup_posterior_predictive, warmup_predictions, warmup_log_likelihood, warmup_sample_stats, save_warmup, index_origin, coords, dims, pred_dims, pred_coords, attrs, **kwargs)
352 def from_dict(
353 posterior=None,
354 *,
(...) 377 **kwargs,
378 ):
379 """Convert Dictionary data into an InferenceData object.
380
381 For a usage example read the
(...) 432 InferenceData
433 """
434 return DictConverter(
435 posterior=posterior,
436 posterior_predictive=posterior_predictive,
437 predictions=predictions,
438 sample_stats=sample_stats,
439 log_likelihood=log_likelihood,
440 prior=prior,
441 prior_predictive=prior_predictive,
442 sample_stats_prior=sample_stats_prior,
443 observed_data=observed_data,
444 constant_data=constant_data,
445 predictions_constant_data=predictions_constant_data,
446 warmup_posterior=warmup_posterior,
447 warmup_posterior_predictive=warmup_posterior_predictive,
448 warmup_predictions=warmup_predictions,
449 warmup_log_likelihood=warmup_log_likelihood,
450 warmup_sample_stats=warmup_sample_stats,
451 save_warmup=save_warmup,
452 index_origin=index_origin,
453 coords=coords,
454 dims=dims,
455 pred_dims=pred_dims,
456 pred_coords=pred_coords,
457 attrs=attrs,
458 **kwargs,
--> 459 ).to_inference_data()
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/io_dict.py:334, in DictConverter.to_inference_data(self)
326 def to_inference_data(self):
327 """Convert all available data to an InferenceData object.
328
329 Note that if groups can not be created, then the InferenceData
330 will not have those groups.
331 """
332 return InferenceData(
333 **{
--> 334 "posterior": self.posterior_to_xarray(),
335 "sample_stats": self.sample_stats_to_xarray(),
336 "log_likelihood": self.log_likelihood_to_xarray(),
337 "posterior_predictive": self.posterior_predictive_to_xarray(),
338 "predictions": self.predictions_to_xarray(),
339 "prior": self.prior_to_xarray(),
340 "sample_stats_prior": self.sample_stats_prior_to_xarray(),
341 "prior_predictive": self.prior_predictive_to_xarray(),
342 "observed_data": self.observed_data_to_xarray(),
343 "constant_data": self.constant_data_to_xarray(),
344 "predictions_constant_data": self.predictions_constant_data_to_xarray(),
345 "save_warmup": self.save_warmup,
346 "attrs": self.attrs,
347 }
348 )
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/base.py:71, in requires.__call__.<locals>.wrapped(cls)
69 if all((getattr(cls, prop_i) is None for prop_i in prop)):
70 return None
---> 71 return func(cls)
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/io_dict.py:97, in DictConverter.posterior_to_xarray(self)
94 posterior_attrs = self._kwargs.get("posterior_attrs")
95 posterior_warmup_attrs = self._kwargs.get("posterior_warmup_attrs")
96 return (
---> 97 dict_to_dataset(
98 data,
99 library=None,
100 coords=self.coords,
101 dims=self.dims,
102 attrs=posterior_attrs,
103 index_origin=self.index_origin,
104 ),
105 dict_to_dataset(
106 data_warmup,
107 library=None,
108 coords=self.coords,
109 dims=self.dims,
110 attrs=posterior_warmup_attrs,
111 index_origin=self.index_origin,
112 ),
113 )
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/base.py:406, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
402 except TypeError: # probably unsortable keys -- the function will still work if
403 pass # it is an honest dictionary.
405 data_vars = {
--> 406 key: numpy_to_data_array(
407 values,
408 var_name=key,
409 coords=coords,
410 dims=dims.get(key),
411 default_dims=default_dims,
412 index_origin=index_origin,
413 skip_event_dims=skip_event_dims,
414 )
415 for key, values in data.items()
416 }
417 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ~/Documents/website_projects/.pixi/envs/default/lib/python3.13/site-packages/arviz/data/base.py:305, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
302 coords["draw"] = np.arange(index_origin, n_samples + index_origin)
304 # filter coords based on the dims
--> 305 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
306 return xr.DataArray(ary, coords=coords, dims=dims)
KeyError: ' m=10'
Summary
Fitting a Bambi model whose formula contains an
hsgp(...)term withinference_method="nutpie"raises:at the very end of sampling, while nutpie is converting its arrow-backed
draws into an arviz InferenceData. The default PyMC sampler
(
inference_method="pymc") works fine on the same model.Root cause (best guess from the traceback)
Bambi names the HSGP weights variable after the formula call literally:
That name contains commas. nutpie's
_arrow_to_arvizpath passes alist/tuple of dim names through to arviz's
from_dict→numpy_to_data_array. Somewhere in that chain the variable name appearsto be split on commas, producing strings like
m=10andc=1.5)_weightsthat are then looked up in the
coordsdict — and they aren't there,hence:
The default PyMC sampler builds the InferenceData via a different path
(
pm.to_inference_data), which only emits a UserWarning about theweights dim and otherwise succeeds — so this is specific to
nutpie's arviz conversion, not to Bambi's HSGP naming itself.
Likely fix locations (in order of fewest side effects)
anything that splits on commas; treat the name as opaque.
hsgp_x_m10_c15instead of
hsgp(x, m=10, c=1.5)). Cleanest long-term, but renamesuser-visible posterior variables.
Workaround
Drop the
inference_method="nutpie"argument and use the default PyMCsampler. You lose nutpie's speed but the model fits cleanly.
Versions observed
Run