Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAC (#34) #37

Merged
merged 15 commits into from
Aug 4, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ indicatif = "0.17.8"
ndarray = "0.15.6"
plotters = "0.3.6"
rand = "0.8.5"
serde = { version = "1.0.203", features = ["std", "derive"] }
107 changes: 107 additions & 0 deletions examples/sac_pendulum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use std::path::PathBuf;

use burn::{
backend::{
libtorch::LibTorchDevice,
Autodiff, LibTorch,
},
grad_clipping::GradientClippingConfig,
optim::{Adam, AdamConfig},
};
use sb3_burn::{
common::{
algorithm::{OfflineAlgParams, OfflineTrainer},
buffer::ReplayBuffer,
eval::EvalConfig,
logger::{CsvLogger, Logger},
spaces::BoxSpace,
},
env::classic_control::pendulum::make_pendulum,
simple_sac::{
agent::SACAgent,
models::{PiModel, QModelSet},
},
};

const N_CRITICS: usize = 2;

fn main() {
// Using parameters from:
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml

type TrainingBacked = Autodiff<LibTorch>;

let train_device = LibTorchDevice::Cuda(0);

let env = make_pendulum(None);

let config_optimizer =
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));

let pi_optim = config_optimizer.init();

let qs: QModelSet<TrainingBacked> = QModelSet::new(
env.observation_space().shape().len(),
env.action_space().shape().len(),
&train_device,
N_CRITICS,
);

let q_optim = config_optimizer.init();

let pi = PiModel::new(
env.observation_space().shape().len(),
env.action_space().shape().len(),
&train_device,
);

let offline_params = OfflineAlgParams::new()
.with_batch_size(256)
.with_memory_size(1000000)
.with_n_steps(20000)
.with_warmup_steps(1000)
.with_lr(1e-3)
.with_eval_at_start_of_training(true)
.with_eval_at_end_of_training(true)
.with_evaluate_during_training(false);

let agent = SACAgent::new(
pi,
qs.clone(),
qs,
pi_optim,
q_optim,
None,
true,
None,
Some(0.995),
Box::new(BoxSpace::from(([0.0].to_vec(), [0.0].to_vec()))),
Box::new(BoxSpace::from(([0.0].to_vec(), [0.0].to_vec()))),
);

let buffer = ReplayBuffer::new(offline_params.memory_size);

let logger = CsvLogger::new(
PathBuf::from("logs/sac_pendulum/log_sac_pendulum.csv"),
false,
);

match logger.check_can_log(false) {
Ok(_) => {}
Err(err) => panic!("Error setting up logger: {err}"),
}

let mut trainer: OfflineTrainer<_, Adam<LibTorch>, _, _, _> = OfflineTrainer::new(
offline_params,
env,
make_pendulum(None),
agent,
buffer,
Box::new(logger),
None,
EvalConfig::new().with_n_eval_episodes(20),
&train_device,
);

trainer.train();
}
107 changes: 107 additions & 0 deletions examples/sac_probe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use std::path::PathBuf;

use burn::{
backend::{
libtorch::LibTorchDevice,
Autodiff, LibTorch,
},
grad_clipping::GradientClippingConfig,
optim::{Adam, AdamConfig},
};
use sb3_burn::{
common::{
algorithm::{OfflineAlgParams, OfflineTrainer},
buffer::ReplayBuffer,
eval::EvalConfig,
logger::{CsvLogger, Logger},
spaces::BoxSpace,
},
env::{base::Env, probe::ProbeEnvContinuousActions},
simple_sac::{
agent::SACAgent,
models::{PiModel, QModelSet},
},
};

const N_CRITICS: usize = 2;

fn main() {
// Using parameters from:
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml

type TrainingBacked = Autodiff<LibTorch>;

let train_device = LibTorchDevice::Cuda(0);

let env = ProbeEnvContinuousActions::default();

let config_optimizer =
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));

let pi_optim = config_optimizer.init();

