-
Notifications
You must be signed in to change notification settings - Fork 2
/
findLayersToReplace.m
49 lines (36 loc) · 1.58 KB
/
findLayersToReplace.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
% findLayersToReplace(lgraph) finds the classification layer of the layer graph
% and the preceding learnable (fully connected or convolutional) layer.
function [learnableLayer,classLayer] = findLayersToReplace(lgraph)
if ~isa(lgraph,'nnet.cnn.LayerGraph')
error('Argumentumnak muszáj egy LayerGraph object-nek lennie.')
end
% Scan source, destination and layer names.
src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');
% Find classification layer
isClassificationLayer = arrayfun(@(l) ...
(isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
lgraph.Layers);
if sum(isClassificationLayer) ~= 1
error('A Layer Graph-nak muszáj rendelkeznie egy single classification layer-el.')
end
classLayer = lgraph.Layers(isClassificationLayer);
% On the layer diagram, start from the classification layer and work backwards.
% If the network branches, write out an error.
currentLayerIdx = find(isClassificationLayer);
while true
if numel(currentLayerIdx) ~= 1
error(' Nem megfelelő Learnable Layer')
end
currentLayerType = class(lgraph.Layers(currentLayerIdx));
isLearnableLayer = ismember(currentLayerType, ...
['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
if isLearnableLayer
learnableLayer = lgraph.Layers(currentLayerIdx);
return
end
currentDstIdx = find(layerNames(currentLayerIdx) == dst);
currentLayerIdx = find(src(currentDstIdx) == layerNames);
end
end