diff --git a/tests/clv/models/test_basic.py b/tests/clv/models/test_basic.py index 42993415a..afec31b73 100644 --- a/tests/clv/models/test_basic.py +++ b/tests/clv/models/test_basic.py @@ -7,6 +7,7 @@ from arviz import InferenceData, from_dict from pymc_marketing.clv.models.basic import CLVModel +from tests.clv.utils import set_model_fit class CLVModelTest(CLVModel): @@ -54,10 +55,10 @@ def test_create_distribution_from_wrong_prior(self): model = CLVModelTest() with pytest.raises( ValueError, - match="Distribution definately_not_PyMC_dist does not exist in PyMC", + match="Distribution definitely_not_PyMC_dist does not exist in PyMC", ): model._create_distribution( - {"dist": "definately_not_PyMC_dist", "kwargs": {"alpha": 1, "beta": 1}} + {"dist": "definitely_not_PyMC_dist", "kwargs": {"alpha": 1, "beta": 1}} ) def test_fit_mcmc(self): @@ -165,9 +166,7 @@ def test_thin_fit_result(self): model = CLVModelTest(data=data) model.build_model() fake_idata = from_dict(dict(x=np.random.normal(size=(4, 1000)))) - fake_idata.add_groups(dict(fit_data=data.to_xarray())) - model.set_idata_attrs(fake_idata) - model.idata = fake_idata + set_model_fit(model, fake_idata) thin_model = model.thin_fit_result(keep_every=20) assert thin_model is not model diff --git a/tests/clv/models/test_gamma_gamma.py b/tests/clv/models/test_gamma_gamma.py index b53298241..c057bfb35 100644 --- a/tests/clv/models/test_gamma_gamma.py +++ b/tests/clv/models/test_gamma_gamma.py @@ -10,6 +10,7 @@ GammaGammaModel, GammaGammaModelIndividual, ) +from tests.clv.utils import set_model_fit class BaseTestGammaGammaModel: @@ -232,8 +233,7 @@ def test_new_customer_spend(self, distribution): fake_fit = pm.sample_prior_predictive( samples=1000, model=model.model, random_seed=self.rng ) - fake_fit.add_groups(dict(posterior=fake_fit.prior)) - model.idata = fake_fit + set_model_fit(model, fake_fit.prior) # Closed formula solution for the mean and var of the population spend (eqs 3, 4 from [1]) # noqa: E501 expected_preds_mean = p_mean * v_mean / (q_mean - 1) expected_preds_std = np.sqrt( diff --git a/tests/clv/models/test_pareto_nbd.py b/tests/clv/models/test_pareto_nbd.py index 71337ef79..1f81c2c99 100644 --- a/tests/clv/models/test_pareto_nbd.py +++ b/tests/clv/models/test_pareto_nbd.py @@ -9,6 +9,7 @@ from pymc_marketing.clv import ParetoNBDModel from pymc_marketing.clv.distributions import ParetoNBD +from tests.clv.utils import set_model_fit class TestParetoNBDModel: @@ -36,8 +37,6 @@ def setup_class(cls): # Instantiate model with CDNOW data for testing cls.model = ParetoNBDModel(cls.data) - # TODO: This can be removed after build_model() is called internally with __init__ - cls.model.build_model() # Also instantiate lifetimes model for comparison cls.lifetimes_model = ParetoNBDFitter() @@ -64,8 +63,7 @@ def setup_class(cls): ), } ) - - cls.model.idata = cls.mock_fit + set_model_fit(cls.model, cls.mock_fit) @pytest.fixture(scope="class") def model_config(self): diff --git a/tests/clv/test_utils.py b/tests/clv/test_utils.py index 0ef2fd7af..ca008de8a 100644 --- a/tests/clv/test_utils.py +++ b/tests/clv/test_utils.py @@ -15,6 +15,7 @@ rfm_summary, to_xarray, ) +from tests.clv.utils import set_model_fit def test_to_xarray(): @@ -74,11 +75,8 @@ def fitted_bg(test_summary_data) -> BetaGeoModel: model.build_model() fake_fit = pm.sample_prior_predictive( samples=50, model=model.model, random_seed=rng - ) - fake_fit.add_groups(dict(posterior=fake_fit.prior)) - model.idata = fake_fit - model.set_idata_attrs(model.idata) - model._add_fit_data_group(model.data) + ).prior + set_model_fit(model, fake_fit) return model @@ -103,11 +101,8 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel: # Mock an idata object for tests requiring a fitted model fake_fit = pm.sample_prior_predictive( samples=50, model=pnbd_model.model, random_seed=rng - ) - fake_fit.add_groups(dict(posterior=fake_fit.prior)) - pnbd_model.idata = fake_fit - pnbd_model.set_idata_attrs(pnbd_model.idata) - pnbd_model._add_fit_data_group(pnbd_model.data) + ).prior + set_model_fit(pnbd_model, fake_fit) return pnbd_model @@ -136,11 +131,8 @@ def fitted_gg(test_summary_data) -> GammaGammaModel: model.build_model() fake_fit = pm.sample_prior_predictive( samples=50, model=model.model, random_seed=rng - ) - fake_fit.add_groups(dict(posterior=fake_fit.prior)) - model.idata = fake_fit - model.set_idata_attrs(model.idata) - model._add_fit_data_group(model.data) + ).prior + set_model_fit(model, fake_fit) return model diff --git a/tests/clv/utils.py b/tests/clv/utils.py new file mode 100644 index 000000000..db329cdf2 --- /dev/null +++ b/tests/clv/utils.py @@ -0,0 +1,18 @@ +from typing import Union + +from arviz import InferenceData +from xarray import Dataset + +from pymc_marketing.clv.models import CLVModel + + +def set_model_fit(model: CLVModel, fit: Union[InferenceData, Dataset]): + if isinstance(fit, InferenceData): + assert "posterior" in fit.groups() + else: + fit = InferenceData(posterior=fit) + if model.model is None: + model.build_model() + model.idata = fit + model.idata.add_groups(fit_data=model.data.to_xarray()) + model.set_idata_attrs(fit)