-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathextract_tda_dataset.py
73 lines (55 loc) · 2.69 KB
/
extract_tda_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Before running the script, run:
# !wget -nc "http://www.timeseriesclassification.com/Downloads/Archives/Univariate2018_arff.zip"
# !unzip -q -n "Univariate2018_arff.zip"
import numpy as np
import pandas as pd
import sys
from pathlib import Path
from src.utils import (get_data_from_directory, get_files_directory_list,
one_hot_encoding, TimeSeriesDataset,get_device, train_clf)
from src.TFE import *
def extract_dataset(dataset_index):
directory_list = get_files_directory_list()
directory_list = sorted(directory_list)
dataset_name = directory_list[dataset_index]
print("Processing dataset " + str(dataset_index) + ": " + dataset_name + "...")
X_train, X_test, y_train, y_test = get_data_from_directory(dataset_name)
X_train = X_train.squeeze()
y_train = y_train.squeeze()
X_test = X_test.squeeze()
y_test = y_test.squeeze()
feature_extractor = TopologicalFeaturesExtractor(
persistence_diagram_extractor=PersistenceDiagramsExtractor(tokens_embedding_dim=10,
tokens_embedding_delay=5,
homology_dimensions=(0, 1, 2),
parallel=True),
persistence_diagram_features=[HolesNumberFeature(),
MaxHoleLifeTimeFeature(),
RelevantHolesNumber(),
AverageHoleLifetimeFeature(),
SumHoleLifetimeFeature(),
PersistenceEntropyFeature(),
SimultaneousAliveHolesFeatue(),
AveragePersistenceLandscapeFeature(),
BettiNumbersSumFeature(),
RadiusAtMaxBNFeature()])
X_train_transformed = feature_extractor.fit_transform(X_train)
X_test_transformed = feature_extractor.fit_transform(X_test)
base_path = Path("./TDA-Datasets/")
if not base_path.exists():
base_path.mkdir()
dataset_path = base_path / dataset_name
dataset_path.mkdir()
np.save(dataset_path / (dataset_name + "_TRAIN"), X_train_transformed)
np.save(dataset_path / (dataset_name + "_TEST"), X_test_transformed)
print("Dataset " + str(dataset_index) + " finished")
def main():
start_index = int(sys.argv[1])
end_index = int(sys.argv[2])
for i in range(start_index, end_index + 1):
try:
extract_dataset(i)
except Exception as e:
print(str(e))
if __name__ == "__main__":
main()