diff --git a/rehline/_class.py b/rehline/_class.py index ed4c2a8..6bdcc28 100644 --- a/rehline/_class.py +++ b/rehline/_class.py @@ -5,8 +5,11 @@ # License: MIT License +import warnings + import numpy as np from sklearn.base import BaseEstimator +from sklearn.exceptions import ConvergenceWarning from sklearn.utils.validation import (_check_sample_weight, check_array, check_is_fitted, check_X_y) @@ -167,6 +170,12 @@ def fit(self, X, sample_weight=None): self.dual_obj_ = result.dual_objfns self.primal_obj_ = result.primal_objfns + if self.n_iter_ >= self.max_iter: + warnings.warn( + "ReHLine failed to converge, increase the number of iterations: `max_iter`.", + ConvergenceWarning, + ) + def decision_function(self, X): """The decision function evaluated on the given dataset @@ -363,6 +372,12 @@ def fit(self, X, y, sample_weight=None): self.dual_obj_ = result.dual_objfns self.primal_obj_ = result.primal_objfns + if self.n_iter_ >= self.max_iter: + warnings.warn( + "ReHLine failed to converge, increase the number of iterations: `max_iter`.", + ConvergenceWarning, + ) + def decision_function(self, X): """The decision function evaluated on the given dataset