forked from Stanpie3/importance_sampling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon_utils.py
91 lines (62 loc) · 2.39 KB
/
common_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def _has_len(obj):
return hasattr(obj, '__len__')
class Average():
def __init__(self):
self.reset()
def reset(self):
self.count = 0
self.sum = 0
def avg(self):
return self.sum / float(self.count)
def update(self, val, n=1):
assert(n > 0)
print(self.count)
self.sum += val * n
self.count += n
class Accumulator():
def __init__(self):
self.items = dict()
def __call__(self,**kwargs):
for key, i in kwargs.items():
if not (key in self.items):
self.items[key] = Average()
d , n = i if _has_len(i) and len(i)>1 else (i,1)
self.items[key].update(d , n)
def __str__(self):
return "["+ ", ".join([f"{key}: {i.avg()}" for key, i in self.items.items()]) + "]"
class CallBack:
def __init__(self, eval_fn, name=None):
self.eval_fn = eval_fn
self.train_losses = []
self.train_accs = []
self.train_w_losses = []
self.train_max_p_i = []
self.train_num_unique_points = []
self.val_losses = []
self.val_accs = []
self.n_un = []
def last_info(self):
return {'loss_train': f'{self.train_losses[-1]:.3f}',
'acc_train': f'{self.train_accs[-1]:.3f}',
'w_loss_train': f'{self.train_w_losses[-1]:.3f}',
'loss_val': f'{self.val_losses[-1]:.3f}',
'acc_val': f'{self.val_accs[-1]:.3f}',
'n_un': f'{self.n_un[-1]:.3f}',
}
def __call__(self, model, val_dataloader, loss_fn,
epoch_loss=None, epoch_acc=None, epoch_weighted_loss=None, epoch_max_p_i_s=None, epoch_num_unique_points_s=None, n_un=None):
self.train_losses.append(epoch_loss)
self.train_accs.append(epoch_acc)
self.train_w_losses.append(epoch_weighted_loss)
self.train_max_p_i.append(epoch_max_p_i_s)
self.train_num_unique_points.append(epoch_num_unique_points_s)
loss_val, acc_val = self.eval_fn(model, val_dataloader, loss_fn)
self.val_losses.append(loss_val)
self.val_accs.append(acc_val)
self.n_un.append(n_un)
return self.last_info()
acc = Accumulator()
acc(accuracy= (10,20), data=10,newo =6.6)
acc(accuracy= (10,10), data=5,newo =6.6)
acc(accuracy= (10,10), data=5,newo =6.6)
print(acc)