From 642cb2e21b2ddb13dd6f7fa7acc63df0db43e2a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Fri, 14 Jun 2024 13:26:40 +0200 Subject: [PATCH] Add classes_ to cfe --- metalearners/cross_fit_estimator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index e26d8981..97ef1057 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -101,6 +101,7 @@ class CrossFitEstimator: _overall_estimator: _ScikitModel | None = field(init=False) _test_indices: tuple[np.ndarray] | None = field(init=False) _n_classes: int | None = field(init=False) + classes_: np.ndarray | None = field(init=False) def __post_init__(self): _validate_n_folds(self.n_folds) @@ -115,6 +116,7 @@ def __post_init__(self): self._overall_estimator: _ScikitModel | None = None self._test_indices: tuple[np.ndarray] | None = None self._n_classes: int | None = None + self.classes_: np.ndarray | None = None def _train_overall_estimator( self, X: Matrix, y: Matrix | Vector, fit_params: dict | None = None @@ -189,7 +191,14 @@ def fit( if is_classifier(self): self._n_classes = len(np.unique(y)) - + self.classes_ = np.unique(y) + for e in self._estimators: + if set(e.classes_) != set(self.classes_): # type: ignore + raise ValueError( + "Some cross fit estimators training data had less classes than " + "the overall estimator. Please check the cv parameter. If you are " + "synchronizing the folds in a MetaLearner consider not doing it." + ) return self def _initialize_prediction_tensor(