let qs: QModelSet<TrainingBacked> = QModelSet::new(
env.observation_space().shape().len(),
env.action_space().shape().len(),
&train_device,
N_CRITICS,
);

let q_optim = config_optimizer.init();

let pi = PiModel::new(
env.observation_space().shape().len(),
env.action_space().shape().len(),
&train_device,
);

let offline_params = OfflineAlgParams::new()
.with_batch_size(256)
.with_memory_size(1000000)
.with_n_steps(2000)
.with_warmup_steps(256)
.with_lr(5e-3)
.with_eval_at_start_of_training(true)
.with_eval_at_end_of_training(true)
.with_evaluate_during_training(false);

let agent = SACAgent::new(
pi,
qs.clone(),
qs,
pi_optim,
q_optim,
None,
true,
None,
Some(0.995),
Box::new(BoxSpace::from(([0.0].to_vec(), [1.0].to_vec()))),
Box::new(BoxSpace::from(([0.0].to_vec(), [1.0].to_vec()))),
);

let buffer = ReplayBuffer::new(offline_params.memory_size);

let logger = CsvLogger::new(
PathBuf::from("logs/sac_probe/log_sac_probe.csv"),
false,
);

match logger.check_can_log(false) {
Ok(_) => {}
Err(err) => panic!("Error setting up logger: {err}"),
}

let mut trainer: OfflineTrainer<_, Adam<LibTorch>, _, _, _> = OfflineTrainer::new(
offline_params,
Box::new(env),
Box::new(ProbeEnvContinuousActions::default()),
agent,
buffer,
Box::new(logger),
None,
EvalConfig::new().with_n_eval_episodes(20),
&train_device,
);

trainer.train();
}
70 changes: 45 additions & 25 deletions src/common/distributions/action_distribution.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use burn::{module::{Module, Param}, nn::{Linear, LinearConfig}, tensor::{backend::Backend, Shape, Tensor}};
use burn::{
module::{Module, Param},
nn::{Linear, LinearConfig},
tensor::{backend::Backend, Shape, Tensor},
};

use crate::common::{agent::Policy, utils::module_update::update_linear};

use super::{distribution::BaseDistribution, normal::Normal};

