Skip to content

Commit

Permalink
Add semi-supervised meta actor-critic (#137)
Browse files Browse the repository at this point in the history
* wip to add smac

* finish adding smac (v1)

* update README

* add missing files

* fix saving bug
  • Loading branch information
vitchyr authored Aug 7, 2021
1 parent 354f14c commit 6a13e1b
Show file tree
Hide file tree
Showing 50 changed files with 6,365 additions and 10 deletions.
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Reinforcement learning framework and algorithms implemented in PyTorch.

Implemented algorithms:
- Semi-supervised Meta Actor Critic
- [example script](examples/smac/ant.py)
- [paper](https://arxiv.org/abs/2107.03974)
- [Documentation](docs/SMAC.md)
- Skew-Fit
- [example script](examples/skewfit/sawyer_door.py)
- [paper](https://arxiv.org/abs/1903.03698)
Expand Down Expand Up @@ -222,8 +226,11 @@ Reinforcement Learning with Imagined Goals (RIG), run
# References
The algorithms are based on the following papers

[Offline Meta-Reinforcement Learning with Online Self-Supervision](https://arxiv.org/abs/2107.03974)
Vitchyr H. Pong, Ashvin Nair, Laura Smith, Catherine Huang, Sergey Levine. arXiv preprint, 2021.

[Skew-Fit: State-Covering Self-Supervised Reinforcement Learning](https://arxiv.org/abs/1903.03698).
Vitchyr H. Pong*, Murtaza Dalal*, Steven Lin*, Ashvin Nair, Shikhar Bahl, Sergey Levine. arXiv preprint, 2019.
Vitchyr H. Pong*, Murtaza Dalal*, Steven Lin*, Ashvin Nair, Shikhar Bahl, Sergey Levine. ICML, 2020.

[Visual Reinforcement Learning with Imagined Goals](https://arxiv.org/abs/1807.04742).
Ashvin Nair*, Vitchyr Pong*, Murtaza Dalal, Shikhar Bahl, Steven Lin, Sergey Levine. NeurIPS 2018.
Expand All @@ -250,12 +257,14 @@ Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. ICML, 2018.
Scott Fujimoto, Herke van Hoof, David Meger. ICML, 2018.

# Credits
This repository was initially developed primarily by [Vitchyr Pong](https://github.com/vitchyr), until July 2021, at which point it was transferred to the RAIL Berkeley organization and is primarily maintained by [Ashvin Nair](https://github.com/anair13).
Other major collaborators and contributions:
- [Murtaza Dalal](https://github.com/mdalal2020)
- [Steven Lin](https://github.com/stevenlin1111)

A lot of the coding infrastructure is based on [rllab](https://github.com/rll/rllab).
The serialization and logger code are basically a carbon copy of the rllab versions.

The Dockerfile is based on the [OpenAI mujoco-py Dockerfile](https://github.com/openai/mujoco-py/blob/master/Dockerfile).

Other major collaborators and contributions:
- [Murtaza Dalal](https://github.com/mdalal2020)
- [Steven Lin](https://github.com/stevenlin1111)
- [Ashvin Nair](https://github.com/anair13)
The SMAC code builds off of the [PEARL code](https://github.com/katerakelly/oyster), which built off of an older RLKit version.
22 changes: 22 additions & 0 deletions docs/SMAC.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Requirements that differ from base requirements:
- python 3.6.5
- joblib==0.9.4
- numpy==1.18.5

Running these examples requires first generating the data, updating the main launch script to point to that generated data, and then launching the SMAC experiments.

This can be done by first running
```bash
python examples/smac/generate_{ant|cheetah}_data.py
```
which runs [PEARL](https://github.com/katerakelly/oyster) to generate multi-task data.
This script will generate a directory and file of the form
```
LOCAL_LOG_DIR/<experiment_prefix>/<foldername>/extra_snapshot_itrXYZ.cpkl
```

You can then update the `examples/smac/{ant|cheetah}.py` file, where it says `TODO: update to point to correct file` to point to this file.
Finally, run the SMAC script
```bash
python examples/smac/{ant|cheetah}.py
```
71 changes: 71 additions & 0 deletions examples/smac/ant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from rlkit.torch.smac.base_config import DEFAULT_CONFIG
from rlkit.torch.smac.launcher import smac_experiment
import rlkit.util.hyperparameter as hyp


# @click.command()
# @click.option('--debug', is_flag=True, default=False)
# @click.option('--dry', is_flag=True, default=False)
# @click.option('--suffix', default=None)
# @click.option('--nseeds', default=1)
# @click.option('--mode', default='here_no_doodad')
# def main(debug, dry, suffix, nseeds, mode):
def main():
debug = True
dry = False
mode = 'here_no_doodad'
suffix = ''
nseeds = 1
gpu = True

path_parts = __file__.split('/')
suffix = '' if suffix is None else '--{}'.format(suffix)
exp_name = 'pearl-awac-{}--{}{}'.format(
path_parts[-2].replace('_', '-'),
path_parts[-1].split('.')[0].replace('_', '-'),
suffix,
)

if debug or dry:
exp_name = 'dev--' + exp_name
mode = 'here_no_doodad'
nseeds = 1

variant = DEFAULT_CONFIG.copy()
variant["env_name"] = "ant-dir"
variant["env_params"]["direction_in_degrees"] = True
search_space = {
'load_buffer_kwargs.pretrain_buffer_path': [
"results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file
],
'saved_tasks_path': [
"examples/smac/ant_tasks.joblib", # TODO: update to point to correct file
],
'load_buffer_kwargs.start_idx': [
-1200,
],
'seed': list(range(nseeds)),
}
from rlkit.launchers.launcher_util import run_experiment
sweeper = hyp.DeterministicHyperparameterSweeper(
search_space, default_parameters=variant,
)
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
variant['exp_id'] = exp_id
run_experiment(
smac_experiment,
unpack_variant=True,
exp_prefix=exp_name,
mode=mode,
variant=variant,
use_gpu=gpu,
)

print(exp_name)




if __name__ == "__main__":
main()

Binary file added examples/smac/ant_tasks.joblib
Binary file not shown.
96 changes: 96 additions & 0 deletions examples/smac/cheetah.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from rlkit.launchers.launcher_util import run_experiment
from rlkit.torch.smac.launcher import smac_experiment
from rlkit.torch.smac.base_config import DEFAULT_CONFIG
import rlkit.util.hyperparameter as hyp


# @click.command()
# @click.option('--debug', is_flag=True, default=False)
# @click.option('--dry', is_flag=True, default=False)
# @click.option('--suffix', default=None)
# @click.option('--nseeds', default=1)
# @click.option('--mode', default='here_no_doodad')
# def main(debug, dry, suffix, nseeds, mode):
def main():
debug = True
dry = False
mode = 'here_no_doodad'
suffix = ''
nseeds = 1
gpu=True

path_parts = __file__.split('/')
suffix = '' if suffix is None else '--{}'.format(suffix)
exp_name = 'pearl-awac-{}--{}{}'.format(
path_parts[-2].replace('_', '-'),
path_parts[-1].split('.')[0].replace('_', '-'),
suffix,
)

if debug or dry:
exp_name = 'dev--' + exp_name
mode = 'here_no_doodad'
nseeds = 1

print(exp_name)

variant = DEFAULT_CONFIG.copy()
variant["env_name"] = "cheetah-vel"
search_space = {
'load_buffer_kwargs.pretrain_buffer_path': [
"results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file
],
'saved_tasks_path': [
"examples/smac/cheetah_tasks.joblib", # TODO: update to point to correct file
],
'load_macaw_buffer_kwargs.rl_buffer_start_end_idxs': [
[(0, 1200)],
],
'load_macaw_buffer_kwargs.encoder_buffer_start_end_idxs': [
[(-400, None)],
],
'load_macaw_buffer_kwargs.encoder_buffer_matches_rl_buffer': [
False,
],
'algo_kwargs.use_rl_buffer_for_enc_buffer': [
False,
],
'algo_kwargs.train_encoder_decoder_in_unsupervised_phase': [
False,
],
'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [
False,
],
'algo_kwargs.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase': [
True,
],
'pretrain_offline_algo_kwargs.logging_period': [
25000,
],
'algo_kwargs.num_iterations': [
51,
],
'seed': list(range(nseeds)),
}
sweeper = hyp.DeterministicHyperparameterSweeper(
search_space, default_parameters=variant,
)
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
variant['exp_id'] = exp_id
run_experiment(
smac_experiment,
unpack_variant=True,
exp_prefix=exp_name,
mode=mode,
variant=variant,
use_gpu=gpu,
)

print(exp_name)




if __name__ == "__main__":
main()

Binary file added examples/smac/cheetah_tasks.joblib
Binary file not shown.
75 changes: 75 additions & 0 deletions examples/smac/generate_ant_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import rlkit.util.hyperparameter as hyp
from rlkit.launchers.launcher_util import run_experiment
from rlkit.torch.smac.base_config import DEFAULT_PEARL_CONFIG
from rlkit.torch.smac.pearl_launcher import pearl_experiment
from rlkit.util.io import load_local_or_remote_file


# @click.command()
# @click.option('--debug', is_flag=True, default=False)
# @click.option('--dry', is_flag=True, default=False)
# @click.option('--suffix', default=None)
# @click.option('--nseeds', default=1)
# @click.option('--mode', default='here_no_doodad')
# def main(debug, dry, suffix, nseeds, mode):
def main():
debug = True
dry = False
mode = 'here_no_doodad'
suffix = ''
nseeds = 1
gpu = True

path_parts = __file__.split('/')
suffix = '' if suffix is None else '--{}'.format(suffix)
exp_name = 'pearl-awac-{}--{}{}'.format(
path_parts[-2].replace('_', '-'),
path_parts[-1].split('.')[0].replace('_', '-'),
suffix,
)

if debug or dry:
exp_name = 'dev--' + exp_name
mode = 'here_no_doodad'
nseeds = 1

if dry:
mode = 'here_no_doodad'

print(exp_name)

task_data = load_local_or_remote_file(
"examples/smac/ant_tasks.joblib", # TODO: update to point to correct file
file_type='joblib')
tasks = task_data['tasks']
search_space = {
'seed': list(range(nseeds)),
}
variant = DEFAULT_PEARL_CONFIG.copy()
variant["env_name"] = "ant-dir"
# variant["train_task_idxs"] = list(range(100))
# variant["eval_task_idxs"] = list(range(100, 120))
variant["env_params"]["fixed_tasks"] = [t['goal'] for t in tasks]
variant["env_params"]["direction_in_degrees"] = True
variant["trainer_kwargs"]["train_context_decoder"] = True
variant["trainer_kwargs"]["backprop_q_loss_into_encoder"] = True
variant["saved_tasks_path"] = "examples/smac/ant_tasks.joblib" # TODO: update to point to correct file

sweeper = hyp.DeterministicHyperparameterSweeper(
search_space, default_parameters=variant,
)
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
variant['exp_id'] = exp_id
run_experiment(
pearl_experiment,
unpack_variant=True,
exp_prefix=exp_name,
mode=mode,
variant=variant,
time_in_mins=3 * 24 * 60 - 1,
use_gpu=gpu,
)

if __name__ == "__main__":
main()

72 changes: 72 additions & 0 deletions examples/smac/generate_cheetah_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
PEARL Experiment
"""

import rlkit.util.hyperparameter as hyp
from rlkit.launchers.launcher_util import run_experiment
from rlkit.torch.smac.base_config import DEFAULT_PEARL_CONFIG

from rlkit.torch.smac.pearl_launcher import pearl_experiment
from rlkit.util.io import load_local_or_remote_file


# @click.command()
# @click.option('--debug', is_flag=True, default=False)
# @click.option('--dry', is_flag=True, default=False)
# @click.option('--suffix', default=None)
# @click.option('--nseeds', default=1)
# @click.option('--mode', default='here_no_doodad')
# def main(debug, dry, suffix, nseeds, mode):
def main():
debug = True
dry = False
mode = 'here_no_doodad'
suffix = ''
nseeds = 1
gpu = True

path_parts = __file__.split('/')
suffix = '' if suffix is None else '--{}'.format(suffix)
exp_name = 'pearl-awac-{}--{}{}'.format(
path_parts[-2].replace('_', '-'),
path_parts[-1].split('.')[0].replace('_', '-'),
suffix,
)

if debug or dry:
exp_name = 'dev--' + exp_name
mode = 'here_no_doodad'
nseeds = 1

if dry:
mode = 'here_no_doodad'

print(exp_name)

search_space = {
'seed': list(range(nseeds)),
}
variant = DEFAULT_PEARL_CONFIG.copy()
variant["env_name"] = "cheetah-vel"
variant['trainer_kwargs']["train_context_decoder"] = True
variant["saved_tasks_path"] = "examples/smac/cheetah_tasks.joblib" # TODO: update to point to correct file

sweeper = hyp.DeterministicHyperparameterSweeper(
search_space, default_parameters=variant,
)
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()):
variant['exp_id'] = exp_id
run_experiment(
pearl_experiment,
unpack_variant=True,
exp_prefix=exp_name,
mode=mode,
variant=variant,
time_in_mins=3 * 24 * 60 - 1,
use_gpu=gpu,
)


if __name__ == "__main__":
main()

4 changes: 4 additions & 0 deletions rlkit/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'):
joblib.dump(data, file_name, compress=3)
elif mode == 'pickle':
pickle.dump(data, open(file_name, "wb"))
elif mode == 'cloudpickle':
import cloudpickle
full_filename = file_name + ".cpkl"
cloudpickle.dump(data, open(full_filename, "wb"))
else:
raise ValueError("Invalid mode: {}".format(mode))
return file_name
Expand Down
Loading

0 comments on commit 6a13e1b

Please sign in to comment.