Skip to content

Add Bass Diffusion Model #1328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Apr 25, 2025
Merged

Add Bass Diffusion Model #1328

merged 41 commits into from
Apr 25, 2025

Conversation

williambdean
Copy link
Contributor

@williambdean williambdean commented Jan 3, 2025

Description

Adding Bass Diffusion Model

Related Issue

Checklist

Modules affected

  • MMM
  • CLV
  • Customer Choice
  • Product Development

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc-marketing--1328.org.readthedocs.build/en/1328/

@github-actions github-actions bot added the enhancement New feature or request label Jan 3, 2025
@williambdean
Copy link
Contributor Author

Here is a crude implementation of the Bass model. Feel free to take over @juanitorduz

I forget why I didn't wrap it in the ModelBuilder. Maybe I was just trying it out and familiarizing myself with the model. Maybe it will be straight forward but I remember having some concern.

As for the magnitude of m compared to the other parameters and NUTS, maybe a scaling constant can be used in the model. You might have some more ideas.

model = Bass(m_scaling=5000)

# Under the hood
m = m_scaling * Prior.create_variable(...)

Copy link

codecov bot commented Jan 3, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.39%. Comparing base (e80894d) to head (4df7b6f).
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1328      +/-   ##
==========================================
+ Coverage   93.35%   93.39%   +0.03%     
==========================================
  Files          55       56       +1     
  Lines        6287     6325      +38     
==========================================
+ Hits         5869     5907      +38     
  Misses        418      418              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@juanitorduz
Copy link
Collaborator

Thanks @wd60622 ! I will familiarize myself with the model and push it forward 🙌

@williambdean
Copy link
Contributor Author

williambdean commented Jan 3, 2025

Sounds good! Thanks

I had the idea we could parameterize it as days/weeks/months from product launch making pooling across dims / products easier. I feel like that would be an interesting insight.

@juanitorduz juanitorduz self-assigned this Jan 3, 2025
@williambdean
Copy link
Contributor Author

Visual produced from the main block:

bass

@juanitorduz
Copy link
Collaborator

I will pick up this one so that we can merge the base model and iterate

@williambdean
Copy link
Contributor Author

williambdean commented Apr 2, 2025

The magnitude of m might benefit from VariableFactory protocol allows for:

import pymc as pm
from pymc_marketing.prior import Prior

class Scaled: 
    def __init__(self, dist: Prior, factor: str): 
        self.dist = dist
        self.factor = factor 

    @property 
    def dims(self): 
        return self.dist.dims

    def create_variable(self, name: str): 
        var = self.dist.create_variable(f"{name}_unscaled")
        return pm.Deterministic(name, var * self.factor, dims=self.dims)

This would keep the model itself free from this scaling logic but still allow if needed.

from pymc_marketing.prior import sample_prior

m = Scaled(Prior("HalfNormal", sigma=1), 5000)
prior = sample_prior(m)

scaled

@github-actions github-actions bot added the tests label Apr 8, 2025
@github-actions github-actions bot added the docs Improvements or additions to documentation label Apr 8, 2025
Juan Orduz added 2 commits April 8, 2025 15:10
@juanitorduz juanitorduz marked this pull request as draft April 8, 2025 13:22
@williambdean
Copy link
Contributor Author

williambdean commented Apr 22, 2025

@jwilkinson88 and I had a bit of motivation here and created a synthetic dataset

Generate Data
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import pymc as pm

from pymc_marketing.bass.model import create_bass_model
from pymc_marketing.prior import Prior


def setup_simulation_parameters(
    n_weeks=52,
    n_products=15,
    start_date="2023-01-01",
    cutoff_start_date="2023-12-01",
):
    """Set up initial parameters for the Bass diffusion model simulation."""
    seed = sum(map(ord, "Bass Model"))
    rng = np.random.default_rng(seed)

    # Create time array and date range
    T = np.arange(n_weeks)
    possible_dates = pd.date_range(start_date, freq="W-MON", periods=n_weeks)
    cutoff_start_date = pd.to_datetime(cutoff_start_date)
    cutoff_start_date = cutoff_start_date + pd.DateOffset(weeks=1)
    possible_start_dates = possible_dates[possible_dates < cutoff_start_date]

    # Generate product names and random start dates
    products = [f"P{i}" for i in range(n_products)]
    product_start = pd.Series(
        rng.choice(possible_start_dates, size=len(products)),
        index=pd.Index(products, name="product"),
    )

    coords = {"date": T, "product": products}
    return T, possible_dates, possible_start_dates, products, product_start, coords


class Scaled:
    """Scaled distribution for numerical stability."""

    def __init__(self, dist: Prior, factor: str):
        self.dist = dist
        self.factor = factor

    @property
    def dims(self):
        return self.dist.dims

    def create_variable(self, name: str):
        var = self.dist.create_variable(f"{name}_unscaled")
        return pm.Deterministic(name, var * self.factor, dims=self.dims)


