Skip to content

Commit

Permalink
Add test util for setting fake data in CLV models
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 15, 2024
1 parent 6f7a3d4 commit 3e92610
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 26 deletions.
9 changes: 4 additions & 5 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from arviz import InferenceData, from_dict

from pymc_marketing.clv.models.basic import CLVModel
from tests.clv.utils import set_model_fit


class CLVModelTest(CLVModel):
Expand Down Expand Up @@ -54,10 +55,10 @@ def test_create_distribution_from_wrong_prior(self):
model = CLVModelTest()
with pytest.raises(
ValueError,
match="Distribution definately_not_PyMC_dist does not exist in PyMC",
match="Distribution definitely_not_PyMC_dist does not exist in PyMC",
):
model._create_distribution(
{"dist": "definately_not_PyMC_dist", "kwargs": {"alpha": 1, "beta": 1}}
{"dist": "definitely_not_PyMC_dist", "kwargs": {"alpha": 1, "beta": 1}}
)

def test_fit_mcmc(self):
Expand Down Expand Up @@ -165,9 +166,7 @@ def test_thin_fit_result(self):
model = CLVModelTest(data=data)
model.build_model()
fake_idata = from_dict(dict(x=np.random.normal(size=(4, 1000))))
fake_idata.add_groups(dict(fit_data=data.to_xarray()))
model.set_idata_attrs(fake_idata)
model.idata = fake_idata
set_model_fit(model, fake_idata)

thin_model = model.thin_fit_result(keep_every=20)
assert thin_model is not model
Expand Down
4 changes: 2 additions & 2 deletions tests/clv/models/test_gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GammaGammaModel,
GammaGammaModelIndividual,
)
from tests.clv.utils import set_model_fit


class BaseTestGammaGammaModel:
Expand Down Expand Up @@ -232,8 +233,7 @@ def test_new_customer_spend(self, distribution):
fake_fit = pm.sample_prior_predictive(
samples=1000, model=model.model, random_seed=self.rng
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.idata = fake_fit
set_model_fit(model, fake_fit.prior)
# Closed formula solution for the mean and var of the population spend (eqs 3, 4 from [1]) # noqa: E501
expected_preds_mean = p_mean * v_mean / (q_mean - 1)
expected_preds_std = np.sqrt(
Expand Down
6 changes: 2 additions & 4 deletions tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pymc_marketing.clv import ParetoNBDModel
from pymc_marketing.clv.distributions import ParetoNBD
from tests.clv.utils import set_model_fit


class TestParetoNBDModel:
Expand Down Expand Up @@ -36,8 +37,6 @@ def setup_class(cls):

# Instantiate model with CDNOW data for testing
cls.model = ParetoNBDModel(cls.data)
# TODO: This can be removed after build_model() is called internally with __init__
cls.model.build_model()

# Also instantiate lifetimes model for comparison
cls.lifetimes_model = ParetoNBDFitter()
Expand All @@ -64,8 +63,7 @@ def setup_class(cls):
),
}
)

cls.model.idata = cls.mock_fit
set_model_fit(cls.model, cls.mock_fit)

@pytest.fixture(scope="class")
def model_config(self):
Expand Down
22 changes: 7 additions & 15 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
rfm_summary,
to_xarray,
)
from tests.clv.utils import set_model_fit


def test_to_xarray():
Expand Down Expand Up @@ -74,11 +75,8 @@ def fitted_bg(test_summary_data) -> BetaGeoModel:
model.build_model()
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=rng
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.idata = fake_fit
model.set_idata_attrs(model.idata)
model._add_fit_data_group(model.data)
).prior
set_model_fit(model, fake_fit)

return model

Expand All @@ -103,11 +101,8 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
# Mock an idata object for tests requiring a fitted model
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
pnbd_model.idata = fake_fit
pnbd_model.set_idata_attrs(pnbd_model.idata)
pnbd_model._add_fit_data_group(pnbd_model.data)
).prior
set_model_fit(pnbd_model, fake_fit)

return pnbd_model

Expand Down Expand Up @@ -136,11 +131,8 @@ def fitted_gg(test_summary_data) -> GammaGammaModel:
model.build_model()
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=rng
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.idata = fake_fit
model.set_idata_attrs(model.idata)
model._add_fit_data_group(model.data)
).prior
set_model_fit(model, fake_fit)

return model

Expand Down
18 changes: 18 additions & 0 deletions tests/clv/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Union

from arviz import InferenceData
from xarray import Dataset

from pymc_marketing.clv.models import CLVModel


def set_model_fit(model: CLVModel, fit: Union[InferenceData, Dataset]):
if isinstance(fit, InferenceData):
assert "posterior" in fit.groups()
else:
fit = InferenceData(posterior=fit)
if model.model is None:
model.build_model()
model.idata = fit
model.idata.add_groups(fit_data=model.data.to_xarray())
model.set_idata_attrs(fit)

0 comments on commit 3e92610

Please sign in to comment.