Skip to content

Commit

Permalink
fix: add new api test_api_main, parametrize and update explain method
Browse files Browse the repository at this point in the history
  • Loading branch information
Lopa10ko committed Jan 29, 2025
1 parent f0733ac commit 4734fec
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 70 deletions.
13 changes: 2 additions & 11 deletions fedot_ind/tools/explain/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def explain(self, **kwargs):

@staticmethod
def predict_proba(model, features, target):
if hasattr(model, 'solver'):
model.solver.test_features = None
if hasattr(model, 'manager') and hasattr(model.manager, 'solver'):
model.manager.solver.test_features = None
base_proba_ = model.predict_proba(predict_data=(features, target))
else:
base_proba_ = model.predict_proba(X=features)
Expand Down Expand Up @@ -59,15 +59,6 @@ def visual(self, metric: str = 'mean', name: str = 'test', threshold: float = No
plt.savefig(f'recurrence_matrix_for_{name}_dataset_cls_{classes}.png')
plt.close()

@staticmethod
def predict_proba(model, features, target):
if hasattr(model, 'solver'):
model.solver.test_features = None
base_proba_ = model.predict_proba(predict_data=(features, target))
else:
base_proba_ = model.predict_proba(X=features)
return base_proba_


class PointExplainer(Explainer):
def __init__(self, model, features, target):
Expand Down
119 changes: 60 additions & 59 deletions tests/unit/api/main/test_api_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from matplotlib import get_backend, pyplot as plt

from fedot_ind.api.main import FedotIndustrial
from fedot_ind.api.utils.data import SynthTimeSeriesData
from fedot_ind.core.repository.config_repository import DEFAULT_CLF_AUTOML_CONFIG, DEFAULT_COMPUTE_CONFIG, \
DEFAULT_REG_AUTOML_CONFIG
from fedot_ind.tools.synthetic.synth_ts_data import SynthTimeSeriesData
from fedot_ind.tools.synthetic.ts_datasets_generator import TimeSeriesDatasetsGenerator


Expand Down Expand Up @@ -45,58 +47,54 @@ def multivariate_regression_data():
return train_data


@pytest.fixture
def fedot_industrial_classification():
return FedotIndustrial(problem='classification', timeout=0.1, logging_level=50)
AUTOML_LEARNING_STRATEGY = dict(timeout=0.1,
logging_level=50)
LEARNING_CONFIG = {'learning_strategy': 'from_scratch',
'learning_strategy_params': AUTOML_LEARNING_STRATEGY,
'optimisation_loss': {'quality_loss': 'f1'}}
INDUSTRIAL_CONFIG = {'problem': 'classification'}
API_CONFIG = {'industrial_config': INDUSTRIAL_CONFIG,
'automl_config': DEFAULT_CLF_AUTOML_CONFIG,
'learning_config': LEARNING_CONFIG,
'compute_config': DEFAULT_COMPUTE_CONFIG}
return FedotIndustrial(**API_CONFIG)


@pytest.fixture
def fedot_industrial_regression():
return FedotIndustrial(problem='regression', timeout=0.1, logging_level=50)


def test_fit_predict_classification_multi(fedot_industrial_classification):
data = multivariate_clf_data()
fedot_industrial_classification.fit(data)
predict = fedot_industrial_classification.predict(data)
predict_proba = fedot_industrial_classification.predict_proba(data)
metrics = fedot_industrial_classification.get_metrics(target=data[1])
np.unique(data[1])
AUTOML_LEARNING_STRATEGY = dict(timeout=0.1,
logging_level=50)
LEARNING_CONFIG = {'learning_strategy': 'from_scratch',
'learning_strategy_params': AUTOML_LEARNING_STRATEGY,
'optimisation_loss': {'quality_loss': 'rmse'}}
INDUSTRIAL_CONFIG = {'problem': 'regression'}
API_CONFIG = {'industrial_config': INDUSTRIAL_CONFIG,
'automl_config': DEFAULT_REG_AUTOML_CONFIG,
'learning_config': LEARNING_CONFIG,
'compute_config': DEFAULT_COMPUTE_CONFIG}
return FedotIndustrial(**API_CONFIG)


@pytest.mark.parametrize('metric_names, data_func, fedot_func', (
[('f1', 'accuracy'), univariate_clf_data, fedot_industrial_classification],
[('f1', 'accuracy'), multivariate_clf_data, fedot_industrial_classification],
[('rmse', 'mae'), univariate_regression_data, fedot_industrial_regression],
[('rmse', 'mae'), multivariate_regression_data, fedot_industrial_regression],
), ids=['clf_uni', 'clf_multi', 'reg_uni', 'reg_multi'])
def test_fit_predict_fedot_industrial(metric_names, data_func, fedot_func):
data = data_func()
fedot_industrial = fedot_func()
fedot_industrial.fit(data)
predict = fedot_industrial.predict(data)
predict_proba = fedot_industrial.predict_proba(data)
metrics = fedot_industrial.get_metrics(predict,
predict_proba,
target=data[1],
metric_names=metric_names)

assert predict.shape[0] == data[1].shape[0]
assert predict_proba.shape[0] == data[1].shape[0]
assert metrics is not None


def test_fit_predict_classification_uni(fedot_industrial_classification):
data = univariate_clf_data()
fedot_industrial_classification.fit(data)
predict = fedot_industrial_classification.predict(data)
predict_proba = fedot_industrial_classification.predict_proba(data)
metrics = fedot_industrial_classification.get_metrics(target=data[1])
np.unique(data[1])

assert predict.shape[0] == data[1].shape[0]
assert predict_proba.shape[0] == data[1].shape[0]
assert metrics is not None


def test_fit_predict_regression_uni(fedot_industrial_regression):
data = univariate_regression_data()
fedot_industrial_regression.fit(data)
predict = fedot_industrial_regression.predict(data)

assert predict.shape[0] == data[1].shape[0]
if len(data[1].shape) > 1:
assert predict.shape[1] == data[1].shape[1]


def test_fit_predict_regression_multi(fedot_industrial_regression):
data = multivariate_regression_data()
fedot_industrial_regression.fit(data)
predict = fedot_industrial_regression.predict(data)

assert predict.shape[0] == data[1].shape[0]
if len(data[1].shape) > 1:
assert predict.shape[1] == data[1].shape[1]

Expand All @@ -108,7 +106,7 @@ def ts_config():
'start_val': 36.6})


