-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconv_blocks.py
56 lines (44 loc) · 2.01 KB
/
conv_blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from keras.layers import Conv2D, MaxPooling2D, Concatenate, Add
from keras import layers
from keras import backend as K
K.set_image_data_format('channels_last')
import sys
import os
sys.path.append(os.path.abspath('../'))
from models import caps_layers
from models.caps_layers import BilinearUpsampling
from experiment.config import h, w, capsnet_config, fusion_config
NO_CONV_BLOCK = 'NO_CONV_BLOCK'
class ConvBlocks(object):
# Branches or Capsnet
def cnn_generic_branch(self, input):
x = input
branch_name = input.name.split('_')[0].split(':')[0]
shortcuts_map = {}
for i, item in enumerate(capsnet_config['branch']['blocks']):
params = item['params'].copy()
name = ''
if 'name' in item['params']:
name = '{}-{}'.format(branch_name, item['params']['name'])
params['name'] = name
module = caps_layers if hasattr(caps_layers, item['op']) else layers if hasattr(layers, item['op']) else None
x = getattr(module, item['op'])(**params)(x)
# Shortcut: get
if i in capsnet_config['branch']['shortcuts'].keys():
shortcuts_map[capsnet_config['branch']['shortcuts'][i]] = x
# Shortcut: put
if i in shortcuts_map.keys():
x = Add(name = '{}-branch_shorcuts_position_{}'.format(branch_name, i))([x, shortcuts_map[i]])
return x
# Fusion
def cnn_generic_fusion(self, inputs):
x = [self.cnn_generic_branch(input) for input in inputs]
for item in fusion_config:
module = caps_layers if hasattr(caps_layers, item['op']) else layers if hasattr(layers, item['op']) else None
x = getattr(module, item['op'])(**item['params'])(x)
return x
###########################
# CONV BLOCKS API CATALOG #
###########################
ApiConvBlocks = {str(i): getattr(ConvBlocks(), str(i)) for i in dir(ConvBlocks) if not i.startswith('__')}
ApiConvBlocks[NO_CONV_BLOCK] = lambda x: x # Add BYPASS function