-
Notifications
You must be signed in to change notification settings - Fork 3
/
indices_pooling.py
198 lines (163 loc) · 7.07 KB
/
indices_pooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#title :pooling.py
#description :implementation of max_pooling_with_argmax and unpooling for tensorflow and keras
#author :yselivonchyk
#date :20190405
#modeldetails :non-sequential model, parallel training as a multiple output model
#from github:https://github.com/yselivonchyk/Tensorflow_WhatWhereAutoencoder/blob/master/pooling.py
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.python.keras.engine.base_layer import Layer
# Tensorflow implementations of max_pooling and unpooling
def max_pool_with_argmax(net, strides, name=None, ksize=[1, 2, 2, 1]):
assert isinstance(ksize, list) or isinstance(ksize, int)
assert isinstance(strides, list) or isinstance(strides, int)
ksize = ksize if isinstance(ksize, list) else [1, ksize, ksize, 1]
strides = strides if isinstance(strides, list) else [1, strides, strides, 1]
with tf.name_scope('MaxPoolArgMax'):
net, mask = tf.nn.max_pool_with_argmax(
net,
ksize=ksize,
strides=strides,
padding='SAME',name=name)
return net, mask
def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
"""
Unpooling layer after max_pool_with_argmax.
Args:
pool: max pooled output tensor
ind: argmax indices
ksize: ksize is the same as for the pool
Return:
unpool: unpooling tensor
"""
with tf.variable_scope(scope):
input_shape = tf.shape(pool)
output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]
flat_input_size = tf.reduce_prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]
pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b1 = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b1, ind_], 1)
ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
ret = tf.reshape(ret, output_shape)
set_input_shape = pool.get_shape()
set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
ret.set_shape(set_output_shape)
return ret
def _upsample_along_axis(volume, axis, stride, mode='ZEROS'):
shape = volume.get_shape().as_list()
assert mode in ['COPY', 'ZEROS']
assert 0 <= axis < len(shape)
target_shape = shape[:]
target_shape[axis] *= stride
padding = tf.zeros(shape, dtype=volume.dtype) if mode == 'ZEROS' else volume
parts = [volume] + [padding for _ in range(stride - 1)]
volume = tf.concat(parts, min(axis+1, len(shape)-1))
volume = tf.reshape(volume, target_shape)
return volume
def upsample(net, stride, mode='ZEROS'):
"""
Imitate reverse operation of Max-Pooling by either placing original max values
into a fixed postion of upsampled cell:
[0.9] =>[[.9, 0], (stride=2)
[ 0, 0]]
or copying the value into each cell:
[0.9] =>[[.9, .9], (stride=2)
[ .9, .9]]
:param net: 4D input tensor with [batch_size, width, heights, channels] axis
:param stride:
:param mode: string 'ZEROS' or 'COPY' indicating which value to use for undefined cells
:return: 4D tensor of size [batch_size, width*stride, heights*stride, channels]
"""
assert mode in ['COPY', 'ZEROS']
with tf.name_scope('Upsampling'):
net = _upsample_along_axis(net, 2, stride, mode=mode)
net = _upsample_along_axis(net, 1, stride, mode=mode)
return net
# Keras layers for pooling and unpooling
# https://github.com/ykamikawa/tf-keras-SegNet/blob/master/layers.py
class MaxPoolingWithArgmax2D(Layer):
def __init__(
self,
pool_size=(2, 2),
strides=(2, 2),
padding='same',
**kwargs):
super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
self.padding = padding
self.pool_size = pool_size
self.strides = strides
def call(self, inputs, **kwargs):
padding = self.padding
pool_size = self.pool_size
strides = self.strides
#if K.backend() == 'tensorflow':
# ksize = [1, pool_size[0], pool_size[1], 1]
# strides = [1, strides[0], strides[1], 1]
# output, argmax = max_pool_with_argmax(inputs, ksize, strides)
#else:
# errmsg = '{} backend is not supported for layer {}'.format(
# K.backend(), type(self).__name__)
# raise NotImplementedError(errmsg)
#argmax = K.cast(argmax, K.floatx())
#return [output, argmax]
if tf.keras.backend.backend() == 'tensorflow':
ksize = [1, pool_size[0], pool_size[1], 1]
padding = padding.upper()
strides = [1, strides[0], strides[1], 1]
output, argmax = tf.nn.max_pool_with_argmax(
inputs,
ksize=ksize,
strides=strides,
padding=padding)
else:
errmsg = '{} backend is not supported for layer {}'.format(
tf.keras.backend.backend(), type(self).__name__)
raise NotImplementedError(errmsg)
argmax = tf.keras.backend.cast(argmax, tf.keras.backend.floatx())
return [output, argmax]
def compute_output_shape(self, input_shape):
ratio = (1, 2, 2, 1)
output_shape = [
dim // ratio[idx]
if dim is not None else None
for idx, dim in enumerate(input_shape)]
output_shape = tuple(output_shape)
return [output_shape, output_shape]
def compute_mask(self, inputs, mask=None):
return 2 * [None]
class MaxUnpooling2D(Layer):
def __init__(self, up_size=(2, 2), **kwargs):
super(MaxUnpooling2D, self).__init__(**kwargs)
self.up_size = up_size
def call(self, inputs, output_shape=None):
updates = inputs[0]
mask = tf.cast(inputs[1], dtype=tf.int64)
ksize = [1, self.up_size[0], self.up_size[1], 1]
return unpool(updates, mask, ksize)
def compute_output_shape(self, input_shape):
mask_shape = input_shape[1]
return (
mask_shape[0],
mask_shape[1] * self.up_size[0],
mask_shape[2] * self.up_size[1],
mask_shape[3]
)
# Test implementation
if __name__ == '__main__':
import numpy as np
input = keras.layers.Input((4, 4, 3))
(e, m) = MaxPoolingWithArgmax2D()(input)
o2 = MaxUnpooling2D()([e, m])
model = keras.Model(inputs=input, outputs=o2) # outputs=o
model.compile(optimizer="adam", loss='categorical_crossentropy')
model.summary()
x = np.random.randint(0, 100, (3, 4, 4, 3))
m = model.predict(x)
print(x[0])
print(m[0])