-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTunWeight_RL_DDPGNetworks.m
79 lines (72 loc) · 3.31 KB
/
TunWeight_RL_DDPGNetworks.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
%% Critic
%load('NewTraied_600');
numObs = 11;
numAct = 4;
criticLayerSizes = [100 100];
statePath = [
imageInputLayer([numObs 1 1],'Normalization','none','Name', 'observation')
fullyConnectedLayer(criticLayerSizes(1), 'Name', 'CriticStateFC1', ...
'Weights',criticNetwork.Layers(2,1).Weights, ...
'Bias',criticNetwork.Layers(2,1).Bias)
reluLayer('Name','CriticStateRelu1')
fullyConnectedLayer(criticLayerSizes(2), 'Name', 'CriticStateFC2', ...
'Weights',criticNetwork.Layers(4,1).Weights, ...
'Bias',criticNetwork.Layers(4,1).Bias)
];
actionPath = [
imageInputLayer([numAct 1 1],'Normalization','none', 'Name', 'action')
fullyConnectedLayer(criticLayerSizes(2), 'Name', 'CriticActionFC1', ...
'Weights',criticNetwork.Layers(6,1).Weights, ...
'Bias',criticNetwork.Layers(6,1).Bias)
];
commonPath = [
additionLayer(2,'Name','add')
reluLayer('Name','CriticCommonRelu1')
fullyConnectedLayer(1, 'Name', 'CriticOutput',...
'Weights',criticNetwork.Layers(9,1).Weights, ...
'Bias',criticNetwork.Layers(9,1).Bias)
];
% Connect the layer graph
criticNetwork = layerGraph(statePath);
criticNetwork = addLayers(criticNetwork, actionPath);
criticNetwork = addLayers(criticNetwork, commonPath);
criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');
criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
% Create critic representation
criticOptions = rlRepresentationOptions('Optimizer','adam','LearnRate',1e-3, ...
'GradientThreshold',1,'L2RegularizationFactor',2e-4);
if useGPU
criticOptions.UseDevice = 'gpu';
end
critic = rlRepresentation(criticNetwork,criticOptions, ...
'Observation',{'observation'},env.getObservationInfo, ...
'Action',{'action'},env.getActionInfo);
% figure
% plot(criticNetwork)
%% ACTOR
% Create the actor network layers
actorLayerSizes = [100 100];
actorNetwork = [
imageInputLayer([numObs 1 1],'Normalization','none','Name','observation')
fullyConnectedLayer(actorLayerSizes(1), 'Name', 'ActorFC1', ...
'Weights',actorNetwork(2,1).Weights, ...
'Bias',actorNetwork(2,1).Bias)
reluLayer('Name', 'ActorRelu1')
fullyConnectedLayer(actorLayerSizes(2), 'Name', 'ActorFC2', ...
'Weights',actorNetwork(4,1).Weights, ...
'Bias',actorNetwork(4,1).Bias)
reluLayer('Name', 'ActorRelu2')
fullyConnectedLayer(numAct, 'Name', 'ActorFC3', ...
'Weights',actorNetwork(6,1).Weights, ...
'Bias',actorNetwork(6,1).Bias)
tanhLayer('Name','ActorTanh1')
scalingLayer('Name','ActorScaling1','Scale',max(actionInfo.UpperLimit))];
% Create actor representation
actorOptions = rlRepresentationOptions('Optimizer','adam','LearnRate',1e-4, ...
'GradientThreshold',1,'L2RegularizationFactor',1e-5);
if useGPU
actorOptions.UseDevice = 'gpu';
end
actor = rlRepresentation(actorNetwork,actorOptions, ...
'Observation',{'observation'},env.getObservationInfo, ...
'Action',{'ActorScaling1'},env.getActionInfo);