pub trait ActionDistribution<B> : Policy<B>
pub trait ActionDistribution<B>: Policy<B>
where
B: Backend,
// SD: BaseDistribution<B, 1>
{
/// takes in a batched input and returns the
/// batched log prob
Expand All @@ -22,8 +25,7 @@ where
/// returns an unbatched sample from the distribution
fn sample(&mut self) -> Tensor<B, 2>;


fn get_actions(&mut self, deterministic: bool) -> Tensor<B, 2>{
fn get_actions(&mut self, deterministic: bool) -> Tensor<B, 2> {
if deterministic {
self.mode()
} else {
Expand All @@ -36,7 +38,7 @@ where

/// Continuous actions are usually considered to be independent,
/// so we can sum components of the ``log_prob`` or the entropy.
///
///
/// # Shapes
/// t: (batch, n_actions) or (batch)
/// return: (batch) for (batch, n_actions) input, or (1) for (batch) input
Expand All @@ -49,28 +51,38 @@ where
// }

#[derive(Debug, Module)]
pub struct DiagGaussianDistribution<B: Backend>{
pub struct DiagGaussianDistribution<B: Backend> {
means: Linear<B>,
log_std: Param<Tensor<B, 1>>,
dist: Normal<B, 2>,
}

impl<B: Backend> DiagGaussianDistribution<B>{
pub fn new(latent_dim: usize, action_dim: usize, log_std_init: f32, device: &B::Device) -> Self {
impl<B: Backend> DiagGaussianDistribution<B> {
pub fn new(
latent_dim: usize,
action_dim: usize,
log_std_init: f32,
device: &B::Device,
) -> Self {
// create the distribution with dummy values for now
let loc: Tensor<B, 2> = Tensor::ones(Shape::new([action_dim]), &Default::default()).unsqueeze_dim(0);
let std: Tensor<B, 2> = Tensor::ones(Shape::new([action_dim]), &Default::default()).mul_scalar(log_std_init).unsqueeze_dim(0);
let loc: Tensor<B, 2> =
Tensor::ones(Shape::new([action_dim]), &Default::default()).unsqueeze_dim(0);
let std: Tensor<B, 2> = Tensor::ones(Shape::new([action_dim]), &Default::default())
.mul_scalar(log_std_init)
.unsqueeze_dim(0);
let dist: Normal<B, 2> = Normal::new(loc, std);

Self {
Self {
means: LinearConfig::new(latent_dim, action_dim).init(device),
log_std: Param::from_tensor(Tensor::ones(Shape::new([action_dim]), device).mul_scalar(log_std_init)),
dist: dist.no_grad()
log_std: Param::from_tensor(
Tensor::ones(Shape::new([action_dim]), device).mul_scalar(log_std_init),
),
dist: dist.no_grad(),
}
}
}

impl<B: Backend>ActionDistribution<B> for DiagGaussianDistribution<B>{
impl<B: Backend> ActionDistribution<B> for DiagGaussianDistribution<B> {
fn log_prob(&self, sample: Tensor<B, 2>) -> Tensor<B, 2> {
self.dist.log_prob(sample)
}
Expand All @@ -86,24 +98,29 @@ impl<B: Backend>ActionDistribution<B> for DiagGaussianDistribution<B>{
fn sample(&mut self) -> Tensor<B, 2> {
self.dist.rsample()
}

fn actions_from_obs(&mut self, obs: Tensor<B, 2>) -> Tensor<B, 2> {
let scale = self.log_std.val().clone().exp().unsqueeze_dim(0).repeat(0, obs.shape().dims[0]);
let scale = self
.log_std
.val()
.clone()
.exp()
.unsqueeze_dim(0)
.repeat(0, obs.shape().dims[0]);
let mean = self.means.forward(obs);
self.dist = Normal::new(mean, scale).no_grad();

self.sample()
}
}

impl<B: Backend> Policy<B> for DiagGaussianDistribution<B>{
impl<B: Backend> Policy<B> for DiagGaussianDistribution<B> {
fn update(&mut self, from: &Self, tau: Option<f32>) {
self.means = update_linear(&from.means, self.means.clone(), tau);
//TODO: update self.log_std
}
}


// #[derive(Clone, Debug)]
// pub struct StateDependentNoiseDistribution<B: Backend>{
// means: Linear<B>,
Expand All @@ -118,14 +135,17 @@ impl<B: Backend> Policy<B> for DiagGaussianDistribution<B>{

#[cfg(test)]
mod test {
use burn::{backend::{Autodiff, NdArray}, tensor::{Distribution, Shape, Tensor}};
use burn::{
backend::{Autodiff, NdArray},
tensor::{Distribution, Shape, Tensor},
};

use crate::common::distributions::action_distribution::ActionDistribution;

use super::DiagGaussianDistribution;

#[test]
fn test_diag_gaussian_dist(){
fn test_diag_gaussian_dist() {
type Backend = Autodiff<NdArray>;
let latent_size = 10;
let action_size = 3;
Expand All @@ -140,15 +160,15 @@ mod test {
// create some dummy obs
let batch_size = 6;
let dummy_obs: Tensor<Backend, 2> = Tensor::random(
Shape::new([batch_size, latent_size]),
Distribution::Normal(0.0, 1.0),
Shape::new([batch_size, latent_size]),
Distribution::Normal(0.0, 1.0),
&Default::default(),
);

let action_sample = dist.actions_from_obs(dummy_obs);
let log_prob = dist.log_prob(action_sample);

// build a dummy loss function on the log prob and
// build a dummy loss function on the log prob and
// make sure we can do a backwards pass
let dummy_loss = log_prob.sub_scalar(0.1).powi_scalar(2).mean();
let _grads = dummy_loss.backward();
Expand All @@ -157,4 +177,4 @@ mod test {
dist.mode();
dist.entropy();
}
}
}
Loading
Loading