Skip to content

Commit

Permalink
add convergence warning
Browse files Browse the repository at this point in the history
  • Loading branch information
statmlben committed Aug 23, 2024
1 parent bf71128 commit 0bb03f5
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions rehline/_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

# License: MIT License

import warnings

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import (_check_sample_weight, check_array,
check_is_fitted, check_X_y)

Expand Down Expand Up @@ -167,6 +170,12 @@ def fit(self, X, sample_weight=None):
self.dual_obj_ = result.dual_objfns
self.primal_obj_ = result.primal_objfns

if self.n_iter_ >= self.max_iter:
warnings.warn(
"ReHLine failed to converge, increase the number of iterations: `max_iter`.",
ConvergenceWarning,
)

def decision_function(self, X):
"""The decision function evaluated on the given dataset
Expand Down Expand Up @@ -363,6 +372,12 @@ def fit(self, X, y, sample_weight=None):
self.dual_obj_ = result.dual_objfns
self.primal_obj_ = result.primal_objfns

if self.n_iter_ >= self.max_iter:
warnings.warn(
"ReHLine failed to converge, increase the number of iterations: `max_iter`.",
ConvergenceWarning,
)

def decision_function(self, X):
"""The decision function evaluated on the given dataset
Expand Down

0 comments on commit 0bb03f5

Please sign in to comment.