Skip to content

Commit

Permalink
Merge branch 'review_submission'
Browse files Browse the repository at this point in the history
  • Loading branch information
sai-prasanna committed Mar 13, 2024
2 parents dac754e + 5b53fa7 commit 5c03c40
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 2,048 deletions.
63 changes: 0 additions & 63 deletions 001_adversary.sh

This file was deleted.

11 changes: 5 additions & 6 deletions train_jobs.sh → 001_main_experiments.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

#SBATCH --array=0-359
#SBATCH --array=0-479
#SBATCH --partition alldlc_gpu-rtx2080
#SBATCH --job-name CMbRL_odin
#SBATCH --output logs/slurm/%x-%A-%a-HelloCluster.out
#SBATCH --error logs/slurm/%x-%A-%a-HelloCluster.err
#SBATCH --job-name CMbRL
#SBATCH --output logs/slurm/%x-%A-%a.out
#SBATCH --error logs/slurm/%x-%A-%a.err
#SBATCH --mem 32GB
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
Expand All @@ -21,7 +21,7 @@ start=`date +%s`

tasks=("carl_classic_cartpole" "carl_dmc_walker")
seeds=("0" "42" "1337" "13" "71" "1994" "1997" "908" "2102" "3")
schemes=("enc_obs_dec_obs_default" "enc_img_dec_img_default" "enc_obs_dec_obs" "enc_img_dec_img" "enc_obs_ctx_dec_obs_ctx" "enc_img_ctx_dec_img_ctx" "enc_obs_dec_obs_pgm_ctx" "enc_img_dec_img_pgm_ctx" "enc_obs_dec_obs_pgm_ctx_adv" "enc_img_dec_img_pgm_ctx_adv")
schemes=("enc_obs_dec_obs_default" "enc_img_dec_img_default" "enc_obs_dec_obs" "enc_img_dec_img" "enc_obs_ctx_dec_obs_ctx" "enc_img_ctx_dec_img_ctx" "enc_obs_dec_obs_pgm_ctx" "enc_img_dec_img_pgm_ctx")
contexts=("single_0" "single_1" "double_box")

