This repository has been archived by the owner on Jan 27, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
stat_unit.py
112 lines (83 loc) · 3.98 KB
/
stat_unit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from utils.ArticlesHandler import ArticlesHandler
from utils import solve, embedding_matrix_2_kNN, get_rate, accuracy, precision, recall, f1_score
from utils.Trainer_graph import TrainerGraph
from utils import Config, accuracy_sentence_based
import time
import numpy as np
# from utils.postprocessing.SelectLabelsPostprocessor import SelectLabelsPostprocessor
from utils.Trainer_graph import TrainerGraph
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
config = Config('config/')
all_accuracys = []
for repeat in range(config.stats.iteration_stat):
debut = time.time()
handler = ArticlesHandler(config)
# Save in a pickle file. To open, use the pickle dataloader.
#handler.articles.save("../Dataset/train_fake.pkl")
# Only recompute labels:
# handler.articles.compute_labels()
C = handler.get_tensor()
# select_labels = SelectLabelsPostprocessor(config, handler.articles)
# handler.add_postprocessing(select_labels, "label-selection")
# handler.postprocess()
labels = np.array(handler.articles.labels)
all_labels = np.array(handler.articles.labels_untouched)
print(labels,all_labels)
if config.learning.method_learning == "FaBP":
assert max(labels) == 2, "FaBP can only be used for binary classification."
print(len(all_labels), "Articles")
if config.graph.node_features == config.embedding.method_decomposition_embedding:
C_nodes = C.copy()
else:
config.embedding.set("method_decomposition_embedding", config.graph.method_create_graph)
C_nodes = handler.get_tensor()
fin = time.time()
print("get tensor and decomposition done", fin - debut)
sentence_to_articles = None if not config.graph.sentence_based else handler.articles.sentence_to_article
graph = embedding_matrix_2_kNN(C, k=config.graph.num_nearest_neighbours,
sentence_to_articles=sentence_to_articles).toarray()
fin3 = time.time()
print("KNN done", fin3 - fin)
if config.learning.method_learning == "FaBP":
# classe b(i){> 0, < 0} means i ∈ {“+”, “-”}
beliefs = solve(graph, labels[:])
fin4 = time.time()
print("FaBP done", fin4 - fin3)
elif config.learning.method_learning in ["SVM", "RF"]:
training_mask = labels > 0
test_mask = labels == 0
training_set = C[training_mask, :]
l = labels[training_mask]
l[l == 2] = -1
print("Fitting")
if config.learning.method_learning == "SVM":
clf = svm.SVC(gamma='scale')
else: # Random forest
clf = RandomForestClassifier(n_estimators=100, max_depth=2, random_state=0)
clf.fit(training_set, l)
beliefs = labels
beliefs[test_mask] = clf.predict(C[test_mask, :])
beliefs[beliefs == -1] = 2
else:
trainer = TrainerGraph(C_nodes, graph, all_labels, labels)
beliefs, acc_test = trainer.train()
print(acc_test)
fin4 = time.time()
print("Learning done", fin4 - fin3)
# Compute hit rate
# TODO: changer pour le multiclasse...
#beliefs[beliefs >= 0] = 1
#beliefs[beliefs < 0] = 2
if config.graph.sentence_based:
acc = accuracy_sentence_based(handler, beliefs)
else:
#print(all_labels, beliefs)
acc = accuracy_score(all_labels, beliefs)
#print("Accuracy", acc)
all_accuracys.append(acc)
print(np.mean(all_accuracys))
print(np.std(all_accuracys))
np.save("/media/benamira/19793564030D4273/MCsBackup/3A/OMA/Projet/Stats/27_avril/"+str(config.embedding.method_decomposition_embedding)+"_mean"+str(config.graph.num_nearest_neighbours)+ str(config.stats.ratio_labeled)+".npy", np.mean(all_accuracys))
np.save("/media/benamira/19793564030D4273/MCsBackup/3A/OMA/Projet/Stats/27_avril/"+str(config.embedding.method_decomposition_embedding)+"_std"+str(config.graph.num_nearest_neighbours)+str(config.stats.ratio_labeled)+".npy", np.std(all_accuracys))