def create_bass_priors():
    """Define prior distributions for the Bass model parameters."""
    return {
        "m": Scaled(Prior("Gamma", mu=1, sigma=0.001, dims="product"), factor=50_000),
        "p": Prior("Beta", mu=0.38, sigma=0.05, dims="product"),
        "q": Prior("Beta", mu=0.35, sigma=0.3, dims="product"),
        "likelihood": Prior("NegativeBinomial", n=1.5, dims=("date", "product")),
    }


def sample_and_plot_prior_predictive(model):
    """Sample from the prior predictive distribution and create initial plot."""
    idata = pm.sample_prior_predictive(model=model)

    bass_data = idata.prior.y.sel(chain=0, draw=0)
    return bass_data


def transform_to_actual_dates(bass_data, product_start, possible_dates):
    """Transform simulation data to actual calendar dates."""
    bass_data = bass_data.to_dataset()
    bass_data["product_start"] = product_start.to_xarray()

    df_bass_data = (
        bass_data.to_dataframe().drop(columns=["chain", "draw"]).reset_index()
    )
    df_bass_data["actual_date"] = df_bass_data["product_start"] + pd.to_timedelta(
        7 * df_bass_data["date"], unit="days"
    )

    return (
        df_bass_data.set_index(["actual_date", "product"])
        .y.unstack(fill_value=0)
        .reindex(possible_dates, fill_value=0)
    )


def main():
    """Run the Bass diffusion model simulation."""
    # Setup simulation parameters
    T, possible_dates, _, products, product_start, coords = (
        setup_simulation_parameters()
    )

    # Create and configure the Bass model
    priors = create_bass_priors()
    model = create_bass_model(t=T, coords=coords, observed=None, priors=priors)

    # Sample and visualize results
    bass_data = sample_and_plot_prior_predictive(model)
    actual_data = transform_to_actual_dates(bass_data, product_start, possible_dates)

    return bass_data, actual_data, model


if __name__ == "__main__":
    plot = True
    bass_data, actual_data, generative_model = main()

    with pm.observe(generative_model, {"y": bass_data.values}) as model:
        idata = pm.sample(nuts_sampler="nutpie", compile_kwargs={"mode": "NUMBA"})

    print(idata.sample_stats.diverging.sum())

    pm.sample_posterior_predictive(idata, model=model, extend_inferencedata=True)

    if plot:
        from pymc_marketing.plot import plot_curve

        fig, axes = plt.subplots(2, 1)
        bass_data.to_series().unstack().plot(ax=axes[0])

        actual_data.plot(ax=axes[1])
        plt.show()

        idata.posterior_predictive.y.sel(product=["P0", "P1"]).pipe(
            plot_curve, {"date"}
        )
        plt.show()

bass-data

@juanitorduz
Copy link
Collaborator

Fantastic! I'll add it into the example notebook (wip) 🙏

@juanitorduz juanitorduz added the Bass model Dealing with the Bass Defusion model label Apr 23, 2025
Juan Orduz and others added 11 commits April 23, 2025 22:05
>
> Co-authored-by: jwilkinson88
> Co-authored-by: williambdean
>
> Co-authored-by: joe <jwilkinson88>
> Co-authored-by: will <williambdean>
>
> Co-authored-by: jwilkinson88 [email protected]
> Co-authored-by: williambdean [email protected]
>
> Co-authored-by: jwilkinson88 <[email protected]>
> Co-authored-by: williambdean <[email protected]>
@juanitorduz juanitorduz marked this pull request as ready for review April 24, 2025 16:25
@drbenvincent drbenvincent self-requested a review April 24, 2025 16:38
@williambdean
Copy link
Contributor Author

The injected code for the notebook runner will use 10 samples which causes a key error in the notebook

@juanitorduz
Copy link
Collaborator

The injected code for the notebook runner will use 10 samples which causes a key error in the notebook

Thanks! Fixed!

Copy link
Contributor Author

@williambdean williambdean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me!
I cannot approve it since I started the PR!

I was just waiting to see if the documentation figure renders
https://pymc-marketing--1328.org.readthedocs.build/en/1328/api/generated/pymc_marketing.bass.model.html#module-pymc_marketing.bass.model

@juanitorduz
Copy link
Collaborator

Great team work!!!

@juanitorduz juanitorduz merged commit 53538d0 into main Apr 25, 2025
33 checks passed
@juanitorduz juanitorduz deleted the bass-model branch April 25, 2025 16:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bass model Dealing with the Bass Defusion model docs Improvements or additions to documentation enhancement New feature or request Prior class tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Bass Diffusion Model
2 participants