From 287502683f20ec371043bce316908e4777a68b69 Mon Sep 17 00:00:00 2001 From: v1docq Date: Thu, 2 May 2024 15:09:38 +0300 Subject: [PATCH] minor fixes --- .../benchmark_example/time_series_uni_clf_benchmark.py | 2 +- fedot_ind/core/metrics/metrics_implementation.py | 8 ++++---- requirements.txt | 3 ++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/benchmark_example/time_series_uni_clf_benchmark.py b/examples/benchmark_example/time_series_uni_clf_benchmark.py index 7f4c86a73..336152578 100644 --- a/examples/benchmark_example/time_series_uni_clf_benchmark.py +++ b/examples/benchmark_example/time_series_uni_clf_benchmark.py @@ -10,7 +10,7 @@ 'logging_level': 10, 'n_jobs': 2, 'early_stopping_iterations': 5, - 'initial_assumption': PipelineBuilder().add_node('chronos_extractor').add_node('logit'), + 'initial_assumption': PipelineBuilder().add_node('quantile_extractor').add_node('logit'), 'early_stopping_timeout': 75} if __name__ == "__main__": diff --git a/fedot_ind/core/metrics/metrics_implementation.py b/fedot_ind/core/metrics/metrics_implementation.py index abb4d5dda..81ee4b05e 100644 --- a/fedot_ind/core/metrics/metrics_implementation.py +++ b/fedot_ind/core/metrics/metrics_implementation.py @@ -171,11 +171,11 @@ def metric(self) -> float: return accuracy_score(y_true=self.target, y_pred=self.predicted_labels) -def MASE(A, F, y_train): +def mase(A, F, y_train): return mean_absolute_scaled_error(A, F, y_train=y_train) -def SMAPE(a, f, _=None): +def smape(a, f, _=None): return 1 / len(a) * np.sum(2 * np.abs(f - a) / (np.abs(a) + np.abs(f)) * 100) @@ -221,8 +221,8 @@ def rmse(y_true, y_pred): 'rmse': rmse, 'mae': mean_absolute_error, 'median_absolute_error': median_absolute_error, - 'smape': SMAPE, - 'mase': MASE + 'smape': smape, + 'mase': mase } df = pd.DataFrame({name: func(target, labels) for name, func in metric_dict.items() diff --git a/requirements.txt b/requirements.txt index 5adb834fd..67583a559 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,8 @@ fastai==2.7.14 distributed datasetsforecast==0.0.8 tensorly==0.8.1 -torch==2.2.0 +torch~=2.2.0 +torchvision~=0.13.1+cu113 torchvision==0.17.0 statsforecast==1.5.0 chardet==5.2.0