-
Notifications
You must be signed in to change notification settings - Fork 2
/
logging_a.py
119 lines (93 loc) · 3.13 KB
/
logging_a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import abc
import tqdm
from torch.utils.tensorboard import SummaryWriter
class ProgressMeter(object):
def __init__(self, num_batches, meters,cfg,prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
self.cfg = cfg
def display(self, batch, tqdm_writer=False):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
if not tqdm_writer:
self.cfg.logger.info("\t".join(entries))
else:
tqdm.tqdm.write("\t".join(entries))
def write_to_tensorboard(
self, writer: SummaryWriter, prefix="train", global_step=None
):
for meter in self.meters:
avg = meter.avg
val = meter.val
if meter.write_val:
writer.add_scalar(
f"{prefix}/{meter.name}_val", val, global_step=global_step
)
if meter.write_avg:
writer.add_scalar(
f"{prefix}/{meter.name}_avg", avg, global_step=global_step
)
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
class Meter(object):
@abc.abstractmethod
def __init__(self, name, fmt=":f"):
pass
@abc.abstractmethod
def reset(self):
pass
@abc.abstractmethod
def update(self, val, n=1):
pass
@abc.abstractmethod
def __str__(self):
pass
class AverageMeter(Meter):
""" Computes and stores the average and current value """
def __init__(self, name, fmt=":f", write_val=True, write_avg=True):
self.name = name
self.fmt = fmt
self.reset()
self.write_val = write_val
self.write_avg = write_avg
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class VarianceMeter(Meter):
def __init__(self, name, fmt=":f", write_val=False):
self.name = name
self._ex_sq = AverageMeter(name="_subvariance_1", fmt=":.02f")
self._sq_ex = AverageMeter(name="_subvariance_2", fmt=":.02f")
self.fmt = fmt
self.reset()
self.write_val = False
self.write_avg = True
@property
def val(self):
return self._ex_sq.val - self._sq_ex.val ** 2
@property
def avg(self):
return self._ex_sq.avg - self._sq_ex.avg ** 2
def reset(self):
self._ex_sq.reset()
self._sq_ex.reset()
def update(self, val, n=1):
self._ex_sq.update(val ** 2, n=n)
self._sq_ex.update(val, n=n)
def __str__(self):
return ("{name} (var {avg" + self.fmt + "})").format(
name=self.name, avg=self.avg
)