Skip to content

Commit

Permalink
fix issues on the pytest version
Browse files Browse the repository at this point in the history
  • Loading branch information
pswpswpsw committed Feb 8, 2024
1 parent 3a7e869 commit 45bbfe0
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 189 deletions.
2 changes: 1 addition & 1 deletion src/pykoopman/koopman.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from .regression import DMDc
from .regression import EDMDc
from .regression import EnsembleBaseRegressor
from .regression import PyDMDRegressor
from .regression import NNDMD
from .regression import PyDMDRegressor


class Koopman(BaseEstimator):
Expand Down
18 changes: 3 additions & 15 deletions src/pykoopman/regression/_base_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
from __future__ import annotations

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import clone
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -78,16 +77,10 @@ def fit(self, X, y, **fit_params):
functions.
"""

# if (
# isinstance(X, np.ndarray)
# and isinstance(y, np.ndarray)
# and X.ndim == 2
# and y.ndim == 2
# ):
# case 2: x, y are 2D np.ndarray, must be 1-step, no validation
self._training_dim = y.ndim
# transformers are designed to modify X which is 2d dimensional, we
# need to modify y accordingly.

self._training_dim = y.ndim
if y.ndim == 1:
y_2d = y.reshape(-1, 1)
else:
Expand All @@ -104,17 +97,12 @@ def fit(self, X, y, **fit_params):

if self.regressor is None:
from sklearn.linear_model import LinearRegression

self.regressor_ = LinearRegression()
else:
self.regressor_ = clone(self.regressor)

self.regressor_.fit(X, y_trans, **fit_params)
# elif isinstance(X, list) and isinstance(y, list):
# # case 4: x, y are two lists of trajectories, we have validation data
# for
#
# else:
# raise ValueError("check `x` and `y` for `self.fit`")

if hasattr(self.regressor_, "feature_names_in_"):
self.feature_names_in_ = self.regressor_.feature_names_in_
Expand Down
Loading

0 comments on commit 45bbfe0

Please sign in to comment.