Skip to content

Commit

Permalink
added flag to add activation on top
Browse files Browse the repository at this point in the history
  • Loading branch information
bonlime committed Mar 19, 2019
1 parent 9207331 commit ce09324
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, ski
return x


def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3), classes=21, backbone='mobilenetv2', OS=16, alpha=1.):
def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3), classes=21, backbone='mobilenetv2', OS=16, alpha=1., activation=None):
""" Instantiates the Deeplabv3+ architecture
Optionally loads weights pre-trained
Expand All @@ -239,6 +239,8 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
classes: number of desired classes. If classes != 21,
last layer is initialized randomly
backbone: backbone to use. one of {'xception','mobilenetv2'}
activation: optional activation to add to the top of the network.
One of 'softmax', 'sigmoid' or None
OS: determines input_shape/feature_extractor_output ratio. One of {8,16}.
Used only for xception backbone.
alpha: controls the width of the MobileNetV2 network. This is known as the
Expand All @@ -249,7 +251,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
of filters in each layer.
- If `alpha` = 1, default number of filters from the paper
are used at each layer.
Used only for mobilenetv2 backbone
Used only for mobilenetv2 backbone. Pretrained is only available for alpha=1.
# Returns
A Keras model instance.
Expand Down Expand Up @@ -451,6 +453,9 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
else:
inputs = img_input

if activation in {'softmax','sigmoid'}:
x = tf.keras.layers.Activation(activation)(x)

model = Model(inputs, x, name='deeplabv3plus')

# load weights
Expand Down

0 comments on commit ce09324

Please sign in to comment.