Skip to content

Commit

Permalink
Add support for continuous action spaces to PCN (#82)
Browse files Browse the repository at this point in the history
* Allow continuous actions for PCN

* Log horizons with log_all_multi_policy_metrics

* Fix formatting

* Fix continuous action

* Undo change in log_all_multi_policy_metrics

* Update README

* pre-commit fix

---------

Co-authored-by: Lucas Alegre <lucasnale@gmail.com>
  • Loading branch information
vaidas-sl and LucasAlegre authored Dec 19, 2023
1 parent 0907003 commit 7251514
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ A tutorial on MO-Gymnasium and MORL-Baselines is also available: [![Open in Cola
| [Envelope Q-Learning](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/envelope/envelope.py) | Multi | SER | Continuous | Discrete | [Paper](https://arxiv.org/pdf/1908.08342.pdf) |
| [CAPQL](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/capql/capql.py) | Multi | SER | Continuous | Continuous | [Paper](https://openreview.net/pdf?id=TjEzIsyEsQ6) |
| [PGMORL](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/pgmorl/pgmorl.py) <sup>[1](#f1)</sup> | Multi | SER | Continuous | Continuous | [Paper](https://people.csail.mit.edu/jiex/papers/PGMORL/paper.pdf) / [Supplementary Materials](https://people.csail.mit.edu/jiex/papers/PGMORL/supp.pdf) |
| [Pareto Conditioned Networks (PCN)](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/pcn/pcn.py) | Multi | SER/ESR <sup>[2](#f2)</sup> | Continuous | Discrete | [Paper](https://www.ifaamas.org/Proceedings/aamas2022/pdfs/p1110.pdf) |
| [Pareto Conditioned Networks (PCN)](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/pcn/pcn.py) | Multi | SER/ESR <sup>[2](#f2)</sup> | Continuous | Discrete / Continuous | [Paper](https://www.ifaamas.org/Proceedings/aamas2022/pdfs/p1110.pdf) |
| [Pareto Q-Learning](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/pareto_q_learning/pql.py) | Multi | SER | Discrete | Discrete | [Paper](https://jmlr.org/papers/volume15/vanmoffaert14a/vanmoffaert14a.pdf) |
| [MO Q learning](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/single_policy/ser/mo_q_learning.py) | Single | SER | Discrete | Discrete | [Paper](https://www.researchgate.net/publication/235698665_Scalarized_Multi-Objective_Reinforcement_Learning_Novel_Design_Techniques) |
| [MPMOQLearning](https://github.com/LucasAlegre/morl-baselines/blob/main/morl_baselines/multi_policy/multi_policy_moqlearning/mp_mo_q_learning.py) (outer loop MOQL) | Multi | SER | Discrete | Discrete | [Paper](https://www.researchgate.net/publication/235698665_Scalarized_Multi-Objective_Reinforcement_Learning_Novel_Design_Techniques) |
Expand Down
130 changes: 95 additions & 35 deletions morl_baselines/multi_policy/pcn/pcn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Pareto Conditioned Network. Code adapted from https://github.com/mathieu-reymond/pareto-conditioned-networks ."""
import heapq
import os
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -40,16 +41,16 @@ class Transition:
"""Transition dataclass."""

observation: np.ndarray
action: int
action: Union[float, int]
reward: np.ndarray
next_observation: np.ndarray
terminal: bool


class Model(nn.Module):
"""Model for the PCN."""
class BasePCNModel(nn.Module, ABC):
"""Base Model for the PCN."""

def __init__(self, state_dim: int, action_dim: int, reward_dim: int, scaling_factor: np.ndarray, hidden_dim: int = 64):
def __init__(self, state_dim: int, action_dim: int, reward_dim: int, scaling_factor: np.ndarray, hidden_dim: int):
"""Initialize the PCN model."""
super().__init__()
self.state_dim = state_dim
Expand All @@ -58,25 +59,47 @@ def __init__(self, state_dim: int, action_dim: int, reward_dim: int, scaling_fac
self.scaling_factor = nn.Parameter(th.tensor(scaling_factor).float(), requires_grad=False)
self.hidden_dim = hidden_dim

def forward(self, state, desired_return, desired_horizon):
"""Return log-probabilities of actions or return action directly in case of continuous action space."""
c = th.cat((desired_return, desired_horizon), dim=-1)
# commands are scaled by a fixed factor
c = c * self.scaling_factor
s = self.s_emb(state.float())
c = self.c_emb(c)
# element-wise multiplication of state-embedding and command
prediction = self.fc(s * c)
return prediction


class DiscreteActionsDefaultModel(BasePCNModel):
"""Model for the PCN with discrete actions."""

def __init__(self, state_dim: int, action_dim: int, reward_dim: int, scaling_factor: np.ndarray, hidden_dim: int):
"""Initialize the PCN model for discrete actions."""
super().__init__(state_dim, action_dim, reward_dim, scaling_factor, hidden_dim)
self.s_emb = nn.Sequential(nn.Linear(self.state_dim, self.hidden_dim), nn.Sigmoid())
self.c_emb = nn.Sequential(nn.Linear(self.reward_dim + 1, self.hidden_dim), nn.Sigmoid())
self.fc = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.action_dim),
nn.LogSoftmax(1),
nn.LogSoftmax(dim=1),
)

def forward(self, state, desired_return, desired_horizon):
"""Return log-probabilities of actions."""
c = th.cat((desired_return, desired_horizon), dim=-1)
# commands are scaled by a fixed factor
c = c * self.scaling_factor
s = self.s_emb(state.float())
c = self.c_emb(c)
# element-wise multiplication of state-embedding and command
log_prob = self.fc(s * c)
return log_prob

class ContinuousActionsDefaultModel(BasePCNModel):
"""Model for the PCN with continuous actions."""

def __init__(self, state_dim: int, action_dim: int, reward_dim: int, scaling_factor: np.ndarray, hidden_dim: int):
"""Initialize the PCN model for continuous actions."""
super().__init__(state_dim, action_dim, reward_dim, scaling_factor, hidden_dim)
self.s_emb = nn.Sequential(nn.Linear(self.state_dim, self.hidden_dim), nn.Sigmoid())
self.c_emb = nn.Sequential(nn.Linear(self.reward_dim + 1, self.hidden_dim), nn.Sigmoid())
self.fc = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.action_dim),
)


class PCN(MOAgent, MOPolicy):
Expand All @@ -101,12 +124,14 @@ def __init__(
gamma: float = 1.0,
batch_size: int = 256,
hidden_dim: int = 64,
noise: float = 0.1,
project_name: str = "MORL-Baselines",
experiment_name: str = "PCN",
wandb_entity: Optional[str] = None,
log: bool = True,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
model_class: Optional[Type[BasePCNModel]] = None,
) -> None:
"""Initialize PCN agent.
Expand All @@ -117,12 +142,14 @@ def __init__(
gamma (float, optional): Discount factor. Defaults to 1.0.
batch_size (int, optional): Batch size. Defaults to 32.
hidden_dim (int, optional): Hidden dimension. Defaults to 64.
noise (float, optional): Standard deviation of the noise to add to the action in the continuous action case. Defaults to 0.1.
project_name (str, optional): Name of the project for wandb. Defaults to "MORL-Baselines".
experiment_name (str, optional): Name of the experiment for wandb. Defaults to "PCN".
wandb_entity (Optional[str], optional): Entity for wandb. Defaults to None.
log (bool, optional): Whether to log to wandb. Defaults to True.
seed (Optional[int], optional): Seed for reproducibility. Defaults to None.
device (Union[th.device, str], optional): Device to use. Defaults to "auto".
model_class (Optional[Type[BasePCNModel]], optional): Model class to use. Defaults to None.
"""
MOAgent.__init__(self, env, device=device, seed=seed)
MOPolicy.__init__(self, device)
Expand All @@ -135,14 +162,26 @@ def __init__(
self.scaling_factor = scaling_factor
self.desired_return = None
self.desired_horizon = None
self.continuous_action = True if type(self.env.action_space) is gym.spaces.Box else False
self.noise = noise

if model_class and not issubclass(model_class, BasePCNModel):
raise ValueError("model_class must be a subclass of BasePCNModel")

self.model = Model(
if model_class is None:
if self.continuous_action:
model_class = ContinuousActionsDefaultModel
else:
model_class = DiscreteActionsDefaultModel

self.model = model_class(
self.observation_dim, self.action_dim, self.reward_dim, self.scaling_factor, hidden_dim=self.hidden_dim
).to(self.device)
self.opt = th.optim.Adam(self.model.parameters(), lr=self.learning_rate)

self.log = log
if log:
experiment_name += " continuous action" if self.continuous_action else ""
self.setup_wandb(project_name, experiment_name, wandb_entity)

def get_config(self) -> dict:
Expand All @@ -154,6 +193,8 @@ def get_config(self) -> dict:
"learning_rate": self.learning_rate,
"hidden_dim": self.hidden_dim,
"scaling_factor": self.scaling_factor,
"continuous_action": self.continuous_action,
"noise": self.noise,
"seed": self.seed,
}

Expand All @@ -173,22 +214,25 @@ def update(self):
batch.append((s_t, a_t, r_t, h_t))

obs, actions, desired_return, desired_horizon = zip(*batch)
log_prob = self.model(
prediction = self.model(
th.tensor(obs).to(self.device),
th.tensor(desired_return).to(self.device),
th.tensor(desired_horizon).unsqueeze(1).to(self.device),
)

self.opt.zero_grad()
# one-hot of action for CE loss
actions = F.one_hot(th.tensor(actions).long().to(self.device), len(log_prob[0]))
# cross-entropy loss
l = th.sum(-actions * log_prob, -1)
l = l.mean()
if self.continuous_action:
l = F.mse_loss(th.tensor(actions).float().to(self.device), prediction)
else:
# one-hot of action for CE loss
actions = F.one_hot(th.tensor(actions).long().to(self.device), len(prediction[0]))
# cross-entropy loss
l = th.sum(-actions * prediction, -1)
l = l.mean()
l.backward()
self.opt.step()

return l, log_prob
return l, prediction

def _add_episode(self, transitions: List[Transition], max_size: int, step: int) -> None:
# compute return
Expand Down Expand Up @@ -255,17 +299,26 @@ def _choose_commands(self, num_episodes: int):
return desired_return, desired_horizon

def _act(self, obs: np.ndarray, desired_return, desired_horizon, eval_mode=False) -> int:
log_probs = self.model(
prediction = self.model(
th.tensor([obs]).float().to(self.device),
th.tensor([desired_return]).float().to(self.device),
th.tensor([desired_horizon]).unsqueeze(1).float().to(self.device),
)
log_probs = log_probs.detach().cpu().numpy()[0]
if eval_mode:
action = np.argmax(log_probs)

if self.continuous_action:
action = prediction.detach().cpu().numpy()[0]
if not eval_mode:
# Add Gaussian noise: https://arxiv.org/pdf/2204.05027.pdf
action = action + np.random.normal(0.0, self.noise)
return action
else:
action = self.np_random.choice(np.arange(len(log_probs)), p=np.exp(log_probs))
return action
log_probs = prediction.detach().cpu().numpy()[0]

if eval_mode:
action = np.argmax(log_probs)
else:
action = self.np_random.choice(np.arange(len(log_probs)), p=np.exp(log_probs))
return action

def _run_episode(self, env, desired_return, desired_horizon, max_return, eval_mode=False):
transitions = []
Expand Down Expand Up @@ -394,9 +447,10 @@ def train(
for _ in range(num_model_updates):
l, lp = self.update()
loss.append(l.detach().cpu().numpy())
lp = lp.detach().cpu().numpy()
ent = np.sum(-np.exp(lp) * lp)
entropy.append(ent)
if not self.continuous_action:
lp = lp.detach().cpu().numpy()
ent = np.sum(-np.exp(lp) * lp)
entropy.append(ent)

desired_return, desired_horizon = self._choose_commands(num_er_episodes)

Expand All @@ -411,10 +465,16 @@ def train(
{
"train/hypervolume": hv_est,
"train/loss": np.mean(loss),
"train/entropy": np.mean(entropy),
"global_step": self.global_step,
},
)
if not self.continuous_action:
wandb.log(
{
"train/entropy": np.mean(entropy),
"global_step": self.global_step,
},
)

returns = []
horizons = []
Expand Down Expand Up @@ -448,7 +508,7 @@ def train(
},
)
print(
f"step {self.global_step} \t return {np.mean(returns, axis=0)}, ({np.std(returns, axis=0)}) \t loss {np.mean(loss):.3E}"
f"step {self.global_step} \t return {np.mean(returns, axis=0)}, ({np.std(returns, axis=0)}) \t loss {np.mean(loss):.3E} \t horizons {np.mean(horizons)}"
)

if self.global_step >= (n_checkpoints + 1) * total_timesteps / 1000:
Expand Down

0 comments on commit 7251514

Please sign in to comment.