Skip to content

Commit

Permalink
squashing more sac bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
will-maclean committed Sep 3, 2024
1 parent e9fc430 commit 590b858
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 21 deletions.
4 changes: 2 additions & 2 deletions examples/sac_pendulum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn main() {

let offline_params = OfflineAlgParams::new()
.with_batch_size(3)
.with_memory_size(1000000)
.with_memory_size(20000)
.with_n_steps(20000)
.with_warmup_steps(1000)
.with_lr(1e-3)
Expand All @@ -73,7 +73,7 @@ fn main() {
None,
Some(0.995),
Box::new(BoxSpace::from(([0.0].to_vec(), [0.0].to_vec()))),
Box::new(BoxSpace::from(([0.0].to_vec(), [0.0].to_vec()))),
Box::new(BoxSpace::from(([-2.0].to_vec(), [2.0].to_vec()))),
);

let buffer = ReplayBuffer::new(offline_params.memory_size);
Expand Down
14 changes: 4 additions & 10 deletions src/common/distributions/action_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ impl<B: Backend> ActionDistribution<B> for DiagGaussianDistribution<B> {
}

fn mode(&mut self) -> Tensor<B, 2> {
self.dist.mode().tanh()
self.dist.mode()
}

fn sample(&mut self) -> Tensor<B, 2> {
self.dist.rsample().tanh()
self.dist.rsample()
}

fn actions_from_obs(&mut self, obs: Tensor<B, 2>, deterministic: bool) -> Tensor<B, 2> {
Expand Down Expand Up @@ -129,7 +129,6 @@ impl<B: Backend> Policy<B> for DiagGaussianDistribution<B> {
pub struct SquashedDiagGaussianDistribution<B: Backend> {
diag_gaus_dist: DiagGaussianDistribution<B>,
epsilon: f32,
gaus_actions: Option<Tensor<B, 2>>,
}

impl<B: Backend> SquashedDiagGaussianDistribution<B> {
Expand All @@ -146,7 +145,6 @@ impl<B: Backend> SquashedDiagGaussianDistribution<B> {
device,
),
epsilon,
gaus_actions: None,
}
}
}
Expand Down Expand Up @@ -183,15 +181,11 @@ impl<B: Backend> ActionDistribution<B> for SquashedDiagGaussianDistribution<B> {
}

fn mode(&mut self) -> Tensor<B, 2> {
self.gaus_actions = Some(self.diag_gaus_dist.mode());

self.gaus_actions.clone().unwrap().tanh()
self.diag_gaus_dist.mode().tanh()
}

fn sample(&mut self) -> Tensor<B, 2> {
self.gaus_actions = Some(self.diag_gaus_dist.sample());

self.gaus_actions.clone().unwrap().tanh()
self.diag_gaus_dist.sample().tanh()
}

fn actions_from_obs(&mut self, obs: Tensor<B, 2>, deterministic: bool) -> Tensor<B, 2> {
Expand Down
2 changes: 1 addition & 1 deletion src/common/distributions/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<B: Backend, const D: usize> BaseDistribution<B, D> for Normal<B, D> {
fn rsample(&self) -> Tensor<B, D> {
let s = Tensor::random_like(&self.loc, Distribution::Normal(0.0, 1.0));

s.mul(self.scale.clone()) + self.loc.clone()
self.scale.clone().mul(s) + self.loc.clone()
}

fn log_prob(&self, value: Tensor<B, D>) -> Tensor<B, D> {
Expand Down
4 changes: 2 additions & 2 deletions src/common/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ pub fn angle_normalise(f: f32) -> f32 {
}

pub fn disp_tensorf<B: Backend, const D: usize>(name: &str, t: &Tensor<B, D>) {
// println!("{name}. {t}\n");
println!("{name}. {t}\n");
}

pub fn disp_tensorb<B: Backend, const D: usize>(name: &str, t: &Tensor<B, D, Bool>) {
// println!("{name}. {t}\n");
println!("{name}. {t}\n");
}

#[cfg(test)]
Expand Down
9 changes: 6 additions & 3 deletions src/sac/agent.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::time;

use burn::{
module::Module,
nn::{
Expand Down Expand Up @@ -246,6 +244,7 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
train_device,
);

println!("ent_cof: {}\n", ent_coef);

let log_dict = log_dict.push("ent_coef".to_string(), LogData::Float(ent_coef));

Expand Down Expand Up @@ -290,13 +289,15 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
let target_q_vals = target_q_vals.detach();

// calculate the critic loss
let q_vals = self.qs.q_from_actions(states.clone(), actions.clone());
let q_vals: Vec<Tensor<B, 2>> = self.qs.q_from_actions(states.clone(), actions.clone());

let mut critic_loss: Tensor<B, 1> = Tensor::zeros(Shape::new([1]), train_device);
for q in q_vals {
disp_tensorf("q", &q);
critic_loss =
critic_loss + MseLoss::new().forward(q, target_q_vals.clone(), Reduction::Mean);
}

disp_tensorf("critic_loss", &critic_loss);

// Confirmed with sb3 community that the 0.5 scaling has nothing to do with the number
Expand Down Expand Up @@ -352,6 +353,8 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
self.last_update = global_step;
}

// panic!();

(None, log_dict)
}

Expand Down
6 changes: 3 additions & 3 deletions src/sac/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ pub struct PiModel<B: Backend> {
impl<B: Backend> PiModel<B> {
pub fn new(obs_size: usize, n_actions: usize, device: &B::Device) -> Self {
Self {
mlp: MLP::new(&[obs_size, 32, 32].to_vec(), device),
dist: SquashedDiagGaussianDistribution::new(32, n_actions, device, 1e-6),
mlp: MLP::new(&[obs_size, 256, 256].to_vec(), device),
dist: SquashedDiagGaussianDistribution::new(256, n_actions, device, 1e-6),
}
}
}
Expand Down Expand Up @@ -46,7 +46,7 @@ pub struct QModel<B: Backend> {
impl<B: Backend> QModel<B> {
pub fn new(obs_size: usize, n_actions: usize, device: &B::Device) -> Self {
Self {
mlp: MLP::new(&[obs_size + n_actions, 32, n_actions].to_vec(), device),
mlp: MLP::new(&[obs_size + n_actions, 256, 256, n_actions].to_vec(), device),
}
}
}
Expand Down

0 comments on commit 590b858

Please sign in to comment.