Skip to content

Commit

Permalink
fix: 🐛 TRPO
Browse files Browse the repository at this point in the history
  • Loading branch information
Phoenix-Shen committed Dec 29, 2021
1 parent 8481384 commit 705138f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 20 deletions.
3 changes: 1 addition & 2 deletions TrustRegionPolicyOptimization(TRPO)/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def __init__(self) -> None:
self.env_name = "LunarLanderContinuous-v2"
self.seed = 123
self.save_dir = "saved_models/"
self.total_timesteps = 1e6
self.total_timesteps = 1000000
self.nsteps = 1024
self.lr = 3e-4
self.batch_size = 64
Expand All @@ -13,5 +13,4 @@ def __init__(self) -> None:
self.damping = 0.1
self.max_kl = 0.01
self.cuda = True
self.env_type = "mojuco"
self.log_dir = "logs"
26 changes: 14 additions & 12 deletions TrustRegionPolicyOptimization(TRPO)/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
import gym
import arguments
import running_filter
import os
from datetime import datetime
from torch.distributions.normal import Normal
# actor 要输出两个数值sigema和mu,所以在layer上面不能用一个sequential代替
# 还是得写两个类,分别调用forward函数
Expand Down Expand Up @@ -39,7 +39,7 @@ class Critic(nn.Module):
def __init__(self, n_features) -> None:
super().__init__()
self.value_net = nn.Sequential(
nn.Linear(n_features),
nn.Linear(n_features, 256),
nn.Tanh(),
nn.Linear(256, 256),
nn.Tanh(),
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(self, env: gym.Env, args: arguments.ARGS) -> None:
self.old_net.load_state_dict(self.net.state_dict())
# define the optimizer, the updation progress of the actor won't use optimizer
# in other word we use t.autograd to compute the gradients of the actor
self.optimzer = optim.Adam(
self.optimizer = optim.Adam(
self.net.critic.parameters(), lr=self.args.lr)
# define the running mean filter
self.running_state = running_filter.ZFilter(
Expand Down Expand Up @@ -111,7 +111,7 @@ def learn(self):
mb_obs.append(np.copy(obs))
mb_actions.append(actions)
mb_dones.append(dones)
mb_values.append(value.detach().numpy.squeeze())
mb_values.append(value.detach().numpy().squeeze())
# start to execute actions in the environment
obs_, reward, done, _ = self.env.step(actions)
dones = done
Expand All @@ -136,7 +136,7 @@ def learn(self):

# compute the last state value
with t.no_grad():
obs_tensor = t.tensor(mb_obs, dtype=t.float32).unsqueeze(0)
obs_tensor = t.tensor(obs, dtype=t.float32).unsqueeze(0)
last_value, _ = self.net.forward(obs_tensor)
last_value = last_value.detach().numpy().squeeze()

Expand All @@ -145,16 +145,16 @@ def learn(self):
mb_advs = np.zeros_like(mb_rewards)
lastgarlam = 0

for t in reversed(range(self.args.nsteps)):
if t == self.args.nsteps-1:
for ts in reversed(range(self.args.nsteps)):
if ts == self.args.nsteps-1:
nextnonterminal = 1.0-dones
nextvalues = last_value
else:
nextnonterminal = 1.0-mb_dones[t+1]
nextvalues = mb_values[t+1]
delta = mb_rewards[t]+self.args.gamma * \
nextvalues*nextnonterminal-mb_values[t]
mb_advs[t] = lastgarlam = delta+self.args.gamma * \
nextnonterminal = 1.0-mb_dones[ts+1]
nextvalues = mb_values[ts+1]
delta = mb_rewards[ts]+self.args.gamma * \
nextvalues*nextnonterminal-mb_values[ts]
mb_advs[ts] = lastgarlam = delta+self.args.gamma * \
self.args.tau*nextnonterminal*lastgarlam
mb_returns = mb_advs+mb_values
# normalize the advantages
Expand All @@ -164,6 +164,8 @@ def learn(self):
# start to update the network
policy_loss, value_loss = self._update_network(
mb_obs, mb_actions, mb_returns, mb_advs)
print('[{}] Update: {} / {}, Frames: {}, Reward: {:.3f}, VL: {:.3f}, PL: {}'.format(datetime.now(), update,
num_updates, (update + 1)*self.args.nsteps, final_reward, value_loss, policy_loss))

def _update_network(self, mb_obs, mb_actions, mb_returns, mb_advs):
# convert ndarrays to FloatTensor
Expand Down
8 changes: 2 additions & 6 deletions TrustRegionPolicyOptimization(TRPO)/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@

## 拿来主义

`the code in this folder is not functional! the code in github https://github.com/Khrylx/PyTorch-RL is recommended.`
- [拿来主义-github 现成代码-1](https://github.com/TianhongDai/reinforcement-learning-algorithms)

Running_state.py 是做什么东西的?

agent.py 第 124 行代码会出错,是真的搞不来

[拿来主义-github 现成代码](https://github.com/Khrylx/PyTorch-RL)
- [拿来主义-github 现成代码-2](https://github.com/Khrylx/PyTorch-RL)

## Trust Region 算法

Expand Down

0 comments on commit 705138f

Please sign in to comment.