diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py index 91d2dc6c..2739ba67 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py @@ -57,10 +57,10 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, ) def fit(self, dataset): - state, action, reward, next_state, absorbing, _ = dataset.parse(to='torch') + state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') v, adv = compute_advantage_montecarlo(self._V, state, next_state, - reward, absorbing, + reward, absorbing, last, self.mdp_info.gamma) self._V.fit(state, v, **self._critic_fit_params) diff --git a/mushroom_rl/core/array_backend.py b/mushroom_rl/core/array_backend.py index bd0a8c48..cf752730 100644 --- a/mushroom_rl/core/array_backend.py +++ b/mushroom_rl/core/array_backend.py @@ -171,7 +171,14 @@ def shape(array): @staticmethod def full(shape, value): raise NotImplementedError - + + @staticmethod + def nonzero(array): + raise NotImplementedError + + @staticmethod + def repeat(array, repeats): + raise NotImplementedError class NumpyBackend(ArrayBackend): @staticmethod @@ -188,7 +195,12 @@ def to_numpy(array): @staticmethod def to_torch(array): - return None if array is None else torch.from_numpy(array).to(TorchUtils.get_device()) + if array is None: + return None + else: + if array.dtype == np.float64: + array = array.astype(np.float32) + return torch.from_numpy(array).to(TorchUtils.get_device()) @staticmethod def convert_to_backend(cls, array): @@ -303,6 +315,14 @@ def shape(array): @staticmethod def full(shape, value): return np.full(shape, value) + + @staticmethod + def nonzero(array): + return np.flatnonzero(array) + + @staticmethod + def repeat(array, repeats): + return np.repeat(array, repeats) class TorchBackend(ArrayBackend): @@ -443,6 +463,14 @@ def shape(array): @staticmethod def full(shape, value): return torch.full(shape, value) + + @staticmethod + def nonzero(array): + return torch.nonzero(array) + + @staticmethod + def repeat(array, repeats): + return torch.repeat_interleave(array, repeats) class ListBackend(ArrayBackend): diff --git a/mushroom_rl/core/dataset.py b/mushroom_rl/core/dataset.py index 367f7383..88102acb 100644 --- a/mushroom_rl/core/dataset.py +++ b/mushroom_rl/core/dataset.py @@ -11,6 +11,7 @@ from ._impl import * +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes class DatasetInfo(Serializable): def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype, @@ -473,22 +474,19 @@ def compute_J(self, gamma=1.): The cumulative discounted reward of each episode in the dataset. """ - js = list() - - j = 0. - episode_steps = 0 - for i in range(len(self)): - j += gamma ** episode_steps * self.reward[i] - episode_steps += 1 - if self.last[i] or i == len(self) - 1: - js.append(j) - j = 0. - episode_steps = 0 - - if len(js) == 0: - js = [0.] - - return self._array_backend.from_list(js) + r_ep = split_episodes(self.last, self.reward) + + if len(r_ep.shape) == 1: + r_ep = r_ep.unsqueeze(0) + if hasattr(r_ep, 'device'): + js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device) + else: + js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype) + + for k in range(r_ep.shape[1]): + js += gamma ** k * r_ep[..., k] + + return js def compute_metrics(self, gamma=1.): """ diff --git a/mushroom_rl/rl_utils/value_functions.py b/mushroom_rl/rl_utils/value_functions.py index 3b6b338d..b82611e1 100644 --- a/mushroom_rl/rl_utils/value_functions.py +++ b/mushroom_rl/rl_utils/value_functions.py @@ -1,7 +1,7 @@ import torch +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes - -def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma): +def compute_advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): """ Function to estimate the advantage and new value function target over a dataset. The value function is estimated using rollouts @@ -24,18 +24,21 @@ def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma): """ with torch.no_grad(): r = r.squeeze() - q = torch.zeros(len(r)) v = V(s).squeeze() - q_next = V(ss[-1]).squeeze().item() - for rev_k in range(len(r)): - k = len(r) - rev_k - 1 - q_next = r[k] + gamma * q_next * (1 - absorbing[k].int()) - q[k] = q_next + r_ep, absorbing_ep, ss_ep = split_episodes(last, r, absorbing, ss) + q_ep = torch.zeros_like(r_ep, dtype=torch.float32) + q_next_ep = V(ss_ep[..., -1, :]).squeeze() + + for rev_k in range(r_ep.shape[-1]): + k = r_ep.shape[-1] - rev_k - 1 + q_next_ep = r_ep[..., k] + gamma * q_next_ep * (1 - absorbing_ep[..., k].int()) + q_ep[..., k] = q_next_ep + q = unsplit_episodes(last, q_ep) adv = q - v - return q[:, None], adv[:, None] + return q[:, None], adv[:, None] def compute_advantage(V, s, ss, r, absorbing, gamma): """ @@ -97,13 +100,16 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam): with torch.no_grad(): v = V(s) v_next = V(ss) - gen_adv = torch.empty_like(v) - for rev_k in range(len(v)): - k = len(v) - rev_k - 1 - if last[k] or rev_k == 0: - gen_adv[k] = r[k] - v[k] - if not absorbing[k]: - gen_adv[k] += gamma * v_next[k] + + v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing) + gen_adv_ep = torch.zeros_like(v_ep) + for rev_k in range(v_ep.shape[-1]): + k = v_ep.shape[-1] - rev_k - 1 + if rev_k == 0: + gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] else: - gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1] + gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1] + + gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1) + return gen_adv + v, gen_adv \ No newline at end of file diff --git a/mushroom_rl/utils/episodes.py b/mushroom_rl/utils/episodes.py new file mode 100644 index 00000000..7bfd5021 --- /dev/null +++ b/mushroom_rl/utils/episodes.py @@ -0,0 +1,61 @@ +from mushroom_rl.core.array_backend import ArrayBackend + +def split_episodes(last, *arrays): + """ + Split a array from shape (n_steps) to (n_episodes, max_episode_steps). + """ + backend = ArrayBackend.get_array_backend_from(last) + + if last.sum().item() <= 1: + return arrays if len(arrays) > 1 else arrays[0] + + row_idx, colum_idx, n_episodes, max_episode_steps = _get_episode_idx(last, backend) + episodes_arrays = [] + + for array in arrays: + array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None) + + array_ep[row_idx, colum_idx] = array + episodes_arrays.append(array_ep) + + return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] + +def unsplit_episodes(last, *episodes_arrays): + """ + Unsplit a array from shape (n_episodes, max_episode_steps) to (n_steps). + """ + + if last.sum().item() <= 1: + return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] + + row_idx, colum_idx, _, _ = _get_episode_idx(last) + arrays = [] + + for episode_array in episodes_arrays: + array = episode_array[row_idx, colum_idx] + arrays.append(array) + + return arrays if len(arrays) > 1 else arrays[0] + +def _get_episode_idx(last, backend=None): + if backend is None: + backend = ArrayBackend.get_array_backend_from(last) + + n_episodes = last.sum() + last_idx = backend.nonzero(last).squeeze() + first_steps = backend.from_list([last_idx[0] + 1]) + if hasattr(last, 'device'): + first_steps = first_steps.to(last.device) + episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) + max_episode_steps = episode_steps.max() + + start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1]) + range_n_episodes = backend.arange(0, n_episodes, dtype=int) + range_len = backend.arange(0, last.shape[0], dtype=int) + if hasattr(last, 'device'): + range_n_episodes = range_n_episodes.to(last.device) + range_len = range_len.to(last.device) + row_idx = backend.repeat(range_n_episodes, episode_steps) + colum_idx = range_len - start_idx[row_idx] + + return row_idx, colum_idx, n_episodes, max_episode_steps \ No newline at end of file diff --git a/tests/algorithms/test_a2c.py b/tests/algorithms/test_a2c.py index 6b3babaa..d3e7b95b 100644 --- a/tests/algorithms/test_a2c.py +++ b/tests/algorithms/test_a2c.py @@ -75,7 +75,7 @@ def test_a2c(): agent = learn_a2c() w = agent.policy.get_weights() - w_test = np.array([0.9382279 , -1.8847059 , -0.13790752, -0.00786441]) + w_test = np.array([ 0.9389272 ,-1.8838323 ,-0.13710725,-0.00668973]) assert np.allclose(w, w_test) @@ -95,3 +95,5 @@ def test_a2c_save(tmpdir): print(save_attr, load_attr) tu.assert_eq(save_attr, load_attr) + +test_a2c() \ No newline at end of file diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index c99c6bfa..847dedfe 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -128,6 +128,4 @@ def test_dataset_loading(tmpdir): assert len(dataset.info) == len(new_dataset.info) for key in dataset.info: - assert np.array_equal(dataset.info[key], new_dataset.info[key]) - - + assert np.array_equal(dataset.info[key], new_dataset.info[key]) \ No newline at end of file diff --git a/tests/rl_utils/test_value_functions.py b/tests/rl_utils/test_value_functions.py new file mode 100644 index 00000000..b213c452 --- /dev/null +++ b/tests/rl_utils/test_value_functions.py @@ -0,0 +1,92 @@ +import torch +from mushroom_rl.policy import DeterministicPolicy +from mushroom_rl.environments.segway import Segway +from mushroom_rl.core import Core, Agent +from mushroom_rl.approximators import Regressor +from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator +from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo + +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes + +def test_compute_advantage_montecarlo(): + def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): + with torch.no_grad(): + r = r.squeeze() + q = torch.zeros(len(r)) + v = V(s).squeeze() + + for rev_k in range(len(r)): + k = len(r) - rev_k - 1 + if last[k] or rev_k == 0: + q_next = V(ss[k]).squeeze().item() + q_next = r[k] + gamma * q_next * (1 - absorbing[k].int()) + q[k] = q_next + + adv = q - v + return q[:, None], adv[:, None] + + torch.manual_seed(42) + _value_functions_tester(compute_advantage_montecarlo, advantage_montecarlo, 0.99) + +def test_compute_gae(): + def gae(V, s, ss, r, absorbing, last, gamma, lam): + with torch.no_grad(): + v = V(s) + v_next = V(ss) + gen_adv = torch.empty_like(v) + for rev_k in range(len(v)): + k = len(v) - rev_k - 1 + if last[k] or rev_k == 0: + gen_adv[k] = r[k] - v[k] + if not absorbing[k]: + gen_adv[k] += gamma * v_next[k] + else: + gen_adv[k] = r[k] - v[k] + gamma * v_next[k] + gamma * lam * gen_adv[k + 1] + return gen_adv + v, gen_adv + + torch.manual_seed(42) + _value_functions_tester(compute_gae, gae, 0.99, 0.95) + +def _value_functions_tester(test_fun, correct_fun, *args): + mdp = Segway() + V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}}) + + state, action, reward, next_state, absorbing, last = _get_episodes(mdp, 10) + + correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) + v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) + + assert torch.allclose(v, correct_v) + assert torch.allclose(adv, correct_adv) + + V.fit(state, correct_v) + + correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) + v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) + + assert torch.allclose(v, correct_v) + assert torch.allclose(adv, correct_adv) + +def _get_episodes(mdp, n_episodes=100): + mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0) + + approximator = Regressor(LinearApproximator, + input_shape=mdp.info.observation_space.shape, + output_shape=mdp.info.action_space.shape, + weights=mu) + + policy = DeterministicPolicy(approximator) + + agent = Agent(mdp.info, policy) + core = Core(agent, mdp) + dataset = core.evaluate(n_episodes=n_episodes) + + return dataset.parse(to='torch') + +class Net(torch.nn.Module): + def __init__(self, input_shape, output_shape, **kwargs): + super().__init__() + self._q = torch.nn.Linear(input_shape[0], output_shape[0]) + + def forward(self, x): + return self._q(x.float()) diff --git a/tests/utils/test_episodes.py b/tests/utils/test_episodes.py new file mode 100644 index 00000000..8f6f895d --- /dev/null +++ b/tests/utils/test_episodes.py @@ -0,0 +1,60 @@ +import torch +import numpy as np + +from mushroom_rl.core import Core, Agent +from mushroom_rl.approximators import Regressor +from mushroom_rl.policy import DeterministicPolicy +from mushroom_rl.approximators.parametric import LinearApproximator +from mushroom_rl.environments import Segway + +from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes + +def test_torch_split(): + torch.manual_seed(42) + mdp = Segway() + state, action, reward, next_state, absorbing, last = get_episodes(mdp) + + ep_arrays = split_episodes(last, state, action, reward, next_state, absorbing, last) + un_state, un_action, un_reward, un_next_state, un_absorbing, un_last = unsplit_episodes(last, *ep_arrays) + + assert torch.allclose(state, un_state) + assert torch.allclose(action, un_action) + assert torch.allclose(reward, un_reward) + assert torch.allclose(next_state, un_next_state) + assert torch.allclose(absorbing, un_absorbing) + assert torch.allclose(last, un_last) + +def test_numpy_split(): + torch.manual_seed(42) + np.random.seed(42) + + mdp = Segway() + state, action, reward, next_state, absorbing, last = get_episodes(mdp) + + state, action, reward, next_state, absorbing, last = state.numpy(), action.numpy(), reward.numpy(), next_state.numpy(), absorbing.numpy(), last.numpy() + + ep_arrays = split_episodes(last, state, action, reward, next_state, absorbing, last) + un_state, un_action, un_reward, un_next_state, un_absorbing, un_last = unsplit_episodes(last, *ep_arrays) + + assert np.allclose(state, un_state) + assert np.allclose(action, un_action) + assert np.allclose(reward, un_reward) + assert np.allclose(next_state, un_next_state) + assert np.allclose(absorbing, un_absorbing) + assert np.allclose(last, un_last) + +def get_episodes(mdp, n_episodes=100): + mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0) + + approximator = Regressor(LinearApproximator, + input_shape=mdp.info.observation_space.shape, + output_shape=mdp.info.action_space.shape, + weights=mu) + + policy = DeterministicPolicy(approximator) + + agent = Agent(mdp.info, policy) + core = Core(agent, mdp) + dataset = core.evaluate(n_episodes=n_episodes) + + return dataset.parse(to='torch')