diff --git a/mlquantify/base.py b/mlquantify/base.py index 4ae9780..9dba0f7 100644 --- a/mlquantify/base.py +++ b/mlquantify/base.py @@ -314,14 +314,14 @@ def fit_learner(self, X, y): y : array-like Training labels. """ - if mq.ARGUMENTS_SETTED: + if self.learner is not None: + if not self.learner_fitted: + self.learner_.fit(X, y) + elif mq.ARGUMENTS_SETTED: if self.is_probabilistic and mq.arguments["posteriors_test"] is not None: return elif not self.is_probabilistic and mq.arguments["y_pred"] is not None: return - else: - if not self.learner_fitted: - self.learner_.fit(X, y) def predict_learner(self, X): """Predict the class labels or probabilities for the given data.