n_tasks=${#tasks[@]}
Expand All @@ -46,7 +46,6 @@ if [ "$scheme" == "enc_obs_dec_obs_default" ]; then
if [ "$context" != "single_0" ]; then
exit 0
fi

scheme="enc_obs_dec_obs"
context="default"
elif [ "$scheme" == "enc_img_dec_img_default" ]; then
Expand Down
14 changes: 5 additions & 9 deletions schedule_specific_jobs_cartpole.sh → 002_cartpole_experts.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

#SBATCH --array=0-9
#SBATCH --array=0-20
#SBATCH --partition alldlc_gpu-rtx2080
#SBATCH --job-name CMbRL_specific_cartpole
#SBATCH --output logs/slurm/%x-%A-%a-HelloCluster.out
#SBATCH --error logs/slurm/%x-%A-%a-HelloCluster.err
#SBATCH --output logs/slurm/%x-%A-%a.out
#SBATCH --error logs/slurm/%x-%A-%a.err
#SBATCH --mem 32GB
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
Expand All @@ -25,11 +25,7 @@ schemes=("enc_obs_dec_obs")
# gravity [0.98, 2.45, 3.92, 4.9, 9.8, 14.7, 15.68, 17.64, 19.6]
# length [0.1 0.2 0.3 0.5 0.7 0.8 0.9 1.0]
# grav, len [(2.45, 0.2), (17.64, 0.2), (17.64, 0.9), (2.45, 0.9)]
# contexts=("specific_0-0.98" "specific_0-2.45" "specific_0-3.92" "specific_0-4.9" "specific_0-9.8" "specific_0-14.7" "specific_0-15.68" "specific_0-17.64" "specific_0-19.6" "specific_1-0.1" "specific_1-0.2" "specific_1-0.3" "specific_1-0.5" "specific_1-0.7" "specific_1-0.8" "specific_1-0.9" "specific_1-1.0")

#contexts=("specific_0-2.45_1-0.2" "specific_0-17.64_1-0.2" "specific_0-17.64_1-0.9" "specific_0-2.45_1-0.9")

contexts=("specific_0-2.45_1-0.2" "specific_0-17.64_1-0.9")
contexts=("specific_0-0.98" "specific_0-2.45" "specific_0-3.92" "specific_0-4.9" "specific_0-9.8" "specific_0-14.7" "specific_0-15.68" "specific_0-17.64" "specific_0-19.6" "specific_1-0.1" "specific_1-0.2" "specific_1-0.3" "specific_1-0.5" "specific_1-0.7" "specific_1-0.8" "specific_1-0.9" "specific_1-1.0" "specific_0-2.45_1-0.2" "specific_0-17.64_1-0.2" "specific_0-17.64_1-0.9" "specific_0-2.45_1-0.9")

n_tasks=${#tasks[@]}
n_seeds=${#seeds[@]}
Expand All @@ -52,7 +48,7 @@ group_name="${task}_${context}_${scheme}"

python -m contextual_mbrl.dreamer.train --configs carl $scheme --task $task --env.carl.context $context --seed $seed --logdir logs/specific/$group_name/$seed --wandb.project '' --jax.policy_devices 0 --jax.train_devices 1 --run.steps 50000
python -m contextual_mbrl.dreamer.eval --logdir logs/specific/$group_name/$seed

python -m contextual_mbrl.dreamer.eval --logdir logs/specific/$group_name/$seed --random_policy True
end=`date +%s`
runtime=$((end-start))

Expand Down
17 changes: 6 additions & 11 deletions schedule_specific_jobs_walker.sh → 003_walker_experts.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

#SBATCH --array=0-19
#SBATCH --array=0-20
#SBATCH --partition alldlc_gpu-rtx2080
#SBATCH --job-name CMbRL_specific_walker
#SBATCH --output logs/slurm/%x-%A-%a-HelloCluster.out
#SBATCH --error logs/slurm/%x-%A-%a-HelloCluster.err
#SBATCH --output logs/slurm/%x-%A-%a.out
#SBATCH --error logs/slurm/%x-%A-%a.err
#SBATCH --mem 32GB
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
Expand All @@ -27,8 +27,7 @@ schemes=("enc_obs_dec_obs")
# actuator strength [0.1, 0.3, 0.5, 1.0, 1.5, 1.6, 1.8, 2.0]
# grav, strength (2.45, 0.3), (17.64, 0.3), (17.64, 1.8), (2.45, 1.8)

#contexts=("specific_0-0.98" "specific_0-2.45" "specific_0-3.92" "specific_0-4.9" "specific_0-9.81" "specific_0-14.7" "specific_0-15.68" "specific_0-17.64" "specific_0-19.6" "specific_1-0.1" "specific_1-0.3" "specific_1-0.5" "specific_1-1.0" "specific_1-1.5" "specific_1-1.6" "specific_1-1.8" "specific_1-2.0")
contexts=("specific_0-2.45_1-0.3" "specific_0-17.64_1-0.3" "specific_0-17.64_1-1.8" "specific_0-2.45_1-1.8")
contexts=("specific_0-0.98" "specific_0-2.45" "specific_0-3.92" "specific_0-4.9" "specific_0-9.81" "specific_0-14.7" "specific_0-15.68" "specific_0-17.64" "specific_0-19.6" "specific_1-0.1" "specific_1-0.3" "specific_1-0.5" "specific_1-1.0" "specific_1-1.5" "specific_1-1.6" "specific_1-1.8" "specific_1-2.0" "specific_0-2.45_1-0.3" "specific_0-17.64_1-0.3" "specific_0-17.64_1-1.8" "specific_0-2.45_1-1.8")


n_tasks=${#tasks[@]}
Expand All @@ -48,13 +47,9 @@ context=${contexts[$context_index]}

group_name="${task}_${context}_${scheme}"

# if logdir is not there train
if [ ! -d logs/specific/$group_name/$seed ]; then
echo "Training $group_name $seed"
python -m contextual_mbrl.dreamer.train --configs carl $scheme --task $task --env.carl.context $context --seed $seed --logdir logs/specific/$group_name/$seed --jax.policy_devices 0 --jax.train_devices 1 --run.steps 100000 --wandb.project ''
fi

python -m contextual_mbrl.dreamer.train --configs carl $scheme --task $task --env.carl.context $context --seed $seed --logdir logs/specific/$group_name/$seed --jax.policy_devices 0 --jax.train_devices 1 --run.steps 500000 --wandb.project ''
python -m contextual_mbrl.dreamer.eval --logdir logs/specific/$group_name/$seed
python -m contextual_mbrl.dreamer.eval --logdir logs/specific/$group_name/$seed --random_policy True

end=`date +%s`
runtime=$((end-start))
Expand Down
8 changes: 5 additions & 3 deletions schedule_dreams.sh → 004_collect_dreams.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#SBATCH --partition alldlc_gpu-rtx2080
#SBATCH --job-name CMbRL_array
#SBATCH --output logs/slurm/%x-%A-%a-HelloCluster.out
#SBATCH --error logs/slurm/%x-%A-%a-HelloCluster.err
#SBATCH --output logs/slurm/%x-%A-%a.out
#SBATCH --error logs/slurm/%x-%A-%a.err
#SBATCH --mem 32GB
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
Expand Down Expand Up @@ -43,7 +43,9 @@ do
group_name="${task}_${context}_${scheme}_normalized"

if [ -d "logs/$group_name/$seed" ]; then
python -m contextual_mbrl.dreamer.record_cart_length_dreams --logdir logs/$group_name/$seed
python -m contextual_mbrl.dreamer.record_dreams --logdir logs/$group_name/$seed --ctx_id 1
python -m contextual_mbrl.dreamer.record_dreams --logdir logs/$group_name/$seed --ctx_id 1 --counterfactual_ctx 0.1
python -m contextual_mbrl.dreamer.record_dreams --logdir logs/$group_name/$seed --ctx_id 1 --counterfactual_ctx 1.0
fi
done

Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 Sai
Copyright (c) 2023

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
69 changes: 63 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,70 @@
# contextual_mbrl
# Dreaming of Many Worlds: Learning Contextual World Models Aids Zero-Shot Generalization

Do world models actually model the world :p

## Setup
## Code

The training and evaluation code for our experiments is at `contextual_mbrl` directory. The `contextual_mbrl/dreamer/envs.py` defines all our context variations for train and eval and `contextual_mbrl/dreamer/configs.yaml` defines all the configurations and hyperparameters for our runs. The `contextual_mbrl/dreamer/record_dreams.py` allows us to record the extrapolated/counterfactual dreams.

### VScode
Install the recommended extensions when you open the project
The code in `dreamerv3_compat` is taken from the [fork](https://github.com/Kinds-of-Intelligence-CFI/dreamerv3-compat) which adds gymnasium support to the [official DreamerV3](https://github.com/danijar/dreamerv3) codebase. We changed it to incorporate our cRSSM method. Our changes are localized mainly to `dreamerv3_compat/dreamerv3/nets.py` and `dreamerv3_compat/dreamerv3/agent.py`.

## Setup

### Conda
First setup a miniconda environment.

`conda env create --name c_mbrl --file environment.yml`

## Experiments

### Main experiments

All our main experiments can be replicated by running `sbatch 001_main_experiments.sh` with slurm or modify it according to your training environment. To replicate individual experiments, you have to select one for each of the 4 main options.

1. **Task**
`carl_classic_cartpole`/`carl_dmc_walker`
2. **Modality**
`img` for pixel modality and `obs` for featurized modality
3. **Method**
`enc_{$modality}_dec_{$modality}_pgm_ctx` is the cRSSM setting,
`enc_{$modality}_dec_{$modality}` is the hidden-context setting,
`enc_{$modality}_ctx_dec_{$modality}_ctx` is the concat-context setting
4. **Training context**
`default`: Only on default context (makes sense to pair only with the hidden-context setting)
`single_0`: vary the first context (gravity for cartpole and walker)
`single_0`: vary the second context (length for cartpole and actuator strength for walker)

Then run the following commands your preferred settings and seed to run training followed by evaluation in required context regions.

``` bash
python -m contextual_mbrl.dreamer.train --configs carl $scheme --task $task --env.carl.context $training_context --seed $seed --logdir logs/$experiment_name/$seed --wandb.project '' --run.steps $steps
python -m contextual_mbrl.dreamer.eval --logdir logs/$experiment_name/$seed
```


### Experts and random policy

To train the experts, evaluate their mean returns and evaluate the models, run `002_cartpole_experts.sh` and `003_walker_experts.sh`.

An example to train cartpole export in gravity 17.64 and length 0.9 is

```
python -m contextual_mbrl.dreamer.train --configs carl enc_obs_dec_obs --task carl_classic_cartpole --env.carl.context specific_0-17.64_1-0.9 --seed 0 --logdir logs/specific/carl_classic_cartpole_specific_0-17.64_1-0.9/0 --wandb.project '' --run.steps 50000
python -m contextual_mbrl.dreamer.eval --logdir logs/specific/carl_classic_cartpole_specific_0-17.64_1-0.9/0
python -m contextual_mbrl.dreamer.eval --logdir logs/specific/carl_classic_cartpole_specific_0-17.64_1-0.9/0 --random_policy True
```

### Dreams
The extrapolated and counterfactual dreams of the trained cartpole models can be obtained using `004_collect_dreams.sh`. For individual recording of the dreams for a given context id (refer `envs.py` for context id to context mapping for different environments) for a given experiment run,

```bash
python -m contextual_mbrl.dreamer.record_dreams --logdir logs/$experiment_name/$seed --ctx_id 1
```

Use the `counterfactual_ctx` flag to provide counterfactual value and record dreams in different true contexts conditioning on this counterfactual value.

```bash
python -m contextual_mbrl.dreamer.record_dreams --logdir logs/$experiment_name/$seed --ctx_id 1 --counterfactual_ctx 1.0
```
## Results

`conda create --name c_mbrl --file environment.yml`
The figures can be plotted from the evaluations for all experiments using the notebooks `plot_analysis.ipynb` and `plot_analysis_rliable.ipynb`
Loading

0 comments on commit 5c03c40

Please sign in to comment.