Skip to content

Commit

Permalink
added shrinkage to riemann extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
technocreep committed Mar 21, 2024
1 parent 8981efc commit b68a328
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
43 changes: 41 additions & 2 deletions fedot_ind/core/models/manifold/riemann_embeding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from pyriemann.estimation import XdawnCovariances, Covariances
import pandas as pd
from pyriemann.estimation import XdawnCovariances, Covariances, Shrinkage
from pyriemann.tangentspace import TangentSpace
from typing import Optional
from sklearn.utils.extmath import softmax
Expand Down Expand Up @@ -66,13 +67,16 @@ def __init__(self, params: Optional[OperationParameters] = None):
def _init_spaces(self):
self.covarince_transformer = Covariances(estimator='scm')
self.tangent_projector = TangentSpace(metric=self.distance_metric)
self.shinkage = Shrinkage()

def extract_riemann_features(self, input_data: InputData) -> InputData:
if not self.fit_stage:
SPD = self.covarince_transformer.transform(input_data.features)
SPD = self.shinkage.transform(SPD)
ref_point = self.tangent_projector.transform(SPD)
else:
SPD = self.covarince_transformer.fit_transform(input_data.features, input_data.target)
SPD = self.shinkage.fit_transform(SPD)
ref_point = self.tangent_projector.fit_transform(SPD)
self.fit_stage = False
return ref_point
Expand All @@ -81,9 +85,11 @@ def extract_centroid_distance(self, input_data: InputData):
self.classes_ = np.unique(input_data.target)
if not self.fit_stage:
SPD = self.covarince_transformer.transform(input_data.features)
SPD = self.shinkage.transform(SPD)
else:
SPD = self.covarince_transformer.fit_transform(input_data.features, input_data.target)
self.covmeans_ = [mean_covariance(SPD[input_data.target == ll], metric=self.covariance_metric)
SPD = self.shinkage.fit_transform(SPD)
self.covmeans_ = [mean_covariance(SPD[input_data.target.flatten() == ll], metric=self.covariance_metric)
for ll in self.classes_]

n_centroids = len(self.covmeans_)
Expand All @@ -106,3 +112,36 @@ def _transform(self, input_data: InputData) -> np.array:
feature_matrix = self.extraction_func(input_data)
self.predict = self._clean_predict(feature_matrix)
return self.predict


if __name__ == "__main__":
from fedot.core.pipelines.pipeline_builder import PipelineBuilder
from sklearn.metrics import accuracy_score
from fedot_ind.api.utils.data import init_input_data
from fedot_ind.core.repository.initializer_industrial_models import IndustrialModels
from sklearn.model_selection import train_test_split

path_x = '/Users/technocreep/Desktop/Working-Folder/fedot-industrial/Fedot.Industrial/0_lavence/valence_data/pt1/X_eeg_4_24_old_resave.npy'
path_y = '/Users/technocreep/Desktop/Working-Folder/fedot-industrial/Fedot.Industrial/0_lavence/valence_data/pt1/y_old.npy'
X = np.load(path_x)
y = np.load(path_y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)
x_train = pd.DataFrame([[pd.Series(i) for i in k] for k in X_train])
x_test = pd.DataFrame([[pd.Series(i) for i in k] for k in X_test])

init_train = init_input_data(x_train, y_train)
init_test = init_input_data(x_test, y_test)

with IndustrialModels():
# riemann_extractor
# pipeline = PipelineBuilder().add_node('eigen_basis').add_node('quantile_extractor').add_node('rf').build()
pipeline = PipelineBuilder().add_node('riemann_extractor', params={'n_filter': 3})\
.add_node('quantile_extractor')\
.add_node('mlp')\
.build()
pipeline.fit(init_train)
pred = pipeline.predict(init_test)
acc = accuracy_score(y_test, np.round(pred.predict))
print(acc)

5 changes: 3 additions & 2 deletions fedot_ind/core/repository/model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
*

# FastTopologicalFeaturesImplementation
from fedot.core.operations.evaluation.operation_implementations.data_operations.topological.topological_extractor import \
TopologicalFeaturesImplementation
# from fedot.core.operations.evaluation.operation_implementations.data_operations.topological.topological_extractor import \
# TopologicalFeaturesImplementation
from fedot.core.operations.evaluation.operation_implementations.data_operations.topological.fast_topological_extractor import TopologicalFeaturesImplementation
from fedot.core.operations.evaluation.operation_implementations.data_operations.ts_transformations import \
ExogDataTransformationImplementation, GaussianFilterImplementation, LaggedTransformationImplementation, \
SparseLaggedTransformationImplementation, TsSmoothingImplementation
Expand Down

0 comments on commit b68a328

Please sign in to comment.