-
Notifications
You must be signed in to change notification settings - Fork 555
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add semi-supervised meta actor-critic (#137)
* wip to add smac * finish adding smac (v1) * update README * add missing files * fix saving bug
- Loading branch information
Showing
50 changed files
with
6,365 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
Requirements that differ from base requirements: | ||
- python 3.6.5 | ||
- joblib==0.9.4 | ||
- numpy==1.18.5 | ||
|
||
Running these examples requires first generating the data, updating the main launch script to point to that generated data, and then launching the SMAC experiments. | ||
|
||
This can be done by first running | ||
```bash | ||
python examples/smac/generate_{ant|cheetah}_data.py | ||
``` | ||
which runs [PEARL](https://github.com/katerakelly/oyster) to generate multi-task data. | ||
This script will generate a directory and file of the form | ||
``` | ||
LOCAL_LOG_DIR/<experiment_prefix>/<foldername>/extra_snapshot_itrXYZ.cpkl | ||
``` | ||
|
||
You can then update the `examples/smac/{ant|cheetah}.py` file, where it says `TODO: update to point to correct file` to point to this file. | ||
Finally, run the SMAC script | ||
```bash | ||
python examples/smac/{ant|cheetah}.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from rlkit.torch.smac.base_config import DEFAULT_CONFIG | ||
from rlkit.torch.smac.launcher import smac_experiment | ||
import rlkit.util.hyperparameter as hyp | ||
|
||
|
||
# @click.command() | ||
# @click.option('--debug', is_flag=True, default=False) | ||
# @click.option('--dry', is_flag=True, default=False) | ||
# @click.option('--suffix', default=None) | ||
# @click.option('--nseeds', default=1) | ||
# @click.option('--mode', default='here_no_doodad') | ||
# def main(debug, dry, suffix, nseeds, mode): | ||
def main(): | ||
debug = True | ||
dry = False | ||
mode = 'here_no_doodad' | ||
suffix = '' | ||
nseeds = 1 | ||
gpu = True | ||
|
||
path_parts = __file__.split('/') | ||
suffix = '' if suffix is None else '--{}'.format(suffix) | ||
exp_name = 'pearl-awac-{}--{}{}'.format( | ||
path_parts[-2].replace('_', '-'), | ||
path_parts[-1].split('.')[0].replace('_', '-'), | ||
suffix, | ||
) | ||
|
||
if debug or dry: | ||
exp_name = 'dev--' + exp_name | ||
mode = 'here_no_doodad' | ||
nseeds = 1 | ||
|
||
variant = DEFAULT_CONFIG.copy() | ||
variant["env_name"] = "ant-dir" | ||
variant["env_params"]["direction_in_degrees"] = True | ||
search_space = { | ||
'load_buffer_kwargs.pretrain_buffer_path': [ | ||
"results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file | ||
], | ||
'saved_tasks_path': [ | ||
"examples/smac/ant_tasks.joblib", # TODO: update to point to correct file | ||
], | ||
'load_buffer_kwargs.start_idx': [ | ||
-1200, | ||
], | ||
'seed': list(range(nseeds)), | ||
} | ||
from rlkit.launchers.launcher_util import run_experiment | ||
sweeper = hyp.DeterministicHyperparameterSweeper( | ||
search_space, default_parameters=variant, | ||
) | ||
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): | ||
variant['exp_id'] = exp_id | ||
run_experiment( | ||
smac_experiment, | ||
unpack_variant=True, | ||
exp_prefix=exp_name, | ||
mode=mode, | ||
variant=variant, | ||
use_gpu=gpu, | ||
) | ||
|
||
print(exp_name) | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from rlkit.launchers.launcher_util import run_experiment | ||
from rlkit.torch.smac.launcher import smac_experiment | ||
from rlkit.torch.smac.base_config import DEFAULT_CONFIG | ||
import rlkit.util.hyperparameter as hyp | ||
|
||
|
||
# @click.command() | ||
# @click.option('--debug', is_flag=True, default=False) | ||
# @click.option('--dry', is_flag=True, default=False) | ||
# @click.option('--suffix', default=None) | ||
# @click.option('--nseeds', default=1) | ||
# @click.option('--mode', default='here_no_doodad') | ||
# def main(debug, dry, suffix, nseeds, mode): | ||
def main(): | ||
debug = True | ||
dry = False | ||
mode = 'here_no_doodad' | ||
suffix = '' | ||
nseeds = 1 | ||
gpu=True | ||
|
||
path_parts = __file__.split('/') | ||
suffix = '' if suffix is None else '--{}'.format(suffix) | ||
exp_name = 'pearl-awac-{}--{}{}'.format( | ||
path_parts[-2].replace('_', '-'), | ||
path_parts[-1].split('.')[0].replace('_', '-'), | ||
suffix, | ||
) | ||
|
||
if debug or dry: | ||
exp_name = 'dev--' + exp_name | ||
mode = 'here_no_doodad' | ||
nseeds = 1 | ||
|
||
print(exp_name) | ||
|
||
variant = DEFAULT_CONFIG.copy() | ||
variant["env_name"] = "cheetah-vel" | ||
search_space = { | ||
'load_buffer_kwargs.pretrain_buffer_path': [ | ||
"results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file | ||
], | ||
'saved_tasks_path': [ | ||
"examples/smac/cheetah_tasks.joblib", # TODO: update to point to correct file | ||
], | ||
'load_macaw_buffer_kwargs.rl_buffer_start_end_idxs': [ | ||
[(0, 1200)], | ||
], | ||
'load_macaw_buffer_kwargs.encoder_buffer_start_end_idxs': [ | ||
[(-400, None)], | ||
], | ||
'load_macaw_buffer_kwargs.encoder_buffer_matches_rl_buffer': [ | ||
False, | ||
], | ||
'algo_kwargs.use_rl_buffer_for_enc_buffer': [ | ||
False, | ||
], | ||
'algo_kwargs.train_encoder_decoder_in_unsupervised_phase': [ | ||
False, | ||
], | ||
'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [ | ||
False, | ||
], | ||
'algo_kwargs.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase': [ | ||
True, | ||
], | ||
'pretrain_offline_algo_kwargs.logging_period': [ | ||
25000, | ||
], | ||
'algo_kwargs.num_iterations': [ | ||
51, | ||
], | ||
'seed': list(range(nseeds)), | ||
} | ||
sweeper = hyp.DeterministicHyperparameterSweeper( | ||
search_space, default_parameters=variant, | ||
) | ||
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): | ||
variant['exp_id'] = exp_id | ||
run_experiment( | ||
smac_experiment, | ||
unpack_variant=True, | ||
exp_prefix=exp_name, | ||
mode=mode, | ||
variant=variant, | ||
use_gpu=gpu, | ||
) | ||
|
||
print(exp_name) | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import rlkit.util.hyperparameter as hyp | ||
from rlkit.launchers.launcher_util import run_experiment | ||
from rlkit.torch.smac.base_config import DEFAULT_PEARL_CONFIG | ||
from rlkit.torch.smac.pearl_launcher import pearl_experiment | ||
from rlkit.util.io import load_local_or_remote_file | ||
|
||
|
||
# @click.command() | ||
# @click.option('--debug', is_flag=True, default=False) | ||
# @click.option('--dry', is_flag=True, default=False) | ||
# @click.option('--suffix', default=None) | ||
# @click.option('--nseeds', default=1) | ||
# @click.option('--mode', default='here_no_doodad') | ||
# def main(debug, dry, suffix, nseeds, mode): | ||
def main(): | ||
debug = True | ||
dry = False | ||
mode = 'here_no_doodad' | ||
suffix = '' | ||
nseeds = 1 | ||
gpu = True | ||
|
||
path_parts = __file__.split('/') | ||
suffix = '' if suffix is None else '--{}'.format(suffix) | ||
exp_name = 'pearl-awac-{}--{}{}'.format( | ||
path_parts[-2].replace('_', '-'), | ||
path_parts[-1].split('.')[0].replace('_', '-'), | ||
suffix, | ||
) | ||
|
||
if debug or dry: | ||
exp_name = 'dev--' + exp_name | ||
mode = 'here_no_doodad' | ||
nseeds = 1 | ||
|
||
if dry: | ||
mode = 'here_no_doodad' | ||
|
||
print(exp_name) | ||
|
||
task_data = load_local_or_remote_file( | ||
"examples/smac/ant_tasks.joblib", # TODO: update to point to correct file | ||
file_type='joblib') | ||
tasks = task_data['tasks'] | ||
search_space = { | ||
'seed': list(range(nseeds)), | ||
} | ||
variant = DEFAULT_PEARL_CONFIG.copy() | ||
variant["env_name"] = "ant-dir" | ||
# variant["train_task_idxs"] = list(range(100)) | ||
# variant["eval_task_idxs"] = list(range(100, 120)) | ||
variant["env_params"]["fixed_tasks"] = [t['goal'] for t in tasks] | ||
variant["env_params"]["direction_in_degrees"] = True | ||
variant["trainer_kwargs"]["train_context_decoder"] = True | ||
variant["trainer_kwargs"]["backprop_q_loss_into_encoder"] = True | ||
variant["saved_tasks_path"] = "examples/smac/ant_tasks.joblib" # TODO: update to point to correct file | ||
|
||
sweeper = hyp.DeterministicHyperparameterSweeper( | ||
search_space, default_parameters=variant, | ||
) | ||
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): | ||
variant['exp_id'] = exp_id | ||
run_experiment( | ||
pearl_experiment, | ||
unpack_variant=True, | ||
exp_prefix=exp_name, | ||
mode=mode, | ||
variant=variant, | ||
time_in_mins=3 * 24 * 60 - 1, | ||
use_gpu=gpu, | ||
) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
PEARL Experiment | ||
""" | ||
|
||
import rlkit.util.hyperparameter as hyp | ||
from rlkit.launchers.launcher_util import run_experiment | ||
from rlkit.torch.smac.base_config import DEFAULT_PEARL_CONFIG | ||
|
||
from rlkit.torch.smac.pearl_launcher import pearl_experiment | ||
from rlkit.util.io import load_local_or_remote_file | ||
|
||
|
||
# @click.command() | ||
# @click.option('--debug', is_flag=True, default=False) | ||
# @click.option('--dry', is_flag=True, default=False) | ||
# @click.option('--suffix', default=None) | ||
# @click.option('--nseeds', default=1) | ||
# @click.option('--mode', default='here_no_doodad') | ||
# def main(debug, dry, suffix, nseeds, mode): | ||
def main(): | ||
debug = True | ||
dry = False | ||
mode = 'here_no_doodad' | ||
suffix = '' | ||
nseeds = 1 | ||
gpu = True | ||
|
||
path_parts = __file__.split('/') | ||
suffix = '' if suffix is None else '--{}'.format(suffix) | ||
exp_name = 'pearl-awac-{}--{}{}'.format( | ||
path_parts[-2].replace('_', '-'), | ||
path_parts[-1].split('.')[0].replace('_', '-'), | ||
suffix, | ||
) | ||
|
||
if debug or dry: | ||
exp_name = 'dev--' + exp_name | ||
mode = 'here_no_doodad' | ||
nseeds = 1 | ||
|
||
if dry: | ||
mode = 'here_no_doodad' | ||
|
||
print(exp_name) | ||
|
||
search_space = { | ||
'seed': list(range(nseeds)), | ||
} | ||
variant = DEFAULT_PEARL_CONFIG.copy() | ||
variant["env_name"] = "cheetah-vel" | ||
variant['trainer_kwargs']["train_context_decoder"] = True | ||
variant["saved_tasks_path"] = "examples/smac/cheetah_tasks.joblib" # TODO: update to point to correct file | ||
|
||
sweeper = hyp.DeterministicHyperparameterSweeper( | ||
search_space, default_parameters=variant, | ||
) | ||
for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): | ||
variant['exp_id'] = exp_id | ||
run_experiment( | ||
pearl_experiment, | ||
unpack_variant=True, | ||
exp_prefix=exp_name, | ||
mode=mode, | ||
variant=variant, | ||
time_in_mins=3 * 24 * 60 - 1, | ||
use_gpu=gpu, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.