-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
291bcff
commit 3fc8740
Showing
10 changed files
with
231 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
tests/unit/core/operation/decomposition/test_physic_dmd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from fedot_ind.core.operation.optimization.dmd.physic_dmd import piDMD | ||
|
||
|
||
@pytest.fixture | ||
def feature_target(): | ||
return np.random.rand(10, 10), np.random.rand(10, 10) | ||
|
||
|
||
@pytest.mark.parametrize('method', ('exact', 'orthogonal')) | ||
def test_fit_exact(feature_target, method): | ||
decomposer = piDMD(method=method) | ||
features, target = feature_target | ||
|
||
fitted_linear_operator, eigenvals, eigenvectors = decomposer.fit(train_features=features, | ||
train_target=target) | ||
for i in [eigenvals, eigenvectors]: | ||
assert isinstance(i, np.ndarray) | ||
assert isinstance(fitted_linear_operator, Callable) |
36 changes: 36 additions & 0 deletions
36
tests/unit/core/operation/optimization/test_feature_space.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from fedot_ind.core.operation.optimization.FeatureSpace import VarianceSelector | ||
|
||
|
||
@pytest.fixture | ||
def model_data(): | ||
return dict(quantile=np.random.rand(10, 10), | ||
signal=np.random.rand(10, 10), | ||
topological=np.random.rand(10, 10)) | ||
|
||
|
||
def test_get_best_model(model_data): | ||
selector = VarianceSelector(models=model_data) | ||
best_model = selector.get_best_model() | ||
assert isinstance(best_model, str) | ||
|
||
|
||
def test_transform(model_data): | ||
selector = VarianceSelector(models=model_data) | ||
projected = selector.transform(model_data=model_data['quantile'], | ||
principal_components=np.random.rand(10, 2)) | ||
assert isinstance(projected, np.ndarray) | ||
|
||
|
||
def test_select_discriminative_features(model_data): | ||
selector = VarianceSelector(models=model_data) | ||
projected = selector.transform(model_data=model_data['quantile'], | ||
principal_components=np.random.rand(10, 2)) | ||
|
||
discriminative_feature = selector.select_discriminative_features(model_data=pd.DataFrame(model_data['quantile']), | ||
projected_data=projected) | ||
|
||
assert isinstance(discriminative_feature, dict) |
51 changes: 51 additions & 0 deletions
51
tests/unit/core/operation/optimization/test_structure_optimization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
from fedot_ind.core.architecture.experiment.nn_experimenter import ClassificationExperimenter, FitParameters | ||
from fedot_ind.core.operation.optimization.structure_optimization import SVDOptimization, SFPOptimization | ||
|
||
NUM_SAMPLES = 100 | ||
INPUT_SIZE = 10 | ||
OUTPUT_SIZE = 5 | ||
BATCH_SIZE = 32 | ||
|
||
|
||
class DummyModel(nn.Module): | ||
def __init__(self, input_size, output_size): | ||
super(DummyModel, self).__init__() | ||
self.linear = nn.Linear(input_size, output_size) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
|
||
class SimpleDataset(Dataset): | ||
def __init__(self, num_samples, input_size, output_size): | ||
self.inputs = torch.rand((num_samples, input_size)) | ||
self.targets = torch.randint(0, output_size, (num_samples,)) | ||
|
||
def __len__(self): | ||
return len(self.inputs) | ||
|
||
def __getitem__(self, index): | ||
return self.inputs[index], self.targets[index] | ||
|
||
|
||
@pytest.fixture | ||
def dummy_data_loader(): | ||
dataset = SimpleDataset(NUM_SAMPLES, INPUT_SIZE, OUTPUT_SIZE) | ||
shuffle = True | ||
return DataLoader(dataset, | ||
batch_size=BATCH_SIZE, | ||
shuffle=shuffle) | ||
|
||
|
||
@pytest.fixture() | ||
def solver(): | ||
model = DummyModel(INPUT_SIZE, OUTPUT_SIZE) | ||
experimenter = ClassificationExperimenter(model=model, | ||
metric='accuracy', | ||
device='cpu') | ||
return experimenter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import warnings | ||
|
||
from matplotlib import get_backend | ||
|
||
from fedot_ind.tools.synthetic.anomaly_generator import AnomalyGenerator | ||
import pytest | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
@pytest.fixture | ||
def config(): | ||
return {'dip': {'level': 20, | ||
'number': 5, | ||
'min_anomaly_length': 10, | ||
'max_anomaly_length': 20}, | ||
'peak': {'level': 2, | ||
'number': 5, | ||
'min_anomaly_length': 5, | ||
'max_anomaly_length': 10}, | ||
'decrease_dispersion': {'level': 70, | ||
'number': 2, | ||
'min_anomaly_length': 10, | ||
'max_anomaly_length': 15}, | ||
'increase_dispersion': {'level': 50, | ||
'number': 2, | ||
'min_anomaly_length': 10, | ||
'max_anomaly_length': 15}, | ||
'shift_trend_up': {'level': 10, | ||
'number': 2, | ||
'min_anomaly_length': 10, | ||
'max_anomaly_length': 20}, | ||
'add_noise': {'level': 80, | ||
'number': 2, | ||
'noise_type': 'uniform', | ||
'min_anomaly_length': 10, | ||
'max_anomaly_length': 20} | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def synthetic_ts(): | ||
return {'ts_type': 'sin', | ||
'length': 1000, | ||
'amplitude': 10, | ||
'period': 500} | ||
|
||
|
||
def test_generate(config, synthetic_ts): | ||
# switch to non-Gui, preventing plots being displayed | ||
# suppress UserWarning that agg cannot show plots | ||
curr_backend = get_backend() | ||
plt.switch_backend("Agg") | ||
warnings.filterwarnings("ignore", "Matplotlib is currently using agg") | ||
|
||
generator = AnomalyGenerator(config=config) | ||
init_synth_ts, mod_synth_ts, synth_inters = generator.generate(time_series_data=synthetic_ts, | ||
plot=True, | ||
overlap=0.1) | ||
|
||
assert len(init_synth_ts) == len(mod_synth_ts) | ||
for anomaly_type in synth_inters: | ||
for interval in synth_inters[anomaly_type]: | ||
ts_range = range(len(init_synth_ts)) | ||
assert interval[0] in ts_range and interval[1] in ts_range |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from fedot_ind.tools.synthetic.ts_datasets_generator import TimeSeriesDatasetsGenerator | ||
|
||
|
||
def test_generate_data(): | ||
generator = TimeSeriesDatasetsGenerator(num_samples=80, | ||
max_ts_len=50, | ||
n_classes=3, | ||
test_size=0.5) | ||
(X_train, y_train), (X_test, y_test) = generator.generate_data() | ||
|
||
assert X_train.shape[0] == X_test.shape[0] | ||
assert X_train.shape[1] == X_test.shape[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from fedot_ind.tools.synthetic.ts_generator import TimeSeriesGenerator | ||
|
||
|
||
@pytest.fixture | ||
def config(): | ||
return dict(random_walk={'ts_type': 'random_walk', | ||
'length': 1000, | ||
'start_val': 36.6}, | ||
sin={'ts_type': 'sin', | ||
'length': 1000, | ||
'amplitude': 10, | ||
'period': 500}, | ||
auto_regression={'ts_type': 'auto_regression', | ||
'length': 1000, | ||
'ar_params': [0.5, -0.3, 0.2], | ||
'initial_values': None}, | ||
smooth_normal={'ts_type': 'smooth_normal', | ||
'length': 1000, | ||
'window_size': 300} | ||
|
||
) | ||
|
||
|
||
@pytest.mark.parametrize('kind', ('random_walk', 'sin', 'auto_regression', 'smooth_normal')) | ||
def test_get_ts(config, kind): | ||
specific_config = config[kind] | ||
generator = TimeSeriesGenerator(params=specific_config) | ||
ts = generator.get_ts() | ||
assert isinstance(ts, np.ndarray) | ||
assert len(ts) == specific_config['length'] |