Skip to content
This repository was archived by the owner on Jan 19, 2026. It is now read-only.

Commit befb6b9

Browse files
fehiepsicopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 742274292 Change-Id: Ifbaefaf361142577e03bca03d0d1d197d7f140fd
1 parent fac9eb9 commit befb6b9

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

lightweight_mmm/core/baseline/intercept.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ def simple_intercept(
5252
fn=custom_priors.get(priors.INTERCEPT,
5353
default_priors[priors.INTERCEPT]),
5454
)
55-
return intercept
55+
return jnp.asarray(intercept)

lightweight_mmm/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,9 @@ def media_mix_model(
420420
plate_prefixes = ("extra_feature", "geo")
421421
extra_features_einsum = "tfg, fg -> tg" # t = time, f = feature, g = geo
422422
extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
423-
with numpyro.plate_stack(plate_prefixes,
424-
sizes=extra_features_plates_shape):
423+
with numpyro.plate_stack(
424+
str(plate_prefixes), sizes=list(extra_features_plates_shape)
425+
):
425426
coef_extra_features = numpyro.sample(
426427
name=_COEF_EXTRA_FEATURES,
427428
fn=custom_priors.get(

0 commit comments

Comments
 (0)