From 4734fece999a876bb815742e692aa3c02f6470aa Mon Sep 17 00:00:00 2001 From: Lopa10ko Date: Wed, 29 Jan 2025 15:50:23 +0300 Subject: [PATCH] fix: add new api test_api_main, parametrize and update explain method --- fedot_ind/tools/explain/explain.py | 13 +-- tests/unit/api/main/test_api_main.py | 119 ++++++++++++++------------- 2 files changed, 62 insertions(+), 70 deletions(-) diff --git a/fedot_ind/tools/explain/explain.py b/fedot_ind/tools/explain/explain.py index 16e7e39e7..dd3ca1f4c 100644 --- a/fedot_ind/tools/explain/explain.py +++ b/fedot_ind/tools/explain/explain.py @@ -23,8 +23,8 @@ def explain(self, **kwargs): @staticmethod def predict_proba(model, features, target): - if hasattr(model, 'solver'): - model.solver.test_features = None + if hasattr(model, 'manager') and hasattr(model.manager, 'solver'): + model.manager.solver.test_features = None base_proba_ = model.predict_proba(predict_data=(features, target)) else: base_proba_ = model.predict_proba(X=features) @@ -59,15 +59,6 @@ def visual(self, metric: str = 'mean', name: str = 'test', threshold: float = No plt.savefig(f'recurrence_matrix_for_{name}_dataset_cls_{classes}.png') plt.close() - @staticmethod - def predict_proba(model, features, target): - if hasattr(model, 'solver'): - model.solver.test_features = None - base_proba_ = model.predict_proba(predict_data=(features, target)) - else: - base_proba_ = model.predict_proba(X=features) - return base_proba_ - class PointExplainer(Explainer): def __init__(self, model, features, target): diff --git a/tests/unit/api/main/test_api_main.py b/tests/unit/api/main/test_api_main.py index 8ad2b4412..23320d1a5 100644 --- a/tests/unit/api/main/test_api_main.py +++ b/tests/unit/api/main/test_api_main.py @@ -5,7 +5,9 @@ from matplotlib import get_backend, pyplot as plt from fedot_ind.api.main import FedotIndustrial -from fedot_ind.api.utils.data import SynthTimeSeriesData +from fedot_ind.core.repository.config_repository import DEFAULT_CLF_AUTOML_CONFIG, DEFAULT_COMPUTE_CONFIG, \ + DEFAULT_REG_AUTOML_CONFIG +from fedot_ind.tools.synthetic.synth_ts_data import SynthTimeSeriesData from fedot_ind.tools.synthetic.ts_datasets_generator import TimeSeriesDatasetsGenerator @@ -45,58 +47,54 @@ def multivariate_regression_data(): return train_data -@pytest.fixture def fedot_industrial_classification(): - return FedotIndustrial(problem='classification', timeout=0.1, logging_level=50) + AUTOML_LEARNING_STRATEGY = dict(timeout=0.1, + logging_level=50) + LEARNING_CONFIG = {'learning_strategy': 'from_scratch', + 'learning_strategy_params': AUTOML_LEARNING_STRATEGY, + 'optimisation_loss': {'quality_loss': 'f1'}} + INDUSTRIAL_CONFIG = {'problem': 'classification'} + API_CONFIG = {'industrial_config': INDUSTRIAL_CONFIG, + 'automl_config': DEFAULT_CLF_AUTOML_CONFIG, + 'learning_config': LEARNING_CONFIG, + 'compute_config': DEFAULT_COMPUTE_CONFIG} + return FedotIndustrial(**API_CONFIG) -@pytest.fixture def fedot_industrial_regression(): - return FedotIndustrial(problem='regression', timeout=0.1, logging_level=50) - - -def test_fit_predict_classification_multi(fedot_industrial_classification): - data = multivariate_clf_data() - fedot_industrial_classification.fit(data) - predict = fedot_industrial_classification.predict(data) - predict_proba = fedot_industrial_classification.predict_proba(data) - metrics = fedot_industrial_classification.get_metrics(target=data[1]) - np.unique(data[1]) + AUTOML_LEARNING_STRATEGY = dict(timeout=0.1, + logging_level=50) + LEARNING_CONFIG = {'learning_strategy': 'from_scratch', + 'learning_strategy_params': AUTOML_LEARNING_STRATEGY, + 'optimisation_loss': {'quality_loss': 'rmse'}} + INDUSTRIAL_CONFIG = {'problem': 'regression'} + API_CONFIG = {'industrial_config': INDUSTRIAL_CONFIG, + 'automl_config': DEFAULT_REG_AUTOML_CONFIG, + 'learning_config': LEARNING_CONFIG, + 'compute_config': DEFAULT_COMPUTE_CONFIG} + return FedotIndustrial(**API_CONFIG) + + +@pytest.mark.parametrize('metric_names, data_func, fedot_func', ( + [('f1', 'accuracy'), univariate_clf_data, fedot_industrial_classification], + [('f1', 'accuracy'), multivariate_clf_data, fedot_industrial_classification], + [('rmse', 'mae'), univariate_regression_data, fedot_industrial_regression], + [('rmse', 'mae'), multivariate_regression_data, fedot_industrial_regression], +), ids=['clf_uni', 'clf_multi', 'reg_uni', 'reg_multi']) +def test_fit_predict_fedot_industrial(metric_names, data_func, fedot_func): + data = data_func() + fedot_industrial = fedot_func() + fedot_industrial.fit(data) + predict = fedot_industrial.predict(data) + predict_proba = fedot_industrial.predict_proba(data) + metrics = fedot_industrial.get_metrics(predict, + predict_proba, + target=data[1], + metric_names=metric_names) assert predict.shape[0] == data[1].shape[0] assert predict_proba.shape[0] == data[1].shape[0] assert metrics is not None - - -def test_fit_predict_classification_uni(fedot_industrial_classification): - data = univariate_clf_data() - fedot_industrial_classification.fit(data) - predict = fedot_industrial_classification.predict(data) - predict_proba = fedot_industrial_classification.predict_proba(data) - metrics = fedot_industrial_classification.get_metrics(target=data[1]) - np.unique(data[1]) - - assert predict.shape[0] == data[1].shape[0] - assert predict_proba.shape[0] == data[1].shape[0] - assert metrics is not None - - -def test_fit_predict_regression_uni(fedot_industrial_regression): - data = univariate_regression_data() - fedot_industrial_regression.fit(data) - predict = fedot_industrial_regression.predict(data) - - assert predict.shape[0] == data[1].shape[0] - if len(data[1].shape) > 1: - assert predict.shape[1] == data[1].shape[1] - - -def test_fit_predict_regression_multi(fedot_industrial_regression): - data = multivariate_regression_data() - fedot_industrial_regression.fit(data) - predict = fedot_industrial_regression.predict(data) - - assert predict.shape[0] == data[1].shape[0] if len(data[1].shape) > 1: assert predict.shape[1] == data[1].shape[1] @@ -108,7 +106,7 @@ def ts_config(): 'start_val': 36.6}) -def test_generate_ts(fedot_industrial_classification, ts_config): +def test_generate_ts(ts_config): ts = SynthTimeSeriesData(ts_config).generate_ts() assert isinstance(ts, np.ndarray) @@ -124,10 +122,7 @@ def anomaly_config(): } -def test_generate_anomaly_ts( - fedot_industrial_classification, - ts_config, - anomaly_config): +def test_generate_anomaly_ts(ts_config, anomaly_config): init_synth_ts, mod_synth_ts, synth_inters = SynthTimeSeriesData(anomaly_config).generate_anomaly_ts(ts_config) assert len(init_synth_ts) == len(mod_synth_ts) for anomaly_type in synth_inters: @@ -136,16 +131,22 @@ def test_generate_anomaly_ts( assert interval[0] in ts_range and interval[1] in ts_range -def test_finetune(fedot_industrial_classification): - industrial = fedot_industrial_classification - data = univariate_clf_data() - industrial.fit(data) - industrial.finetune(train_data=data, tuning_params={'tuning_timeout': 0.1}) - assert industrial.solver is not None +@pytest.mark.parametrize('data_func, fedot_func', ( + [univariate_clf_data, fedot_industrial_classification], + [multivariate_clf_data, fedot_industrial_classification], + [univariate_regression_data, fedot_industrial_regression], + [multivariate_regression_data, fedot_industrial_regression], +), ids=['clf_uni', 'clf_multi', 'reg_uni', 'reg_multi']) +def test_finetune(data_func, fedot_func): + data = data_func() + fedot_industrial = fedot_func() + fedot_industrial.fit(data) + fedot_industrial.finetune(train_data=data, tuning_params={'tuning_timeout': 0.1}) + assert fedot_industrial.solver is not None -def test_plot_methods(fedot_industrial_classification): - industrial = fedot_industrial_classification +def test_plot_methods(): + industrial = fedot_industrial_classification() data = univariate_clf_data() industrial.fit(data) industrial.predict(data) @@ -156,4 +157,4 @@ def test_plot_methods(fedot_industrial_classification): get_backend() plt.switch_backend("Agg") warnings.filterwarnings("ignore", "Matplotlib is currently using agg") - fedot_industrial_classification.explain() + industrial.explain()