Skip to content

Commit

Permalink
Merge pull request #99 from omidsbhn/dev_eupg
Browse files Browse the repository at this point in the history
Dev eupg
  • Loading branch information
ffelten authored May 3, 2024
2 parents affb987 + 2868f29 commit bca7952
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
12 changes: 10 additions & 2 deletions examples/eupg_fishwood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mo_gymnasium as mo_gym
import numpy as np
import torch as th
from mo_gymnasium.utils import MORecordEpisodeStatistics

from morl_baselines.common.evaluation import eval_mo_reward_conditioned
Expand All @@ -10,8 +11,15 @@
env = MORecordEpisodeStatistics(mo_gym.make("fishwood-v0"), gamma=0.99)
eval_env = mo_gym.make("fishwood-v0")

def scalarization(reward: np.ndarray, w):
return min(reward[0], reward[1] // 2)
def scalarization(reward: np.ndarray, w=None):
reward = th.tensor(reward) if not isinstance(reward, th.Tensor) else reward
# Handle the case when reward is a single tensor of shape (2, )
if reward.dim() == 1 and reward.size(0) == 2:
return min(reward[0], reward[1] // 2).item()

# Handle the case when reward is a tensor of shape (200, 2)
elif reward.dim() == 2 and reward.size(1) == 2:
return th.min(reward[:, 0], reward[:, 1] // 2)

agent = EUPG(env, scalarization=scalarization, weights=np.ones(2), gamma=0.99, log=True, learning_rate=0.001)
agent.train(total_timesteps=int(4e6), eval_env=eval_env, eval_freq=1000)
Expand Down
21 changes: 15 additions & 6 deletions morl_baselines/single_policy/esr/eupg.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ def eval(self, obs: np.ndarray, accrued_reward: Optional[np.ndarray]) -> Union[i
else:
obs = th.as_tensor(obs).to(self.device)
accrued_reward = th.as_tensor(accrued_reward).float().to(self.device)
probas = self.net(obs, accrued_reward)
greedy_act = th.argmax(probas)
return greedy_act.detach().item()
return self.__choose_action(obs, accrued_reward)

@th.no_grad()
def __choose_action(self, obs: th.Tensor, accrued_reward: th.Tensor) -> int:
Expand All @@ -234,16 +232,18 @@ def update(self):
next_obs,
terminateds,
) = self.buffer.get_all_data(to_tensor=True, device=self.device)
# Scalarized episodic reward, our target :-)

episodic_return = th.sum(rewards, dim=0)
scalarized_return = self.scalarization(episodic_return.cpu().numpy(), self.weights)
scalarized_return = th.scalar_tensor(scalarized_return).to(self.device)

discounted_forward_rewards = self._forward_cumulative_rewards(rewards)
scalarized_values = self.scalarization(discounted_forward_rewards)
# For each sample in the batch, get the distribution over actions
current_distribution = self.net.distribution(obs, accrued_rewards)
# Policy gradient
log_probs = current_distribution.log_prob(actions)
loss = -th.mean(log_probs * scalarized_return)
log_probs = current_distribution.log_prob(actions.squeeze())
loss = -th.mean(log_probs * scalarized_values)

self.optimizer.zero_grad()
loss.backward()
Expand All @@ -259,6 +259,15 @@ def update(self):
},
)

def _forward_cumulative_rewards(self, rewards):
flip_rewards = rewards.flip(dims=[0])
cumulative_rewards = th.zeros(self.reward_dim).to(self.device)
for i in range(len(rewards)):
cumulative_rewards = self.gamma * cumulative_rewards + flip_rewards[i]
flip_rewards[i] = cumulative_rewards
forward_rewards = flip_rewards.flip(dims=[0])
return forward_rewards

def train(self, total_timesteps: int, eval_env: Optional[gym.Env] = None, eval_freq: int = 1000, start_time=None):
"""Train the agent.
Expand Down
12 changes: 10 additions & 2 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import mo_gymnasium as mo_gym
import numpy as np
import torch as th
from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import CONCAVE_MAP

from morl_baselines.common.evaluation import eval_mo, eval_mo_reward_conditioned
Expand Down Expand Up @@ -54,8 +55,15 @@ def test_eupg():
env = mo_gym.make("fishwood-v0")
eval_env = mo_gym.make("fishwood-v0")

def scalarization(reward: np.ndarray, w):
return min(reward[0], (reward[1] // 2) + 1)
def scalarization(reward: np.ndarray, w=None):
reward = th.tensor(reward) if not isinstance(reward, th.Tensor) else reward
# Handle the case when reward is a single tensor of shape (2, )
if reward.dim() == 1 and reward.size(0) == 2:
return min(reward[0], reward[1] // 2).item()

# Handle the case when reward is a tensor of shape (200, 2)
elif reward.dim() == 2 and reward.size(1) == 2:
return th.min(reward[:, 0], reward[:, 1] // 2)

agent = EUPG(env, scalarization=scalarization, gamma=0.99, log=False)
agent.train(total_timesteps=10000, eval_env=eval_env, eval_freq=100)
Expand Down

0 comments on commit bca7952

Please sign in to comment.