forked from timomernick/pytorch-capsule
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcapsule_network.py
143 lines (114 loc) · 5.5 KB
/
capsule_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torchvision.utils as vutils
import torch.nn.functional as F
from capsule_conv_layer import CapsuleConvLayer
from capsule_layer import CapsuleLayer
class CapsuleNetwork(nn.Module):
def __init__(self,
image_width,
image_height,
image_channels,
conv_inputs,
conv_outputs,
num_primary_units,
primary_unit_size,
num_output_units,
output_unit_size):
super(CapsuleNetwork, self).__init__()
self.reconstructed_image_count = 0
self.image_channels = image_channels
self.image_width = image_width
self.image_height = image_height
self.conv1 = CapsuleConvLayer(in_channels=conv_inputs,
out_channels=conv_outputs)
self.primary = CapsuleLayer(in_units=0,
in_channels=conv_outputs,
num_units=num_primary_units,
unit_size=primary_unit_size,
use_routing=False)
self.digits = CapsuleLayer(in_units=num_primary_units,
in_channels=primary_unit_size,
num_units=num_output_units,
unit_size=output_unit_size,
use_routing=True)
reconstruction_size = image_width * image_height * image_channels
self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, int((reconstruction_size * 2) / 3))
self.reconstruct1 = nn.Linear(int((reconstruction_size * 2) / 3), int((reconstruction_size * 3) / 2))
self.reconstruct2 = nn.Linear(int((reconstruction_size * 3) / 2), reconstruction_size)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.digits(self.primary(self.conv1(x)))
def loss(self, images, input, target, size_average=True):
return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, size_average)
def margin_loss(self, input, target, size_average=True):
batch_size = input.size(0)
# ||vc|| from the paper.
v_mag = torch.sqrt((input**2).sum(dim=2, keepdim=True))
# Calculate left and right max() terms from equation 4 in the paper.
zero = Variable(torch.zeros(1)).cuda()
m_plus = 0.9
m_minus = 0.1
max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1)**2
max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1)**2
# This is equation 4 from the paper.
loss_lambda = 0.5
T_c = target
L_c = T_c * max_l + loss_lambda * (1.0 - T_c) * max_r
L_c = L_c.sum(dim=1)
if size_average:
L_c = L_c.mean()
return L_c
def reconstruction_loss(self, images, input, size_average=True):
# Get the lengths of capsule outputs.
v_mag = torch.sqrt((input**2).sum(dim=2))
# Get index of longest capsule output.
_, v_max_index = v_mag.max(dim=1)
v_max_index = v_max_index.data
# Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
batch_size = input.size(0)
all_masked = [None] * batch_size
for batch_idx in range(batch_size):
# Get one sample from the batch.
input_batch = input[batch_idx]
# Copy only the maximum capsule index from this batch sample.
# This masks out (leaves as zero) the other capsules in this sample.
batch_masked = Variable(torch.zeros(input_batch.size())).cuda()
batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
all_masked[batch_idx] = batch_masked
# Stack masked capsules over the batch dimension.
masked = torch.stack(all_masked, dim=0)
# Reconstruct input image.
masked = masked.view(input.size(0), -1)
output = self.relu(self.reconstruct0(masked))
output = self.relu(self.reconstruct1(output))
output = self.sigmoid(self.reconstruct2(output))
output = output.view(-1, self.image_channels, self.image_height, self.image_width)
# Save reconstructed images occasionally.
if self.reconstructed_image_count % 10 == 0:
if output.size(1) == 2:
# handle two-channel images
zeros = torch.zeros(output.size(0), 1, output.size(2), output.size(3))
output_image = torch.cat([zeros, output.data.cpu()], dim=1)
else:
# assume RGB or grayscale
output_image = output.data.cpu()
vutils.save_image(output_image, "reconstruction.png")
self.reconstructed_image_count += 1
# The reconstruction loss is the sum squared difference between the input image and reconstructed image.
# Multiplied by a small number so it doesn't dominate the margin (class) loss.
error = (output - images).view(output.size(0), -1)
error = error**2
error = torch.sum(error, dim=1) * 0.0005
# Average over batch
if size_average:
error = error.mean()
return error