Skip to content

Commit

Permalink
Api fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
v1docq committed Jan 26, 2024
1 parent af707e1 commit 66ee030
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions fedot_ind/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, **kwargs):
self.preprocessing = kwargs.get('industrial_preprocessing', False)
self.backend_method = kwargs.get('backend', 'cpu')
self.RAF_workers = kwargs.get('RAF_workers', None)

if self.output_folder is None:
self.output_folder = default_path_to_save_results
Path(self.output_folder).mkdir(parents=True, exist_ok=True)
Expand All @@ -89,9 +90,11 @@ def __init__(self, **kwargs):
self.predicted_labels = None
self.predicted_probs = None
self.predict_data = None
self.config_dict = None
self.ensemble_solver = None
self.config_dict = kwargs
self.config_dict['available_operations'] = kwargs.get('available_operations',
default_industrial_availiable_operation(
self.config_dict['problem']))
self.config_dict['optimizer'] = kwargs.get('optimizer', IndustrialEvoOptimizer)
self.__init_experiment_setup()

def __init_experiment_setup(self):
Expand All @@ -101,9 +104,6 @@ def __init_experiment_setup(self):
backend_method_current, backend_scipy_current = BackendMethods(self.backend_method).backend
globals()['backend_methods'] = backend_method_current
globals()['backend_scipy'] = backend_scipy_current
self.config_dict['available_operations'] = default_industrial_availiable_operation(self.config_dict['problem'])

self.config_dict['optimizer'] = IndustrialEvoOptimizer

def __init_solver(self):
self.logger.info('Initialising Industrial Repository')
Expand Down Expand Up @@ -132,13 +132,9 @@ def _preprocessing_strategy(self, input_data):
batch_timeout = round(self.config_dict['timeout'] / FEDOT_WORKER_TIMEOUT_PARTITION)
self.config_dict['timeout'] = batch_timeout
self.logger.info(f'Batch_size - {batch_size}. Number of batches - {self.RAF_workers}')
self.ensemble_solver = RAFensembler(composing_params=self.config_dict, n_splits=self.RAF_workers,
batch_size=batch_size)
self.logger.info(f'Number of AutoMl models in ensemble - {self.ensemble_solver.n_splits}')
self.ensemble_solver.fit(input_data)
self.solver = self.ensemble_solver
else:
self.preprocessing = False
self.solver = RAFensembler(composing_params=self.config_dict, n_splits=self.RAF_workers,
batch_size=batch_size)
self.logger.info(f'Number of AutoMl models in ensemble - {self.solver.n_splits}')

def fit(self,
input_data,
Expand All @@ -161,10 +157,7 @@ def fit(self,
self.solver = self.__init_solver()
if self.preprocessing:
self._preprocessing_strategy(input_data)
fitted_pipeline = self.ensemble_solver
else:
fitted_pipeline = self.solver.fit(input_data)
return fitted_pipeline
return self.solver.fit(input_data)

def predict(self,
predict_data,
Expand Down

0 comments on commit 66ee030

Please sign in to comment.