Skip to content

Commit

Permalink
Add initial support to permute
Browse files Browse the repository at this point in the history
  • Loading branch information
MarsTechHAN committed Aug 14, 2021
1 parent c6db96e commit ad6ab22
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion keras2ncnn/graph_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ class GraphOptimization:

@staticmethod
def removing_unused_nodes(graph):
UNUSED_NODES = ['Dropout', 'Lambda']
UNUSED_NODES = ['Dropout', 'Lambda', 'TimeDistributed']
nodes_to_remove = []

for target_node_name in graph.get_graph().keys():
Expand Down
44 changes: 44 additions & 0 deletions keras2ncnn/keras_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,50 @@ def Reshape_helper(
layer['layer']['name'], {
'type': 'Reshape', 'param': ncnn_graph_attr, 'binary': []})

def Permute_helper(
self,
layer,
keras_graph_helper,
ncnn_graph_helper,
ncnn_helper):
PERMUTE_LUT = {
# Dim=2
'1,2': (2, 0),
'2,1': (2, 1),

# Dim=3
'1,2,3': (3, 0),
'1,3,2': (3, 5),
'2,1,3': (3, 1),
'2,3,1': (3, 4),
'3,1,2': (3, 3),
'3,2,1': (3, 2),
}

dims = layer['layer']['config']['dims']
if(len(dims) in [2, 3]):
order_type = PERMUTE_LUT[','.join(list(map(str, dims)))]
else:
print(
'[ERROR] Permute Layer Dim [%s] is not supported.' %
str(dims))
frameinfo = inspect.getframeinfo(
inspect.currentframe())
print(
'Failed to convert at %s:%d %s()' %
(frameinfo.filename, frameinfo.lineno, frameinfo.function))
sys.exit(-1)

ncnn_graph_attr = ncnn_helper.dump_args(
'Permute', order_type=order_type)
ncnn_graph_helper.node(
layer['layer']['name'],
keras_graph_helper.get_node_inbounds(
layer['layer']['name']))
ncnn_graph_helper.set_node_attr(
layer['layer']['name'], {
'type': 'Permute', 'param': ncnn_graph_attr, 'binary': []})

def Cropping2D_helper(
self,
layer,
Expand Down
6 changes: 5 additions & 1 deletion keras2ncnn/ncnn_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ class NcnnParamDispatcher:
3: {'flag': 1}
},

'Permute': {
0: {'order_type': 0}
},

'Sigmoid': {
},

Expand Down Expand Up @@ -204,7 +208,7 @@ def dump_args(self, operator, **kwargs):
ncnn_args_phrase = ncnn_args_phrase + \
'%d=%e ' % (arg, params_arg)

elif isinstance(params_arg, list):
elif isinstance(params_arg, (list, tuple)):
ncnn_args_phrase = ncnn_args_phrase + \
'%d=%d,%s ' % (-23300 - arg, len(params_arg),
','.join(list(map(str, params_arg))))
Expand Down

0 comments on commit ad6ab22

Please sign in to comment.