Skip to content

Commit

Permalink
quick env scaling for pendulum
Browse files Browse the repository at this point in the history
  • Loading branch information
will-maclean committed Sep 8, 2024
1 parent ab26dda commit 49b5dce
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 86 deletions.
2 changes: 1 addition & 1 deletion examples/sac_pendulum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn main() {
None,
Some(0.995),
Box::new(BoxSpace::from(([0.0].to_vec(), [0.0].to_vec()))),
Box::new(BoxSpace::from(([-2.0].to_vec(), [2.0].to_vec()))),
Box::new(BoxSpace::from(([-1.0].to_vec(), [1.0].to_vec()))),
);

let buffer = ReplayBuffer::new(offline_params.memory_size);
Expand Down
217 changes: 132 additions & 85 deletions src/sac/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use burn::{
},
optim::{adaptor::OptimizerAdaptor, Adam, AdamConfig, GradientsParams, Optimizer},
tensor::{
backend::{AutodiffBackend, Backend},
ElementConversion, Shape, Tensor,
backend::{AutodiffBackend, Backend}, Bool, ElementConversion, Shape, Tensor
},
};

Expand Down Expand Up @@ -177,85 +176,20 @@ impl<B: AutodiffBackend> SACAgent<B> {
update_every: 1,
}
}
}

impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
fn act(
&mut self,
_global_step: usize,
_global_frac: f32,
obs: &Vec<f32>,
greedy: bool,
inference_device: &<B>::Device,
) -> (Vec<f32>, LogItem) {
// don't judge me
let a: Vec<f32> = self
.pi
.act(&obs.clone().to_tensor(inference_device), greedy)
.detach()
.into_data()
.to_vec()
.unwrap();

(a, LogItem::default())
}

fn train_step(
&mut self,
global_step: usize,
replay_buffer: &ReplayBuffer<Vec<f32>, Vec<f32>>,
offline_params: &crate::common::algorithm::OfflineAlgParams,
train_device: &<B as Backend>::Device,
) -> (Option<f32>, LogItem) {

let log_dict = LogItem::default();

let sample_data = replay_buffer.batch_sample(offline_params.batch_size);

let states = sample_data.states.to_tensor(train_device);
let actions = sample_data.actions.to_tensor(train_device);
let next_states = sample_data.next_states.to_tensor(train_device);
let rewards = sample_data.rewards.to_tensor(train_device).unsqueeze_dim(1);
let terminated = sample_data
.terminated
.to_tensor(train_device)
.unsqueeze_dim(1);
let truncated = sample_data
.truncated
.to_tensor(train_device)
.unsqueeze_dim(1);
let dones = (terminated.float() + truncated.float()).bool();

disp_tensorf("states", &states);
disp_tensorf("actions", &actions);
disp_tensorf("next_states", &next_states);
disp_tensorf("rewards", &rewards);
disp_tensorb("dones", &dones);

let (actions_pi, log_prob) = self.pi.act_log_prob(states.clone());

disp_tensorf("actions_pi", &actions_pi);
disp_tensorf("log_prob", &log_prob);

// train entropy coeficient if required to do so
let (ent_coef, ent_coef_loss) = self.ent_coef.train_step(
log_prob.clone().flatten(0, 1),
offline_params.lr,
train_device,
);

println!("ent_cof: {}\n", ent_coef);

let log_dict = log_dict.push("ent_coef".to_string(), LogData::Float(ent_coef));

let log_dict = if let Some(l) = ent_coef_loss {
log_dict.push("ent_coef_loss".to_string(), LogData::Float(l))
} else {
log_dict
};

fn train_critics(&mut self,
states: Tensor<B, 2>,
actions: Tensor<B, 2>,
next_states: Tensor<B, 2>,
rewards: Tensor<B, 2>,
dones: Tensor<B, 2, Bool>,
gamma: f32,
train_device: &B::Device,
ent_coef: f32,
lr: f64,
log_dict: LogItem,
) -> LogItem{
// select action according to policy

let (next_action_sampled, next_action_log_prob) = self.pi.act_log_prob(next_states.clone());

disp_tensorf("next_action_sampled", &next_action_sampled);
Expand All @@ -282,20 +216,21 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
.bool_not()
.float()
.mul(next_q_vals)
.mul_scalar(offline_params.gamma);
.mul_scalar(gamma);

disp_tensorf("target_q_vals", &target_q_vals);

let target_q_vals = target_q_vals.detach();

// calculate the critic loss
let q_vals: Vec<Tensor<B, 2>> = self.qs.q_from_actions(states.clone(), actions.clone());
let q_vals: Vec<Tensor<B, 2>> = self.qs.q_from_actions(states, actions);

let loss_fn = MseLoss::new();
let mut critic_loss: Tensor<B, 1> = Tensor::zeros(Shape::new([1]), train_device);
for q in q_vals {
disp_tensorf("q", &q);
critic_loss =
critic_loss + MseLoss::new().forward(q, target_q_vals.clone(), Reduction::Mean);
critic_loss + loss_fn.forward(q, target_q_vals.clone(), Reduction::Mean);
}

disp_tensorf("critic_loss", &critic_loss);
Expand All @@ -315,12 +250,23 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
let critic_grads = GradientsParams::from_grads(critic_loss_grads, &self.qs);
self.qs = self
.q_optim
.step(offline_params.lr, self.qs.clone(), critic_grads);
.step(lr, self.qs.clone(), critic_grads);

log_dict

}

fn train_policy(&mut self,
states: Tensor<B, 2>,
ent_coef: f32,
lr: f64,
actions_pi: Tensor<B, 2>,
log_prob: Tensor<B, 2>,
log_dict: LogItem,
) -> LogItem{
// Policy loss
// recalculate q values with new critics
let q_vals = self.qs.q_from_actions(states.clone(), actions_pi);
let q_vals = self.qs.q_from_actions(states, actions_pi);
let q_vals = Tensor::cat(q_vals, 1).detach();
disp_tensorf("q_vals", &q_vals);
let min_q = q_vals.min_dim(1);
Expand All @@ -339,7 +285,108 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
let actor_grads = GradientsParams::from_grads(actor_loss_back, &self.pi);
self.pi = self
.pi_optim
.step(offline_params.lr, self.pi.clone(), actor_grads);
.step(lr, self.pi.clone(), actor_grads);

log_dict
}
}

impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
fn act(
&mut self,
_global_step: usize,
_global_frac: f32,
obs: &Vec<f32>,
greedy: bool,
inference_device: &<B>::Device,
) -> (Vec<f32>, LogItem) {
// don't judge me
let a: Vec<f32> = self
.pi
.act(&obs.clone().to_tensor(inference_device), greedy)
.detach()
.into_data()
.to_vec()
.unwrap();

(a, LogItem::default())
}

fn train_step(
&mut self,
global_step: usize,
replay_buffer: &ReplayBuffer<Vec<f32>, Vec<f32>>,
offline_params: &crate::common::algorithm::OfflineAlgParams,
train_device: &<B as Backend>::Device,
) -> (Option<f32>, LogItem) {

let log_dict = LogItem::default();

let sample_data = replay_buffer.batch_sample(offline_params.batch_size);

let states = sample_data.states.to_tensor(train_device);
let actions = sample_data.actions.to_tensor(train_device);
let next_states = sample_data.next_states.to_tensor(train_device);
let rewards = sample_data.rewards.to_tensor(train_device).unsqueeze_dim(1);
let terminated = sample_data
.terminated
.to_tensor(train_device)
.unsqueeze_dim(1);
let truncated = sample_data
.truncated
.to_tensor(train_device)
.unsqueeze_dim(1);
let dones = (terminated.float() + truncated.float()).bool();

disp_tensorf("states", &states);
disp_tensorf("actions", &actions);
disp_tensorf("next_states", &next_states);
disp_tensorf("rewards", &rewards);
disp_tensorb("dones", &dones);

let (actions_pi, log_prob) = self.pi.act_log_prob(states.clone());

disp_tensorf("actions_pi", &actions_pi);
disp_tensorf("log_prob", &log_prob);

// train entropy coeficient if required to do so
let (ent_coef, ent_coef_loss) = self.ent_coef.train_step(
log_prob.clone().flatten(0, 1),
offline_params.lr,
train_device,
);

println!("ent_cof: {}\n", ent_coef);

let log_dict = log_dict.push("ent_coef".to_string(), LogData::Float(ent_coef));

let log_dict = if let Some(l) = ent_coef_loss {
log_dict.push("ent_coef_loss".to_string(), LogData::Float(l))
} else {
log_dict
};

let log_dict = self.train_critics(
states.clone(),
actions,
next_states,
rewards,
dones,
offline_params.gamma,
train_device,
ent_coef,
offline_params.lr,
log_dict,
);

let log_dict = self.train_policy(
states,
ent_coef,
offline_params.lr,
actions_pi,
log_prob,
log_dict
);

// target critic updates
if global_step > (self.last_update + self.update_every) {
Expand Down

0 comments on commit 49b5dce

Please sign in to comment.