def test_generate_ts(fedot_industrial_classification, ts_config):
def test_generate_ts(ts_config):
ts = SynthTimeSeriesData(ts_config).generate_ts()

assert isinstance(ts, np.ndarray)
Expand All @@ -124,10 +122,7 @@ def anomaly_config():
}


def test_generate_anomaly_ts(
fedot_industrial_classification,
ts_config,
anomaly_config):
def test_generate_anomaly_ts(ts_config, anomaly_config):
init_synth_ts, mod_synth_ts, synth_inters = SynthTimeSeriesData(anomaly_config).generate_anomaly_ts(ts_config)
assert len(init_synth_ts) == len(mod_synth_ts)
for anomaly_type in synth_inters:
Expand All @@ -136,16 +131,22 @@ def test_generate_anomaly_ts(
assert interval[0] in ts_range and interval[1] in ts_range


def test_finetune(fedot_industrial_classification):
industrial = fedot_industrial_classification
data = univariate_clf_data()
industrial.fit(data)
industrial.finetune(train_data=data, tuning_params={'tuning_timeout': 0.1})
assert industrial.solver is not None
@pytest.mark.parametrize('data_func, fedot_func', (
[univariate_clf_data, fedot_industrial_classification],
[multivariate_clf_data, fedot_industrial_classification],
[univariate_regression_data, fedot_industrial_regression],
[multivariate_regression_data, fedot_industrial_regression],
), ids=['clf_uni', 'clf_multi', 'reg_uni', 'reg_multi'])
def test_finetune(data_func, fedot_func):
data = data_func()
fedot_industrial = fedot_func()
fedot_industrial.fit(data)
fedot_industrial.finetune(train_data=data, tuning_params={'tuning_timeout': 0.1})
assert fedot_industrial.solver is not None


def test_plot_methods(fedot_industrial_classification):
industrial = fedot_industrial_classification
def test_plot_methods():
industrial = fedot_industrial_classification()
data = univariate_clf_data()
industrial.fit(data)
industrial.predict(data)
Expand All @@ -156,4 +157,4 @@ def test_plot_methods(fedot_industrial_classification):
get_backend()
plt.switch_backend("Agg")
warnings.filterwarnings("ignore", "Matplotlib is currently using agg")
fedot_industrial_classification.explain()
industrial.explain()

0 comments on commit 4734fec

Please sign in to comment.