-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel.py
126 lines (103 loc) · 4.38 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as V
from torch import autograd
import numpy as np
class Net(nn.Module):
def __init__(self, args):
super(Net, self).__init__()
#### SELF ARGS ####
self.dropout = args.dropout
# Model optimizer
self.optimizer = None
#### MODEL PARAMS ####
self.fc1 = nn.Linear(784, 400)
self.fc1_drop = nn.Dropout(0.5) if self.dropout else nn.Dropout(0)
self.fc2 = nn.Linear(400, 400)
self.fc2_drop = nn.Dropout(0.5) if self.dropout else nn.Dropout(0)
# self.fc3 = nn.Linear(400, 400)
# self.fc3_drop = nn.Dropout(0.5) if self.dropout else nn.Dropout(0)
self.fc_final = nn.Linear(400, 10)
# Init Matrix which will get Fisher Matrix
self.Fisher = {}
# Self Params
self.params = [param for param in self.parameters()]
def forward(self, x):
# Flatten input
x = x.view(-1, 784)
# Keep it for dropout
# FIRST FC
x_relu = F.relu(self.fc1(x))
x = self.fc1_drop(x_relu)
# SECOND FC
x_relu = F.relu(self.fc2(x))
x = self.fc2_drop(x_relu)
# # THIRD FC
# x_relu = F.relu(self.fc3(x))
# x = self.fc3_drop(x_relu)
# Classification
x = self.fc_final(x)
return x
def estimate_fisher(self, dataset, sample_size, batch_size=64):
# Get loglikelihoods from data
self.F_accum = []
for v, _ in enumerate(self.params):
self.F_accum.append(np.zeros(list(self.params[v].size())))
data_loader = dataset
loglikelihoods = []
for x, y in data_loader:
#print(x.size(), y.size())
x = x.view(batch_size, -1)
x = V(x).cuda() if self._is_on_cuda() else V(x)
y = V(y).cuda() if self._is_on_cuda() else V(y)
loglikelihoods.append(F.log_softmax(self(x), dim=1)[range(batch_size), y.data])
if len(loglikelihoods) >= sample_size // batch_size:
break
#loglikelihood = torch.cat(loglikelihoods).mean(0)
loglikelihood = torch.cat(loglikelihoods).mean(0)
loglikelihood_grads = autograd.grad(loglikelihood, self.parameters(),retain_graph=True)
#print("FINISHED GRADING", len(loglikelihood_grads))
for v in range(len(self.F_accum)):
#print(len(self.F_accum))
torch.add(torch.Tensor((self.F_accum[v])), torch.pow(loglikelihood_grads[v], 2).data)
for v in range(len(self.F_accum)):
self.F_accum[v] /= sample_size
parameter_names = [
n.replace('.', '__') for n, p in self.named_parameters()
]
#print("RETURNING", len(parameter_names))
return {n: g for n, g in zip(parameter_names, self.F_accum)}
def consolidate(self, fisher):
for n, p in self.named_parameters():
n = n.replace('.', '__')
self.register_buffer('{}_estimated_mean'.format(n), p.data.clone())
#print(dir(fisher[n].data))
self.register_buffer('{}_estimated_fisher'
.format(n), fisher[n].data)
def ewc_loss(self, lamda, cuda=False):
try:
losses = []
for n, p in self.named_parameters():
# retrieve the consolidated mean and fisher information.
n = n.replace('.', '__')
mean = getattr(self, '{}_estimated_mean'.format(n))
fisher = getattr(self, '{}_estimated_fisher'.format(n))
# wrap mean and fisher in Vs.
mean = V(mean)
fisher = V(fisher.data)
# calculate a ewc loss. (assumes the parameter's prior as
# gaussian distribution with the estimated mean and the
# estimated cramer-rao lower bound variance, which is
# equivalent to the inverse of fisher information)
losses.append((fisher * (p-mean)**2).sum())
return (lamda/2)*sum(losses)
except AttributeError:
# ewc loss is 0 if there's no consolidated parameters.
return (
V(torch.zeros(1)).cuda() if cuda else
V(torch.zeros(1))
)
def _is_on_cuda(self):
return next(self.parameters()).is_cuda