diff --git a/fedot_ind/api/main.py b/fedot_ind/api/main.py index 95fc2f5c9..676d24793 100644 --- a/fedot_ind/api/main.py +++ b/fedot_ind/api/main.py @@ -175,10 +175,9 @@ def __abstract_predict(self, predict_mode: str): custom_predict = self.manager.solver.predict if self.manager.industrial_config.is_default_fedot_context \ else self.manager.industrial_config.strategy.predict - predict_function = Either(value=custom_predict, - monoid=['prob', predict_mode == 'labels']).either( - left_function=lambda prob_func: self.manager.solver.predict_proba if have_proba_output else self.manager.solver.predict, - right_function=lambda label_func: label_func) + predict_function = Either(value=custom_predict, monoid=['prob', predict_mode == 'labels']).either( + left_function=lambda prob_func: self.manager.solver.predict_proba + if have_proba_output else self.manager.solver.predict, right_function=lambda label_func: label_func) def _inverse_encoder_transform(predict): predicted_labels = self.target_encoder.inverse_transform(predict) @@ -425,14 +424,20 @@ def save_optimization_history(self, return_history: bool = False): f"optimization_history.json") def save_best_model(self): - Either(value=self.manager.solver, - monoid=[self.manager.solver, self.manager.condition_check.solver_is_fedot_class(self.manager.solver)]).either( - left_function=lambda pipeline: pipeline.save(path=self.manager.output_folder, - create_subdir=True, - is_datetime_in_path=True), - right_function=lambda solver: solver.current_pipeline.save(path=self.manager.output_folder, - create_subdir=True, - is_datetime_in_path=True)) + Either( + value=self.manager.solver, + monoid=[ + self.manager.solver, + self.manager.condition_check.solver_is_fedot_class( + self.manager.solver)]).either( + left_function=lambda pipeline: pipeline.save( + path=self.manager.output_folder, + create_subdir=True, + is_datetime_in_path=True), + right_function=lambda solver: solver.current_pipeline.save( + path=self.manager.output_folder, + create_subdir=True, + is_datetime_in_path=True)) def explain(self, explaing_config: dict = {}): """Explain model's prediction via time series points perturbation