Skip to content

Commit

Permalink
added explain method to main api module
Browse files Browse the repository at this point in the history
  • Loading branch information
technocreep committed Jan 18, 2024
1 parent 7ad72a3 commit 487bd7b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 33 deletions.
31 changes: 29 additions & 2 deletions fedot_ind/api/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import logging
from pathlib import Path

from fedot.api.main import Fedot
from fedot.core.pipelines.node import PipelineNode
from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.pipelines.tuning.tuner_builder import TunerBuilder
from fedot.core.repository.metrics_repository import ClassificationMetricsEnum
from golem.core.tuning.simultaneous import SimultaneousTuner

from fedot_ind.api.utils.checkers_collections import DataCheck
from fedot_ind.api.utils.path_lib import DEFAULT_PATH_RESULTS as default_path_to_save_results
from fedot_ind.core.architecture.settings.computational import BackendMethods
from fedot_ind.core.ensemble.random_automl_forest import RAFensembler
from fedot_ind.core.operation.transformation.splitter import TSTransformer
from fedot_ind.core.repository.constanst_repository import FEDOT_WORKER_NUM, BATCH_SIZE_FOR_FEDOT_WORKER, \
from fedot_ind.core.repository.constanst_repository import BATCH_SIZE_FOR_FEDOT_WORKER, FEDOT_WORKER_NUM, \
FEDOT_WORKER_TIMEOUT_PARTITION
from fedot_ind.core.repository.initializer_industrial_models import IndustrialModels
from fedot_ind.tools.explain.explain import PointExplainer
from fedot_ind.tools.synthetic.anomaly_generator import AnomalyGenerator
from fedot_ind.tools.synthetic.ts_generator import TimeSeriesGenerator

Expand Down Expand Up @@ -317,7 +320,31 @@ def plot_operation_distribution(self, mode: str = 'total'):
show_fitness=True, dpi=100)

def explain(self, **kwargs):
raise NotImplementedError()
""" Explain model's prediction via time series points perturbation
Args:
samples: int, ``default=1``. Number of samples to explain.
window: int, ``default=5``. Window size for perturbation.
metric: str ``default='rmse'``. Distance metric for perturbation impact assessment.
threshold: int, ``default=90``. Threshold for perturbation impact assessment.
name: str, ``default='test'``. Name of the dataset to be placed on plot.
"""
methods = {'point': PointExplainer,
'shap': NotImplementedError,
'lime': NotImplementedError}

explainer = methods[kwargs.get('method', 'point')](model=self.solver,
features=self.predict_data.features,
target=self.predict_data.target)
metric = kwargs.get('metric', 'rmse')
window = kwargs.get('window', 5)
samples = kwargs.get('samples', 1)
threshold = kwargs.get('threshold', 90)
name = kwargs.get('name', 'test')

explainer.explain(n_samples=samples, window=window, method=metric)
explainer.visual(threshold=threshold, name=name)

def generate_ts(self, ts_config: dict):
"""
Expand Down
52 changes: 21 additions & 31 deletions fedot_ind/tools/explain/explain.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
import math

# import lime
# import lime.lime_tabular
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# import shap
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from tqdm import tqdm

from fedot_ind.tools.explain.distances import DistanceTypes


class PointExplainer:
class Explainer:
def __init__(self, model, features, target):
self.picked_target = None
self.picked_feature = None
self.model = model
self.features = features
self.target = target

def explain(self, **kwargs):
pass

@staticmethod
def predict_proba(model, features, target):
if hasattr(model, 'solver'):
model.solver.test_features = None
base_proba_ = model.predict_proba(
features=pd.DataFrame(features), target=target)
else:
base_proba_ = model.predict_proba(X=features)
return base_proba_


class PointExplainer(Explainer):
def __init__(self, model, features, target):
super().__init__(model, features, target)
self.picked_target = None
self.picked_feature = None

self.scaled_vector = None
self.window_length = None

Expand Down Expand Up @@ -89,16 +104,6 @@ def replace_values(features: np.ndarray, window_len: int, i: int):
features[idx, i] = mean_ts
return features

@staticmethod
def predict_proba(model, features, target):
if hasattr(model, 'solver'):
model.solver.test_features = None
base_proba_ = model.predict_proba(
features=pd.DataFrame(features), target=target)
else:
base_proba_ = model.predict_proba(X=features)
return base_proba_

@staticmethod
def select(features_, target_, n_samples_: int = 3):
selected_df = pd.DataFrame()
Expand Down Expand Up @@ -174,18 +179,3 @@ def plot_importance(self, thr=90, name='dataset'):
plt.colorbar(scal_map, cax=cbar_ax)
plt.tight_layout()
plt.show()


# class ShapExplainer:
# def __init__(self, model, features, target, prediction):
# self.model = model
# self.features = features
# self.target = target
# self.prediction = prediction
#
# def explain(self, n_samples: int = 5):
# X_test = self.features
#
# explainer = shap.KernelExplainer(self.model.predict, X_test, n_samples=n_samples)
# shap_values = explainer.shap_values(X_test.iloc[:n_samples, :])
# shap.summary_plot(shap_values, X_test.iloc[:n_samples, :], plot_type="bar")

0 comments on commit 487bd7b

Please sign in to comment.