From 3d23dc079d765e41224bd3b24f031344858fc73c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 11 Dec 2024 19:27:23 -0800 Subject: [PATCH] feat: compat with new sklearn version --- pysr/sr.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pysr/sr.py b/pysr/sr.py index 3cae1735..c23805b4 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -58,6 +58,13 @@ _suggest_keywords, ) +try: + from sklearn.utils.validation import validate_data + + OLD_SKLEARN = False +except ImportError: + OLD_SKLEARN = True + ALREADY_RAN = False @@ -1604,11 +1611,17 @@ def _validate_and_set_fit_params( ) def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]: - raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore + if OLD_SKLEARN: + raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore + else: + raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore return cast(tuple[ndarray, ndarray], raw_out) def _validate_data_X(self, X: Any) -> ndarray: - raw_out = self._validate_data(X=X, reset=False) # type: ignore + if OLD_SKLEARN: + raw_out = self._validate_data(X=X, reset=False) # type: ignore + else: + raw_out = validate_data(self, X=X, reset=False) # type: ignore return cast(ndarray, raw_out) def _get_precision_mapped_dtype(self, X: np.ndarray) -> type: