Add Input Validation for CLV Models#2305
Add Input Validation for CLV Models#2305shivamlalakiya wants to merge 12 commits intopymc-labs:mainfrom
Conversation
PR SummaryMedium Risk Overview Aligns model-specific constraints and CLV computation plumbing. Tests expanded to cover the new validation errors across BG/NBD, MBG/NBD, Pareto/NBD, Shifted BG, and Gamma-Gamma, plus a check that Gamma-Gamma CLV always derives Written by Cursor Bugbot for commit 8e2a66b. This will update automatically on new commits. Configure here. |
ColtAllen
left a comment
There was a problem hiding this comment.
Hey @shivamlalakiya,
Thanks for opening this PR! Can we move this method into CLVModel so that all CLV models can use it? Supporting the data requirements of all models will require some modification, but the end result will look very similar to this legacy code.
I'm nearly done with a major PR involving changes to the CLV API. For this PR we can keep things simple and just make changes to CLVModel, and I'll update my PR with the model-specific validation calls.
test_model_multi
Outdated
There was a problem hiding this comment.
Curious how this file keeps getting re-created? Is this due to an MMM test?
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2305 +/- ##
==========================================
+ Coverage 93.05% 93.20% +0.14%
==========================================
Files 78 79 +1
Lines 12230 12491 +261
==========================================
+ Hits 11381 11642 +261
Misses 849 849 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ColtAllen
left a comment
There was a problem hiding this comment.
Keep this in CLVModel despite what the Bugbot says. Yes, some changes are needed to support ShiftedBetaGeoModel, but we can add a new parameter for those conditions. The BG/NBD, Pareto/NBD, and MBG/NBD all have identical requirements.
|
Hi @ColtAllen, thanks for the guidance! I've updated the PR to address your feedback:
All tests and linting are passing locally. Let me know if there's anything else you'd like me to tweak before we merge! |
ColtAllen
left a comment
There was a problem hiding this comment.
Thanks @shivamlalakiya - this PR is on the right track! Can we add a monetary_value parameter for the Gamma-Gamma model as well? Tests will also need to be added for models other than BGNBD and sBG.
6fd23fe to
088f814
Compare
|
Thank you @ColtAllen , I have updated the PR to address all the feedback. Let me know if there's anything else you'd like me to tweak before we merge! |
| def expected_customer_lifetime_value( | ||
| self, | ||
| data: pd.DataFrame, | ||
| future_t: int = 12, | ||
| discount_rate: float = 0.00, | ||
| time_unit: str = "D", | ||
| monetary_value: xarray.DataArray | pd.Series | None = None, | ||
| ) -> xarray.DataArray: | ||
| """Compute the average lifetime value for a group of one or more customers.""" | ||
| data = data.copy() | ||
|
|
||
| # Inject custom monetary value if provided | ||
| if monetary_value is not None: | ||
| data["monetary_value"] = monetary_value | ||
|
|
||
| return customer_lifetime_value( | ||
| transaction_model=self, | ||
| data=data, | ||
| future_t=future_t, | ||
| discount_rate=discount_rate, | ||
| time_unit=time_unit, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Why do we need this here? Are we overwriting anything on BetaGeoModel? Or is it the same method?
There was a problem hiding this comment.
@PabloRoque Thank you, ModifiedBetaGeoModel inherits directly from CLVModel, not from BetaGeoModel, so nothing is being overridden here. This method is needed — without it, ModifiedBetaGeoModel would have no way to call expected_customer_lifetime_value directly. That said, since the implementation is identical to ParetoNBDModel's, would you prefer I move it up to CLVModel as a shared base method?
There was a problem hiding this comment.
ModifiedBetaGeoModel would have no way to call expected_customer_lifetime_value directly
None of the transaction models call expected_customer_lifetime_value directly. It is only called from the GammaGammaModel for monetary value, which itself is a wrapper for clv.utils.customer_lifetime_value, which users can call directly if they want to use monetary predictions from an external model.
There was a problem hiding this comment.
@shivamlalakiya I am having a look at this line class ModifiedBetaGeoModel(BetaGeoModel) so indeed inherits from BetaGeoModel and that method is described in that class if I recall correctly.
There was a problem hiding this comment.
@shivamlalakiya I understand now. expected_customer_lifetime_value is an external helper, not attached to any of the classes. Since that method accepts transaction_model then I don't see the need to add this.
ColtAllen
left a comment
There was a problem hiding this comment.
Remove the CLV method from the MBG/NBD model.
I noticed the docstrings for the CLV estimation functions are outdated, which might be the source of the confusion. While you're at it, do you mind quickly updating these two lines to indicate the MBG/NBD model is supported as well?
https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/clv/models/gamma_gamma.py#L175
https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/clv/utils.py#L65
test_model_multi
Outdated
There was a problem hiding this comment.
Curious how this file keeps getting re-created? Is this due to an MMM test?
| def expected_customer_lifetime_value( | ||
| self, | ||
| data: pd.DataFrame, | ||
| future_t: int = 12, | ||
| discount_rate: float = 0.00, | ||
| time_unit: str = "D", | ||
| monetary_value: xarray.DataArray | pd.Series | None = None, | ||
| ) -> xarray.DataArray: | ||
| """Compute the average lifetime value for a group of one or more customers.""" | ||
| data = data.copy() | ||
|
|
||
| # Inject custom monetary value if provided | ||
| if monetary_value is not None: | ||
| data["monetary_value"] = monetary_value | ||
|
|
||
| return customer_lifetime_value( | ||
| transaction_model=self, | ||
| data=data, | ||
| future_t=future_t, | ||
| discount_rate=discount_rate, | ||
| time_unit=time_unit, | ||
| ) | ||
|
|
There was a problem hiding this comment.
ModifiedBetaGeoModel would have no way to call expected_customer_lifetime_value directly
None of the transaction models call expected_customer_lifetime_value directly. It is only called from the GammaGammaModel for monetary value, which itself is a wrapper for clv.utils.customer_lifetime_value, which users can call directly if they want to use monetary predictions from an external model.
PabloRoque
left a comment
There was a problem hiding this comment.
@shivamlalakiya Thanks for the contribution, and sorry for the confusion around expected_customer_lifetime_value. I left some comments on the PR. Let me know your thoughts.
ColtAllen
left a comment
There was a problem hiding this comment.
@PabloRoque what do you think the code logic would look like if this were in __init__.py? I agree the current approach of parametrizing specific requirements in each model is a code smell.
Another possibility could be to use structured pattern matching with the _model_type attribute of the model child classes. This should require only a single parameter for _check_inputs, and would look something like this pseudocode:
# declaration in CLVModel
def check_inputs(self, model_type):
match model_type:
case ["BG/NBD", "MBG/NBD", "Pareto/NBD":
# self.data["column"] checks
case ""Gamma-Gamma Model (Individual Transactions)"":
# frequency and monetary checks
case "BG/BB":
# input checks
case "sBG":
# input checks
case _:
raise ValueError("Invalid Model Type") # this will probably never happen, but we still need a wildcard case
# usage in child class
self._check_inputs(self._model_type).This way validation logic for all models is hard-coded and lives in this single method.
- Add _check_inputs static method to CLVModel to validate frequency >= 0, recency >= 0, T >= 0, and recency <= T - Integrate _check_inputs into BetaGeoModel (BG/NBD) initialization - Per review feedback: keep validation in CLVModel base class despite automated suggestions - ShiftedBetaGeoModel uses its own validation (1 <= recency <= T, T >= 2) as it has different requirements
58c44c3 to
49b8679
Compare
- Move _check_inputs() into CLVModel.__init__ to eliminate boilerplate - Add check_frequency param so ModifiedBetaGeoModel skips frequency >= 0 check - Remove expected_customer_lifetime_value from ModifiedBetaGeoModel and ParetoNBDModel - Remove monetary_value override param from GammaGammaModel.expected_customer_lifetime_value - Update docstrings to list ModifiedBetaGeoModel as supported transaction model
…atch to pytest.raises - Move test_check_inputs_validation out of nested scope in test_gamma_gamma.py - Add monetary_value column to GammaGamma invalid_data fixtures - Remove self._check_frequency = False from ModifiedBetaGeoModel (check_frequency=True already correctly allows zero frequency while rejecting negative values) - Pass match=expected_error_match to pytest.raises in all three test files
Drop the pass-through ModifiedBetaGeoModel.__init__ since it duplicated BetaGeoModel initialization behavior without adding custom logic.
Description
Addresses issue #175.
Following maintainer feedback, this PR implements a shared
_check_inputsutility method in the baseCLVModelrather than limiting it toBetaGeoModel. This centralizes data validation (frequency,recency,T) for all models that share identical input requirements (BG/NBD, Pareto/NBD, MBG/NBD).It also includes parameter handling to support the specific condition variations required by models like
ShiftedBetaGeoModel.Previously, invalid data (e.g., negative inputs or
recency > T) would cause obscure errors downstream during model fitting or sampling. Now, the base model catches these early and raises a clear, descriptiveValueError.Changes
_check_inputsto the baseCLVModelto validatefrequency,recency, andT.ShiftedBetaGeoModel).tests/clv/covering:recency > T.Related Issue
_check_inputs()Utility toclvModule? #175Checklist
pre-commit.ci autofixto auto-fix.📚 Documentation preview 📚: https://pymc-marketing--2305.org.readthedocs.build/en/2305/