-
Notifications
You must be signed in to change notification settings - Fork 15
/
utils.py
121 lines (102 loc) · 3.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import time
import sys
import math
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import csv
import statistics as stat
term_width = int(30)
TOTAL_BAR_LENGTH = 86.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
# Reset for new bar
if current == 0:
begin_time = time.time()
cur_len = int(TOTAL_BAR_LENGTH * current / total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
#L.append(' Step: %s' % format_time(step_time))
L.append(' Time: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
def init_params(net):
# Init layer parameters
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
def acc_stat(accuracy):
length = 10
a = accuracy[-length]
a_stat = [(temp- stat.mean(a))/stat.stdev(a) for temp in a]
return sum(1 for temp in a_stat if abs(temp) > 1)/length < 0.5
def log_row(logname, row):
with open(logname, 'a') as logfile:
logwriter = csv.writer(logfile, delimiter=',')
logwriter.writerow(row)