From 6a13e1b63d3febb1057f46e6bb6b86948b9cab1e Mon Sep 17 00:00:00 2001 From: Vitchyr Pong Date: Fri, 6 Aug 2021 20:49:42 -0400 Subject: [PATCH] Add semi-supervised meta actor-critic (#137) * wip to add smac * finish adding smac (v1) * update README * add missing files * fix saving bug --- README.md | 19 +- docs/SMAC.md | 22 + examples/smac/ant.py | 71 ++ examples/smac/ant_tasks.joblib | Bin 0 -> 2977 bytes examples/smac/cheetah.py | 96 ++ examples/smac/cheetah_tasks.joblib | Bin 0 -> 3246 bytes examples/smac/generate_ant_data.py | 75 ++ examples/smac/generate_cheetah_data.py | 72 ++ rlkit/core/logging.py | 4 + rlkit/core/meta_rl_algorithm.py | 1078 +++++++++++++++++ rlkit/core/simple_offline_rl_algorithm.py | 161 +++ rlkit/core/timer.py | 51 + .../meta_learning_replay_buffer.py | 183 +++ .../multitask_replay_buffer.py | 300 +++++ rlkit/data_management/simple_replay_buffer.py | 2 +- rlkit/envs/pearl_envs/__init__.py | 57 + rlkit/envs/pearl_envs/ant.py | 70 ++ rlkit/envs/pearl_envs/ant_dir.py | 81 ++ rlkit/envs/pearl_envs/ant_goal.py | 45 + rlkit/envs/pearl_envs/ant_multitask_base.py | 43 + rlkit/envs/pearl_envs/ant_normal.py | 24 + rlkit/envs/pearl_envs/assets/ant.xml | 86 ++ .../pearl_envs/assets/low_gear_ratio_ant.xml | 84 ++ rlkit/envs/pearl_envs/half_cheetah.py | 26 + rlkit/envs/pearl_envs/half_cheetah_dir.py | 60 + rlkit/envs/pearl_envs/half_cheetah_vel.py | 65 + .../pearl_envs/hopper_rand_params_wrapper.py | 17 + rlkit/envs/pearl_envs/humanoid_dir.py | 59 + rlkit/envs/pearl_envs/mujoco_env.py | 62 + rlkit/envs/pearl_envs/point_robot.py | 170 +++ .../pearl_envs/rand_param_envs/__init__.py | 0 rlkit/envs/pearl_envs/rand_param_envs/base.py | 139 +++ .../rand_param_envs/hopper_rand_params.py | 54 + .../rand_param_envs/pr2_env_reach.py | 82 ++ .../rand_param_envs/walker2d_rand_params.py | 58 + .../pearl_envs/walker_rand_params_wrapper.py | 17 + rlkit/envs/pearl_envs/wrappers.py | 156 +++ rlkit/launchers/launcher_util.py | 22 +- rlkit/torch/smac/agent.py | 289 +++++ rlkit/torch/smac/base_config.py | 164 +++ rlkit/torch/smac/diagnostics.py | 27 + rlkit/torch/smac/launcher.py | 264 ++++ rlkit/torch/smac/launcher_util.py | 390 ++++++ rlkit/torch/smac/networks.py | 83 ++ rlkit/torch/smac/pearl.py | 347 ++++++ rlkit/torch/smac/pearl_launcher.py | 173 +++ rlkit/torch/smac/sampler.py | 308 +++++ rlkit/torch/smac/smac.py | 668 ++++++++++ rlkit/util/io.py | 8 +- rlkit/util/wrapper.py | 43 + 50 files changed, 6365 insertions(+), 10 deletions(-) create mode 100644 docs/SMAC.md create mode 100644 examples/smac/ant.py create mode 100644 examples/smac/ant_tasks.joblib create mode 100644 examples/smac/cheetah.py create mode 100755 examples/smac/cheetah_tasks.joblib create mode 100644 examples/smac/generate_ant_data.py create mode 100644 examples/smac/generate_cheetah_data.py create mode 100644 rlkit/core/meta_rl_algorithm.py create mode 100644 rlkit/core/simple_offline_rl_algorithm.py create mode 100644 rlkit/core/timer.py create mode 100644 rlkit/data_management/meta_learning_replay_buffer.py create mode 100644 rlkit/data_management/multitask_replay_buffer.py create mode 100644 rlkit/envs/pearl_envs/__init__.py create mode 100644 rlkit/envs/pearl_envs/ant.py create mode 100644 rlkit/envs/pearl_envs/ant_dir.py create mode 100644 rlkit/envs/pearl_envs/ant_goal.py create mode 100644 rlkit/envs/pearl_envs/ant_multitask_base.py create mode 100644 rlkit/envs/pearl_envs/ant_normal.py create mode 100644 rlkit/envs/pearl_envs/assets/ant.xml create mode 100644 rlkit/envs/pearl_envs/assets/low_gear_ratio_ant.xml create mode 100644 rlkit/envs/pearl_envs/half_cheetah.py create mode 100644 rlkit/envs/pearl_envs/half_cheetah_dir.py create mode 100644 rlkit/envs/pearl_envs/half_cheetah_vel.py create mode 100644 rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py create mode 100644 rlkit/envs/pearl_envs/humanoid_dir.py create mode 100644 rlkit/envs/pearl_envs/mujoco_env.py create mode 100644 rlkit/envs/pearl_envs/point_robot.py create mode 100644 rlkit/envs/pearl_envs/rand_param_envs/__init__.py create mode 100644 rlkit/envs/pearl_envs/rand_param_envs/base.py create mode 100644 rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py create mode 100644 rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py create mode 100644 rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py create mode 100644 rlkit/envs/pearl_envs/walker_rand_params_wrapper.py create mode 100644 rlkit/envs/pearl_envs/wrappers.py create mode 100644 rlkit/torch/smac/agent.py create mode 100644 rlkit/torch/smac/base_config.py create mode 100644 rlkit/torch/smac/diagnostics.py create mode 100644 rlkit/torch/smac/launcher.py create mode 100644 rlkit/torch/smac/launcher_util.py create mode 100644 rlkit/torch/smac/networks.py create mode 100644 rlkit/torch/smac/pearl.py create mode 100644 rlkit/torch/smac/pearl_launcher.py create mode 100644 rlkit/torch/smac/sampler.py create mode 100644 rlkit/torch/smac/smac.py create mode 100644 rlkit/util/wrapper.py diff --git a/README.md b/README.md index d77572a6b..4fd665194 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ Reinforcement learning framework and algorithms implemented in PyTorch. Implemented algorithms: + - Semi-supervised Meta Actor Critic + - [example script](examples/smac/ant.py) + - [paper](https://arxiv.org/abs/2107.03974) + - [Documentation](docs/SMAC.md) - Skew-Fit - [example script](examples/skewfit/sawyer_door.py) - [paper](https://arxiv.org/abs/1903.03698) @@ -222,8 +226,11 @@ Reinforcement Learning with Imagined Goals (RIG), run # References The algorithms are based on the following papers +[Offline Meta-Reinforcement Learning with Online Self-Supervision](https://arxiv.org/abs/2107.03974) +Vitchyr H. Pong, Ashvin Nair, Laura Smith, Catherine Huang, Sergey Levine. arXiv preprint, 2021. + [Skew-Fit: State-Covering Self-Supervised Reinforcement Learning](https://arxiv.org/abs/1903.03698). -Vitchyr H. Pong*, Murtaza Dalal*, Steven Lin*, Ashvin Nair, Shikhar Bahl, Sergey Levine. arXiv preprint, 2019. +Vitchyr H. Pong*, Murtaza Dalal*, Steven Lin*, Ashvin Nair, Shikhar Bahl, Sergey Levine. ICML, 2020. [Visual Reinforcement Learning with Imagined Goals](https://arxiv.org/abs/1807.04742). Ashvin Nair*, Vitchyr Pong*, Murtaza Dalal, Shikhar Bahl, Steven Lin, Sergey Levine. NeurIPS 2018. @@ -250,12 +257,14 @@ Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. ICML, 2018. Scott Fujimoto, Herke van Hoof, David Meger. ICML, 2018. # Credits +This repository was initially developed primarily by [Vitchyr Pong](https://github.com/vitchyr), until July 2021, at which point it was transferred to the RAIL Berkeley organization and is primarily maintained by [Ashvin Nair](https://github.com/anair13). +Other major collaborators and contributions: + - [Murtaza Dalal](https://github.com/mdalal2020) + - [Steven Lin](https://github.com/stevenlin1111) + A lot of the coding infrastructure is based on [rllab](https://github.com/rll/rllab). The serialization and logger code are basically a carbon copy of the rllab versions. The Dockerfile is based on the [OpenAI mujoco-py Dockerfile](https://github.com/openai/mujoco-py/blob/master/Dockerfile). -Other major collaborators and contributions: - - [Murtaza Dalal](https://github.com/mdalal2020) - - [Steven Lin](https://github.com/stevenlin1111) - - [Ashvin Nair](https://github.com/anair13) +The SMAC code builds off of the [PEARL code](https://github.com/katerakelly/oyster), which built off of an older RLKit version. diff --git a/docs/SMAC.md b/docs/SMAC.md new file mode 100644 index 000000000..3bda6e044 --- /dev/null +++ b/docs/SMAC.md @@ -0,0 +1,22 @@ +Requirements that differ from base requirements: + - python 3.6.5 + - joblib==0.9.4 + - numpy==1.18.5 + +Running these examples requires first generating the data, updating the main launch script to point to that generated data, and then launching the SMAC experiments. + +This can be done by first running +```bash +python examples/smac/generate_{ant|cheetah}_data.py +``` +which runs [PEARL](https://github.com/katerakelly/oyster) to generate multi-task data. +This script will generate a directory and file of the form +``` +LOCAL_LOG_DIR///extra_snapshot_itrXYZ.cpkl +``` + +You can then update the `examples/smac/{ant|cheetah}.py` file, where it says `TODO: update to point to correct file` to point to this file. +Finally, run the SMAC script +```bash +python examples/smac/{ant|cheetah}.py +``` diff --git a/examples/smac/ant.py b/examples/smac/ant.py new file mode 100644 index 000000000..66226aea2 --- /dev/null +++ b/examples/smac/ant.py @@ -0,0 +1,71 @@ +from rlkit.torch.smac.base_config import DEFAULT_CONFIG +from rlkit.torch.smac.launcher import smac_experiment +import rlkit.util.hyperparameter as hyp + + +# @click.command() +# @click.option('--debug', is_flag=True, default=False) +# @click.option('--dry', is_flag=True, default=False) +# @click.option('--suffix', default=None) +# @click.option('--nseeds', default=1) +# @click.option('--mode', default='here_no_doodad') +# def main(debug, dry, suffix, nseeds, mode): +def main(): + debug = True + dry = False + mode = 'here_no_doodad' + suffix = '' + nseeds = 1 + gpu = True + + path_parts = __file__.split('/') + suffix = '' if suffix is None else '--{}'.format(suffix) + exp_name = 'pearl-awac-{}--{}{}'.format( + path_parts[-2].replace('_', '-'), + path_parts[-1].split('.')[0].replace('_', '-'), + suffix, + ) + + if debug or dry: + exp_name = 'dev--' + exp_name + mode = 'here_no_doodad' + nseeds = 1 + + variant = DEFAULT_CONFIG.copy() + variant["env_name"] = "ant-dir" + variant["env_params"]["direction_in_degrees"] = True + search_space = { + 'load_buffer_kwargs.pretrain_buffer_path': [ + "results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file + ], + 'saved_tasks_path': [ + "examples/smac/ant_tasks.joblib", # TODO: update to point to correct file + ], + 'load_buffer_kwargs.start_idx': [ + -1200, + ], + 'seed': list(range(nseeds)), + } + from rlkit.launchers.launcher_util import run_experiment + sweeper = hyp.DeterministicHyperparameterSweeper( + search_space, default_parameters=variant, + ) + for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): + variant['exp_id'] = exp_id + run_experiment( + smac_experiment, + unpack_variant=True, + exp_prefix=exp_name, + mode=mode, + variant=variant, + use_gpu=gpu, + ) + + print(exp_name) + + + + +if __name__ == "__main__": + main() + diff --git a/examples/smac/ant_tasks.joblib b/examples/smac/ant_tasks.joblib new file mode 100644 index 0000000000000000000000000000000000000000..2954be8dfd89388f4c4dc9ef2f5341695dc2f61b GIT binary patch literal 2977 zcmV;S3tseEMlg6WGd5%(ARr(hARr(hARr(hcwSZ5cT`kI8^B={D`G>e*cHWw9Z@W! zSb4=VAfnRN6&MBt-n*1268mWEsAx2HV~JQ28&Pbr8#VS8ON?kBqQR1=iShHBJNI1A zjeo4??C)WB@7`%=r#NS#lUHbICnu*An&WI|wxpuy zZa?RT5v(h)9>w~yH89(}ua|Ck6ckIap}Jt;Pmt-0B{ z9;L&{n^b~c0$WgQDO)SE?F+wO#yK;YU~7SGD0<7*)@<81RjhmSXkUWu1h%KxLAH)& z8|Zv)R=&+fu#>>f6n$jtVz!S(E-@npCJ^)$*p*^8+5F5lsM3z0oN4_Db{E)#Vo%w6 znJuq{J+VilG=fB6Z;E|n^EaE_r%ur8mv(}E1@@!ZU$y~en{=R7v%4pd;6QHn80v~5wb;^t=)|kt{w%$2u2CCQ5-9q z-E3|((mzHN%keT!U^K-T*<#K1q}!RBMP7*nnZP)TT()?#&7s|PzRHLpm>_UG#YEYX z%vP|q`cq#w`F_a)k)kG>!)zIq!~5Hp%j-)Km`X8Cw)9W;sOO!))aSzqekO1N#fh?I zn9Zp$G_=igIR+;Q%%u3aY?ICQ+O6l@l!F?C@HrIdXi?7Wf6lIkL?)ThRPHPfjJu~T0AB7s>H z7t6N9Z27-=y03dD+fsqq6qm`i+-!fJJR1{sxCX&51+JjDQns(m*2FceaCVxU532;O zrnpA7wV&qL1MgpxM#=HHPT+cq8)VyPwyg<+7ex;bB>1(!O%ylFw#97aCy&__d|r;(&<=L}A`R$Vkek<@hio0doW41rty89m4 z_o5f>6}XS$e%W$AeV=WfT$-_3zUKD=4^TWPTi&O+;^^@HQG?C|4+%U>@rZ0k&GyWZ z&zBsRWA~WAA1EG|?S$EGq~AWib7CaHlLAjsJT2Q9v#rj`KIXPse$KN3&r$qQw)19t zJip51zv{^Qa6#Zjia*JA$!x=D{MsY=^1fddc!lCs*{+!_%EkRc>q~MhT^D$R;!W9Z zneFoK5A};v;|bmtc!%Q8vfceOUN)!s^ofllcu(MciVtM_#cbR=tMQv;dH)^?{FUM( z+44Vq*VMb_?Pnnb9}9d!@i*BD%(kId&Q-G4onWEB-zh$ot;lS-5B87lP?ivUCh$4M zKV&O5+pNI1-g6qt>w6*aCB;{=y*67>*+T&h3uOCK;9nHq$oAH3-PcEyd{9q*XYU05 zP4OSu-kYt2|Ae+JcSR8VSKtSV|C8;b+5FesdA`*@nP^UOM-y5J`AD^r^3jp0Im`b{ z*Z1^>vTtpbiB?Lv(#n-FTv>}-nR{$OaBU=7IpxYL=VG`D7PqSXtg#Dy5{c%jTt($7 z8LqO$9c=ZG-)}OUXjPP}s+^nQs#%o@(X>*k+DMYKK zTy5p*7|z4u>Un0oi|`I4T3zMpDOcZc4J@uz%>!Y7Op>qHP`O6RH8xxmiyP4U*6~j5 z!->{ZIZx%98Lqj-&N9mFuBgPs8=HID5`S_ia~gL?g=eR<4iX{4MU(?J-%ISLC?rt6V?j z`WtS5#U1)#&-zy#V~93Txk1VW7!E9M`>k`UUJR1sez0b8Rp+tqtx8(Z{ zQ*O9&)Nnx-SKN5x)Os5ui59Hf2<1WyH`3zzG+nu8Wx57MQ6>qRMNQ*Nx`>=rjFX2!V-g=2{}PPu61Vhk5+aeH@l zb3a#J-e0C%oO0Z7@fKHavd7~!A@Y3_lpC*HqT!M(jxLG1KXFfN)xm4xS43})hxi6HP zW4O5%S1X`l)vFbrB zglOxOTd&*(!)>%UlH<4Ia0U6je68FjGZuHR4bk2^Xt?thcO`$#!24CAh;~7_ zi^~0ExJwqdzo1r^q<@~-Ba$qat{pmi^bXE=)JkiS}d7iW)l zXhpJyYtQ7Pm;M_8lmure1*M@3l!bCo9$cUTxI#s!1eKu*R0TJv2JTQDYCuh>1+}3L zctBmK2lb%=G=xUb7@9y+@Pua29K4_fw1igB8rpz2w1sxi9y&lr=med?2fBbSbcJr< z2i>6u^n_kOpf~gZf9MPSpg#C24aCh9B_z-1Q-vAkOawqpn(HY zAQjReJvy|)f4(+{R;*s+h5jPUnN?=hSWV`^8n7m;Icvq*vW~0^^JBf3KO4Y+4P`-W zBpbsbnVrS5c$UN*ES+Vr$!t2C#pbewYzbS=zG7?HMz)3Juw85q%Vl}&C_BNt|Q4A56J^>g4*7lI!P1t{(=ue$wOmF^=nJ zG_D`OxPH3g`q7E&=OV5jdboa~;relf>t_+JA2hgris1SYf$RG}*Ee^r@7!G9mbt#? za(zSP`fkbft&r>c9M?B7uJ2G>-+s8hw{U&q;QFq>^(CL{>v?pji}emUw3qV9aP5_R z^y2zL&h=HA>q{`#*H^ADo?KrUxxVaieNE%~g2naKi0exY*VhrQFA`i|0k}TBxjuKf XJ~6pI>$pCpxISN^Q@#HWnrs0=k%#lB literal 0 HcmV?d00001 diff --git a/examples/smac/cheetah.py b/examples/smac/cheetah.py new file mode 100644 index 000000000..4ac5cd99b --- /dev/null +++ b/examples/smac/cheetah.py @@ -0,0 +1,96 @@ +from rlkit.launchers.launcher_util import run_experiment +from rlkit.torch.smac.launcher import smac_experiment +from rlkit.torch.smac.base_config import DEFAULT_CONFIG +import rlkit.util.hyperparameter as hyp + + +# @click.command() +# @click.option('--debug', is_flag=True, default=False) +# @click.option('--dry', is_flag=True, default=False) +# @click.option('--suffix', default=None) +# @click.option('--nseeds', default=1) +# @click.option('--mode', default='here_no_doodad') +# def main(debug, dry, suffix, nseeds, mode): +def main(): + debug = True + dry = False + mode = 'here_no_doodad' + suffix = '' + nseeds = 1 + gpu=True + + path_parts = __file__.split('/') + suffix = '' if suffix is None else '--{}'.format(suffix) + exp_name = 'pearl-awac-{}--{}{}'.format( + path_parts[-2].replace('_', '-'), + path_parts[-1].split('.')[0].replace('_', '-'), + suffix, + ) + + if debug or dry: + exp_name = 'dev--' + exp_name + mode = 'here_no_doodad' + nseeds = 1 + + print(exp_name) + + variant = DEFAULT_CONFIG.copy() + variant["env_name"] = "cheetah-vel" + search_space = { + 'load_buffer_kwargs.pretrain_buffer_path': [ + "results/.../extra_snapshot_itr100.cpkl" # TODO: update to point to correct file + ], + 'saved_tasks_path': [ + "examples/smac/cheetah_tasks.joblib", # TODO: update to point to correct file + ], + 'load_macaw_buffer_kwargs.rl_buffer_start_end_idxs': [ + [(0, 1200)], + ], + 'load_macaw_buffer_kwargs.encoder_buffer_start_end_idxs': [ + [(-400, None)], + ], + 'load_macaw_buffer_kwargs.encoder_buffer_matches_rl_buffer': [ + False, + ], + 'algo_kwargs.use_rl_buffer_for_enc_buffer': [ + False, + ], + 'algo_kwargs.train_encoder_decoder_in_unsupervised_phase': [ + False, + ], + 'algo_kwargs.freeze_encoder_buffer_in_unsupervised_phase': [ + False, + ], + 'algo_kwargs.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase': [ + True, + ], + 'pretrain_offline_algo_kwargs.logging_period': [ + 25000, + ], + 'algo_kwargs.num_iterations': [ + 51, + ], + 'seed': list(range(nseeds)), + } + sweeper = hyp.DeterministicHyperparameterSweeper( + search_space, default_parameters=variant, + ) + for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): + variant['exp_id'] = exp_id + run_experiment( + smac_experiment, + unpack_variant=True, + exp_prefix=exp_name, + mode=mode, + variant=variant, + use_gpu=gpu, + ) + + print(exp_name) + + + + +if __name__ == "__main__": + main() + diff --git a/examples/smac/cheetah_tasks.joblib b/examples/smac/cheetah_tasks.joblib new file mode 100755 index 0000000000000000000000000000000000000000..8c5d69b7c63940b5aab7f749dd91e95f1d49baec GIT binary patch literal 3246 zcmV;f3{mr1Mlg6WH8eLMARr(hARr(hARr(hcwSZ5d3X)i8^H0%LKcD$VryuSR0x`& zTFfQ(*LqY%iHIbeoFtNW5;O>E8H$#lC25sbZM9Xkib_>0EsbjJMeR}A*cxjrvHZ?? z&ogT75?17*DP$$@O9bVu-nP? z2Z~O+pEj_Ag&i4olFhJN#Hop`Z!VL^>9Vjh!!EK#+wDlj7DX}L9~jux!fp(^%htng z->x2U#qXZG3wv7Fi(zlsz;20kJ!e&`>fypZ7WQS>PqzMcd#Os{nusBH3>;wLK!!}V zL3ZoSdL0RTe9yo)Ega17E!l?H%_k=~G}iB_fiV^iW%#yi!|XOMxc<>8o#e3$w=kCB z2-)K7wkdS`z$KL~8#vO!c!r~7OR(FJE5rYNzTt*}i54a?94*@zyT$moYj8Z!&4t{; zWQIbv6ub3XzQ+2Z?4g0F7N#*AD_gqVntikN@wCQ824+}@3{BZG?bhkH8(~{c$oH3J z;W&olWt(8PeR=!*Q-l99@Er@Y8BUaKQu(>{UbJo2tj7jUwlIg`6xpWQ?e5{T*8-o& z&w84L?=qY&+YGxMjm!zpFOWFA*7AN>XyN}DE|P7r-S(t~{kUA_R*T8Qq+|F=^Y&-2XI=5D2tzGgt zzqN1|!+hDkvs+=!iA$2_>^E??h2JyWBij$<_tE^K{#)+IYxhSB_cAPy?I*iE88Glr z>4Dn@{%qkshWlmv#coMHOS<+uCy%qx!UGI{mF=M2#QN&h+AopU=WiDN&hU_Ihs)1< zcjNq^q4M)QV&PGS$7DNhx02!&g)OUjy6}XBe=s~L+bO%f+$3+qhrOQ}c-q208J>~t zZ29$4yV5I%gUSp%XW@B<7i7C=w_nCZO?*%yU+ zwuN^X-j(fM`Rk2c*QsQ=uw+<&TtD zc^;lx_>AFm*)O5HaKAS5?s94V1H}4jiuVv*x)aaq_zFL#{8me%keSxD|&3;s<-n*BwA^AUUSpAcyPoeL&Rgi?%?Dy`|j{hvS#)^}KdMUN35bw9KL8-qvoI!&TnBFR-khJig)NV#$rrF3#aTxt}q0 z>m2!fBgw^+8>L->!>zv3uHLNh`-YiFE{WV|?Z!Bq=in!kPG-pS%gH5^6WXOXT&qUc z=jVSVUpJLp8o9CBr90f+q+RbygzKgpCTv{p5$9!_{ zllwrs1rFES{cg((*Y6tULUR8jw@ACi4rgBc`Sgb~-h+h}<&mK6bdqpP!E4i=8eXZR#hpRuN&6e5=p=#BlkPGL)sm7xS+q&ujbGE z)-aEdJ4)`DcE=sARes52SAksrPLTV9+)3?DIoxs2$Vof9%Jum)xj)IB(eA9nC0^~~ z_RXvU!#qdsJh=s9j@(F|4a9SaBhI&&#=_8oBD^0+dUN z2~dCPG2@fR#;HH}#3heSOioD3G;7EjYu1#HVd`%-;0_+(30_bUDuFlnfG_w#WvBxF zP!*~{bqIhOP!j^77Sx6yr~`GO9@K{h@Dc<=LudpqLkKj6CeRc@Aq-xDW)Kd|p#?-h zOK1hH;Z=x)C};z1p&h&i?csHJ13ExQ=mZA1pfhxVXy^*vpgZ({p3n<=13(|>3;m!! z41j^aU=X|sgW)Y00x>WY-iBc?9AaSv#KB02hf$D_6jNQk*QD|BssDYef%H~u(OXS~ zBp3~2fI~6}NP$#HgRzhf8Gv9yCS<`l7!MQR9ms}>FbO6@4orcmFb&>Kit+o;N118X zrY8M{wZJ0Kj)kHmwX%F$-m=2@SpfE{2)KfkMmRfEWgPA;{Wg) zyoBH54|y4XA=Fn8p}t=T_2ogRZw5k*{X)(3LJjOfP2)n1+Ct6ILJi46O~68pw?fUX zLJg)uO`SrGltRsmLJfmLO?pC&aYD^#LJeR-O;C+5?5!;Dp+} zgxZ>f+INK7RD{|wgxU^-TI+>c$c0*^g<67zT3>})JcU{rg 0: + # always add the last epoch in case user had an OBO error + self.save_extra_manual_epoch_set.add(num_iterations - 1) + + self.save_extra_manual_beginning_epoch_list = save_extra_manual_beginning_epoch_list + self.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase = ( + use_encoder_snapshot_for_reward_pred_in_unsupervised_phase + ) + self.env = env + self.agent = agent + self.trainer = trainer + self.exploration_agent = agent # Can potentially use a different policy purely for exploration rather than also solving tasks, currently not being used + self.train_task_indices = train_task_indices + self.exploration_task_indices = train_task_indices + self.offline_train_task_indices = train_task_indices + self.eval_task_indices = eval_task_indices + self.train_tasks = train_tasks + self.eval_tasks = eval_tasks + self.meta_batch = meta_batch + self.num_iterations = num_iterations + self.num_train_steps_per_itr = num_train_steps_per_itr + self.num_initial_steps = num_initial_steps + self.num_tasks_sample = num_tasks_sample + self.num_steps_prior = num_steps_prior + self.num_steps_posterior = num_steps_posterior + self.num_extra_rl_steps_posterior = num_extra_rl_steps_posterior + self.num_evals = num_evals + self.num_steps_per_eval = num_steps_per_eval + self.batch_size = batch_size + self.embedding_batch_size = embedding_batch_size + self.embedding_mini_batch_size = embedding_mini_batch_size + self.max_path_length = max_path_length + self.discount = discount + self.replay_buffer_size = replay_buffer_size + self.reward_scale = reward_scale + self.update_post_train = update_post_train + self.post_train_funcs = [] + self.num_exp_traj_eval = num_exp_traj_eval + self.eval_deterministic = eval_deterministic + self.render = render + self.sparse_rewards = sparse_rewards + self.use_next_obs_in_context = use_next_obs_in_context + self.save_replay_buffer = save_replay_buffer + self.save_algorithm = save_algorithm + self.save_environment = save_environment + if num_iterations_with_reward_supervision is None: + num_iterations_with_reward_supervision = np.inf + self.num_iterations_with_reward_supervision = num_iterations_with_reward_supervision + self.freeze_encoder_buffer_in_unsupervised_phase = ( + freeze_encoder_buffer_in_unsupervised_phase + ) + self.clear_encoder_buffer_before_every_update = ( + clear_encoder_buffer_before_every_update + ) + self.expl_data_collector = exploration_data_collector + self.eval_data_collector = evaluation_data_collector + + self.eval_statistics = None + self.render_eval_paths = render_eval_paths + self.dump_eval_paths = dump_eval_paths + self.plotter = plotter + + self.exploration_resample_latent_period = exploration_resample_latent_period + self.exploration_update_posterior_period = exploration_update_posterior_period + self.sampler = SMACInPlacePathSampler( + env=env, + policy=agent, + max_path_length=self.max_path_length, + ) + + self.meta_replay_buffer = None + self.replay_buffer = None + self.enc_replay_buffer = None + self.meta_replay_buffer = MetaLearningReplayBuffer( + self.replay_buffer_size, + env, + self.train_task_indices, + use_next_obs_in_context=use_next_obs_in_context, + sparse_rewards=sparse_rewards, + mini_buffer_max_size=self.max_path_length + max( + self.num_steps_prior, + self.num_steps_posterior, + self.num_extra_rl_steps_posterior, + ), + sample_buffer_in_proportion_to_size=sample_buffer_in_proportion_to_size, + ) + self.replay_buffer = MultiTaskReplayBuffer( + self.replay_buffer_size, + env, + self.train_task_indices, + use_next_obs_in_context=use_next_obs_in_context, + sparse_rewards=sparse_rewards, + env_info_sizes=env_info_sizes, + ) + if self.use_rl_buffer_for_enc_buffer: + self.enc_replay_buffer = self.replay_buffer + else: + self.enc_replay_buffer = MultiTaskReplayBuffer( + self.replay_buffer_size, + env, + self.train_task_indices, + use_next_obs_in_context=use_next_obs_in_context, + sparse_rewards=sparse_rewards, + use_ground_truth_context=use_ground_truth_context, + ground_truth_tasks=train_tasks, + env_info_sizes=env_info_sizes, + ) + + self._n_env_steps_total = 0 + self._n_train_steps_total = 0 + self._n_rollouts_total = 0 + self._do_train_time = 0 + self._epoch_start_time = None + self._algo_start_time = None + self._old_table_keys = None + self._current_path_builder = PathBuilder() + self._exploration_paths = [] + self.in_unsupervised_phase = False + self._debug_use_ground_truth_context = use_ground_truth_context + + self._reward_decoder_buffer = self.enc_replay_buffer + self.fake_task_idx_to_z = {} + + def train(self): + ''' + meta-training loop + ''' + start_time = time.time() + print("starting to pretrain") + self.pretrain() + print("done pretraining after time:", time.time() - start_time) + params = self.get_epoch_snapshot(-1) + logger.save_itr_params(-1, params) + gt.reset() + gt.set_def_unique(False) + self._current_path_builder = PathBuilder() + + # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate + for it_ in gt.timed_for( + range(self.num_iterations), + save_itrs=True, + ): + self._start_epoch(it_) + self.training_mode(True) + if it_ == 0 and self.num_initial_steps > 0: + print('collecting initial pool of data for train and eval') + # temp for evaluating + for task_idx in self.train_task_indices: + if self.expl_data_collector: + init_expl_paths = self.expl_data_collector.collect_new_paths( + max_path_length=self.max_path_length, + num_steps=self.num_initial_steps, + discard_incomplete_paths=False, + task_idx=task_idx, + ) + self.replay_buffer.add_paths(task_idx, init_expl_paths) + if not self.use_rl_buffer_for_enc_buffer: + self.enc_replay_buffer.add_paths(task_idx, init_expl_paths) + self.expl_data_collector.end_epoch(-1) + else: + self.collect_exploration_data( + self.num_initial_steps, 1, np.inf, task_idx) + self.in_unsupervised_phase = (it_ >= self.num_iterations_with_reward_supervision) + if it_ == self.num_iterations_with_reward_supervision: + self._transition_to_unsupervised() + update_encoder_buffer = not ( + self.in_unsupervised_phase + and self.freeze_encoder_buffer_in_unsupervised_phase + ) and not self.use_rl_buffer_for_enc_buffer + clear_encoder_buffer = ( + update_encoder_buffer + and self.clear_encoder_buffer_before_every_update + ) and not self.use_rl_buffer_for_enc_buffer + # TODO: propogate unsupervised mode elegantly + # Sample data from train tasks. + for i in range(self.num_tasks_sample): + if len(self.exploration_task_indices) == 0: + # do no data collection + break + task_idx = np.random.choice(self.exploration_task_indices) + if clear_encoder_buffer: + self.enc_replay_buffer.task_buffers[task_idx].clear() + # collect some trajectories with z ~ prior + if self.num_steps_prior > 0: + if self.expl_data_collector: + # TODO: implement + new_expl_paths = self.expl_data_collector.collect_new_paths( + task_idx=task_idx, + max_path_length=self.max_path_length, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=np.inf, + num_steps=self.num_steps_prior, + use_predicted_reward=self.in_unsupervised_phase, + discard_incomplete_paths=False, + ) + self.replay_buffer.add_paths(task_idx, new_expl_paths) + self._n_env_steps_total += sum( + len(p['actions']) for p in new_expl_paths + ) + self._n_rollouts_total += len(new_expl_paths) + if update_encoder_buffer: + self.enc_replay_buffer.add_paths(task_idx, new_expl_paths) + else: + self.collect_exploration_data( + num_samples=self.num_steps_prior, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=np.inf, + add_to_enc_buffer=update_encoder_buffer, + use_predicted_reward=self.in_unsupervised_phase, + task_idx=task_idx, + # TODO: figure out if I want to replace this? + # it's only used when `clear_encoder_buffer_before_every_update` is True + # and when `freeze_encoder_buffer_in_unsupervised_phase` is False + # and when we're in unsupervised phase + ) + # collect some trajectories with z ~ posterior + if self.num_steps_posterior > 0: + if self.expl_data_collector: + # TODO: implement + new_expl_paths = self.expl_data_collector.collect_new_paths( + task_idx=task_idx, + max_path_length=self.max_path_length, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=self.update_post_train, + num_steps=self.num_steps_posterior, + use_predicted_reward=self.in_unsupervised_phase, + discard_incomplete_paths=False, + ) + self.replay_buffer.add_paths(task_idx, new_expl_paths) + self._n_env_steps_total += sum( + len(p['actions']) for p in new_expl_paths + ) + self._n_rollouts_total += len(new_expl_paths) + if update_encoder_buffer and not self.use_rl_buffer_for_enc_buffer: + self.enc_replay_buffer.add_paths(task_idx, new_expl_paths) + else: + self.collect_exploration_data( + num_samples=self.num_steps_posterior, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=self.update_post_train, + add_to_enc_buffer=update_encoder_buffer, + use_predicted_reward=self.in_unsupervised_phase, + task_idx=task_idx, + ) + # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior + if self.num_extra_rl_steps_posterior > 0: + # TODO: implement + if self.expl_data_collector: + new_expl_paths = self.expl_data_collector.collect_new_paths( + task_idx=task_idx, + max_path_length=self.max_path_length, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=self.update_post_train, + num_steps=self.num_extra_rl_steps_posterior, + use_predicted_reward=self.in_unsupervised_phase, + discard_incomplete_paths=False, + ) + self.replay_buffer.add_paths(task_idx, new_expl_paths) + self._n_env_steps_total += sum( + len(p['actions']) for p in new_expl_paths + ) + self._n_rollouts_total += len(new_expl_paths) + if not self.use_rl_buffer_for_enc_buffer: + self.enc_replay_buffer.add_paths(task_idx, new_expl_paths) + else: + add_to_enc_buffer = ( + self.debug_enc_buffer_matches_rl_buffer + and not self.use_rl_buffer_for_enc_buffer + ) + self.collect_exploration_data( + num_samples=self.num_extra_rl_steps_posterior, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=self.update_post_train, + add_to_enc_buffer=add_to_enc_buffer, + use_predicted_reward=self.in_unsupervised_phase, + task_idx=task_idx, + ) + gt.stamp('sample') + + # Sample train tasks and compute gradient updates on parameters. + for train_step in range(self.num_train_steps_per_itr): + if self.use_meta_learning_buffer: + batch = self.meta_replay_buffer.sample_meta_batch( + rl_batch_size=self.batch_size, + meta_batch_size=self.meta_batch, + embedding_batch_size=self.embedding_batch_size, + ) + self.trainer.train(batch) + else: + indices = np.random.choice(self.train_task_indices, self.meta_batch) + + mb_size = self.embedding_mini_batch_size + num_updates = self.embedding_batch_size // mb_size + + # sample context batch + # context_batch = self.sample_context(indices) + context_batch = self.enc_replay_buffer.sample_context( + indices, + self.embedding_batch_size + ) + + # zero out context and hidden encoder state + # self.agent.clear_z(num_tasks=len(indices)) + + # do this in a loop so we can truncate backprop in the recurrent encoder + for i in range(num_updates): + if self._debug_use_ground_truth_context: + context = context_batch + else: + context = context_batch[:, i * mb_size: i * mb_size + mb_size, :] + # batch = self.sample_batch(indices) + batch = self.replay_buffer.sample_batch(indices, self.batch_size) + batch['context'] = context + batch['task_indices'] = indices + self.trainer.train(batch) + self._n_train_steps_total += 1 + + # stop backprop + # self.agent.detach_z() + # train_data = self.replay_buffer.random_batch(self.batch_size) + gt.stamp('train') + + self.training_mode(False) + + # eval + self._try_to_eval(it_) + + self._end_epoch(it_) + + def _transition_to_unsupervised(self): + self._reward_decoder_buffer = copy.deepcopy(self.enc_replay_buffer) + self.trainer.train_encoder_decoder = self.train_encoder_decoder_in_unsupervised_phase + self.trainer.train_agent = self.train_agent_in_unsupervised_phase + self.agent.use_context_encoder_snapshot_for_reward_pred = ( + self.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase + ) + + def pretrain(self): + """ + Do anything before the main training phase. + """ + # HACK: I'm assuming the train and eval task indices are consecutive. + num_existing_tasks = len(self.offline_train_task_indices) + len(self.eval_task_indices) + fake_task_idxs = list(range( + num_existing_tasks, + num_existing_tasks + self._num_tasks_to_generate, + )) + if self.add_exploration_data_to == 'self_generated_tasks': + self.exploration_task_indices = fake_task_idxs + elif self.add_exploration_data_to == 'train_tasks': + self.exploration_task_indices = self.offline_train_task_indices + elif self.add_exploration_data_to == 'train_and_self_generated_tasks': + self.exploration_task_indices = ( + self.offline_train_task_indices + fake_task_idxs + ) + elif self.add_exploration_data_to == 'none': + self.exploration_task_indices = [] + self.num_tasks_sample = 0 + else: + raise ValueError(self.add_exploration_data_to) + self.fake_task_idx_to_z = { + task_idx: ptu.get_numpy(self.agent.latent_prior.sample()) + for task_idx in fake_task_idxs + } + for task_idx in self.fake_task_idx_to_z: + if not self.use_rl_buffer_for_enc_buffer: + self.enc_replay_buffer.create_new_task_buffer(task_idx) + self.replay_buffer.create_new_task_buffer(task_idx) + self.collect_exploration_data( + self.num_initial_steps_self_generated_tasks, 1, np.inf, task_idx, + ) + + def collect_exploration_data(self, num_samples, + resample_latent_period, update_posterior_period, task_idx, add_to_enc_buffer=True, use_predicted_reward=False, + ): + ''' + get trajectories from current env in batch mode with given policy + collect complete trajectories until the number of collected transitions >= num_samples + + :param agent: policy to rollout + :param num_samples: total number of transitions to sample + :param resample_latent_period: how often to resample latent context z (in units of trajectories) + :param update_posterior_period: how often to update q(z | c) from which z is sampled (in units of trajectories) + :param add_to_enc_buffer: whether to add collected data to encoder replay buffer + :param use_predicted_reward: whether to replace the env reward with the predicted reward to simulate not having access to rewards. + ''' + # start from the prior + self.agent.clear_z() + + num_transitions = 0 + init_context = None + while num_transitions < num_samples: + initialized_z_reward = None + initial_reward_context = None + if task_idx in self.fake_task_idx_to_z: + initialized_z_reward = self.fake_task_idx_to_z[task_idx] + use_predicted_reward = True + else: + if use_predicted_reward: + if self.use_meta_learning_buffer: + initial_reward_context = self.meta_replay_buffer.sample_context( + self.embedding_batch_size + ) + else: + initial_reward_context = self._reward_decoder_buffer.sample_context( + task_idx, + self.embedding_batch_size + ) + # TODO: replace with sampler + paths, n_samples = self.sampler.obtain_samples( + max_samples=num_samples - num_transitions, + max_trajs=update_posterior_period, + accum_context=self._condition_on_posterior_guided_data_when_exploring, + resample_latent_period=resample_latent_period, + update_posterior_period=self.exploration_update_posterior_period, + use_predicted_reward=use_predicted_reward, + task_idx=task_idx, + initial_context=init_context, + initial_reward_context=initial_reward_context, + initialized_z_reward=initialized_z_reward, + ) + num_transitions += n_samples + self._n_rollouts_total += len(paths) + if self.use_meta_learning_buffer: + self.meta_replay_buffer.add_paths(paths) + else: + self.replay_buffer.add_paths(task_idx, paths) + if add_to_enc_buffer and not self.use_rl_buffer_for_enc_buffer: + self.enc_replay_buffer.add_paths(task_idx, paths) + if update_posterior_period != np.inf: + # init_context = self.sample_context(task_idx) + if self._condition_on_posterior_guided_data_when_exploring: + init_context = paths[-1]['context'] # TODO clean hack + else: + # TODO: check if it matters which version I use: above code or below? + init_context = self.enc_replay_buffer.sample_context( + task_idx, + self.embedding_batch_size + ) + init_context = ptu.from_numpy(init_context) + self._n_env_steps_total += num_transitions + + def _try_to_eval(self, epoch): + if epoch % self.logging_period != 0: + return + if epoch in self.save_extra_manual_epoch_set: + logger.save_extra_data( + self.get_extra_data_to_save(epoch), + file_name='extra_snapshot_itr{}'.format(epoch), + mode='cloudpickle', + ) + if self._save_extra_every_epoch: + logger.save_extra_data(self.get_extra_data_to_save(epoch)) + gt.stamp('save-extra') + if self._can_evaluate(): + self.evaluate(epoch) + gt.stamp('eval') + + params = self.get_epoch_snapshot(epoch) + logger.save_itr_params(epoch, params) + gt.stamp('save-snapshot') + table_keys = logger.get_table_key_set() + if self._old_table_keys is not None: + assert table_keys == self._old_table_keys, ( + "Table keys cannot change from iteration to iteration." + ) + self._old_table_keys = table_keys + + logger.record_dict( + self.trainer.get_diagnostics(), + prefix='trainer/', + ) + + logger.record_tabular( + "Number of train steps total", + self._n_train_steps_total, + ) + logger.record_tabular( + "Number of env steps total", + self._n_env_steps_total, + ) + logger.record_tabular( + "Number of rollouts total", + self._n_rollouts_total, + ) + + times_itrs = gt.get_times().stamps.itrs + train_time = times_itrs['train'][-1] + sample_time = times_itrs['sample'][-1] + save_extra_time = times_itrs['save-extra'][-1] + save_snapshot_time = times_itrs['save-snapshot'][-1] + eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 + epoch_time = train_time + sample_time + save_extra_time + eval_time + total_time = gt.get_times().total + + logger.record_tabular('in_unsupervised_model', + float(self.in_unsupervised_phase)) + logger.record_tabular('Train Time (s)', train_time) + logger.record_tabular('(Previous) Eval Time (s)', eval_time) + logger.record_tabular('Sample Time (s)', sample_time) + logger.record_tabular('Save Extra Time (s)', save_extra_time) + logger.record_tabular('Save Snapshot Time (s)', save_snapshot_time) + logger.record_tabular('Epoch Time (s)', epoch_time) + logger.record_tabular('Total Train Time (s)', total_time) + + logger.record_tabular("Epoch", epoch) + logger.dump_tabular(with_prefix=False, with_timestamp=False) + else: + logger.log("Skipping eval for now.") + + def _can_evaluate(self): + """ + One annoying thing about the logger table is that the keys at each + iteration need to be the exact same. So unless you can compute + everything, skip evaluation. + + A common example for why you might want to skip evaluation is that at + the beginning of training, you may not have enough data for a + validation and training set. + + :return: + """ + # eval collects its own context, so can eval any time + return True + + def _can_train(self): + return all([self.replay_buffer.num_steps_can_sample(idx) >= self.batch_size for idx in self.train_task_indices]) + + def _get_action_and_info(self, agent, observation): + """ + Get an action to take in the environment. + :param observation: + :return: + """ + agent.set_num_steps_total(self._n_env_steps_total) + return agent.get_action(observation,) + + def _start_epoch(self, epoch): + self._epoch_start_time = time.time() + self._exploration_paths = [] + self._do_train_time = 0 + logger.push_prefix('Iteration #%d | ' % epoch) + if epoch in self.save_extra_manual_beginning_epoch_list: + logger.save_extra_data( + self.get_extra_data_to_save(epoch), + file_name='extra_snapshot_beginning_itr{}'.format(epoch), + mode='cloudpickle', + ) + + def _end_epoch(self, epoch): + for post_train_func in self.post_train_funcs: + post_train_func(self, epoch) + + self.trainer.end_epoch(epoch) + logger.log("Epoch Duration: {0}".format( + time.time() - self._epoch_start_time + )) + logger.log("Started Training: {0}".format(self._can_train())) + logger.pop_prefix() + + ##### Snapshotting utils ##### + def get_epoch_snapshot(self, epoch): + snapshot = {'epoch': epoch} + for k, v in self.trainer.get_snapshot().items(): + snapshot['trainer/' + k] = v + snapshot['env'] = self.env + snapshot['env_sampler'] = self.sampler + snapshot['agent'] = self.agent + snapshot['exploration_agent'] = self.exploration_agent + return snapshot + + def get_extra_data_to_save(self, epoch): + """ + Save things that shouldn't be saved every snapshot but rather + overwritten every time. + :param epoch: + :return: + """ + if self.render: + self.training_env.render(close=True) + data_to_save = dict( + epoch=epoch, + ) + if self.save_environment: + data_to_save['env'] = self.training_env + if self.save_replay_buffer: + data_to_save['replay_buffer'] = self.replay_buffer + data_to_save['enc_replay_buffer'] = self.enc_replay_buffer + data_to_save['meta_replay_buffer'] = self.meta_replay_buffer + if self.save_algorithm: + data_to_save['algorithm'] = self + return data_to_save + + def collect_paths(self, idx, epoch, run): + self.agent.clear_z() + paths = [] + num_transitions = 0 + num_trajs = 0 + init_context = None + infer_posterior_at_start = False + while num_transitions < self.num_steps_per_eval: + # We follow the PEARL protocol and never update the posterior or resample z within an episode during evaluation. + if idx in self.fake_task_idx_to_z: + initialized_z_reward = self.fake_task_idx_to_z[idx] + else: + initialized_z_reward = None + loop_paths, num = self.sampler.obtain_samples( + deterministic=self.eval_deterministic, + max_samples=self.num_steps_per_eval - num_transitions, + max_trajs=1, + accum_context=True, + initial_context=init_context, + task_idx=idx, + resample_latent_period=self.exploration_resample_latent_period, # PEARL had this=0. + update_posterior_period=0, # following PEARL protocol + infer_posterior_at_start=infer_posterior_at_start, + initialized_z_reward=initialized_z_reward, + use_predicted_reward=initialized_z_reward is not None, + ) + paths += loop_paths + num_transitions += num + num_trajs += 1 + # accumulated contexts across rollouts + init_context = paths[-1]['context'] # TODO clean hack + if num_trajs >= self.num_exp_traj_eval: + infer_posterior_at_start = True + + if self.sparse_rewards: + for p in paths: + sparse_rewards = np.stack(e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) + p['rewards'] = sparse_rewards + + goal = self.env._goal + for path in paths: + path['goal'] = goal # goal + + # save the paths for visualization, only useful for point mass + if self.dump_eval_paths and epoch >= 0: + logger.save_extra_data(paths, file_name='eval_trajectories/task{}-epoch{}-run{}'.format(idx, epoch, run)) + + return paths + + def _do_eval(self, indices, epoch): + final_returns = [] + online_returns = [] + task_idx_to_final_context = {} + for idx in indices: + all_rets = [] + for r in range(self.num_evals): + paths = self.collect_paths(idx, epoch, r) + all_rets.append([eval_util.get_average_returns([p]) for p in paths]) + task_idx_to_final_context[idx] = paths[-1]['context'] + final_returns.append(np.mean([a[-1] for a in all_rets])) + # record online returns for the first n trajectories + n = min([len(a) for a in all_rets]) + all_rets = [a[:n] for a in all_rets] + all_rets = np.mean(np.stack(all_rets), axis=0) # avg return per nth rollout + online_returns.append(all_rets) + n = min([len(t) for t in online_returns]) + online_returns = [t[:n] for t in online_returns] + return final_returns, online_returns, task_idx_to_final_context + + def evaluate(self, epoch): + if self.eval_statistics is None: + self.eval_statistics = OrderedDict() + + ### sample trajectories from prior for debugging / visualization + if self.dump_eval_paths: + # 100 arbitrarily chosen for visualizations of point_robot trajectories + # just want stochasticity of z, not the policy + self.agent.clear_z() + prior_paths, _ = self.sampler.obtain_samples( + deterministic=self.eval_deterministic, + max_samples=self.max_path_length * 20, + accum_context=False, + resample_latent_period=self.exploration_resample_latent_period, + update_posterior_period=self.exploration_update_posterior_period, # following PEARL protocol + ) + logger.save_extra_data(prior_paths, file_name='eval_trajectories/prior-epoch{}'.format(epoch)) + ### train tasks + if self._num_tasks_to_eval_on >= len(self.train_task_indices): + indices = self.train_task_indices + else: + # eval on a subset of train tasks in case num train tasks is huge + indices = np.random.choice(self.offline_train_task_indices, self._num_tasks_to_eval_on) + # logger.log('evaluating on {} train tasks'.format(len(indices))) + ### eval train tasks with posterior sampled from the training replay buffer + train_returns = [] + for idx in indices: + self.env.reset_task(idx) + paths = [] + for _ in range(self.num_steps_per_eval // self.max_path_length): + # init_context = self.sample_context(idx) + if self.use_meta_learning_buffer: + init_context = self.meta_replay_buffer._sample_contexts( + [idx], + self.embedding_batch_size + ) + else: + init_context = self.enc_replay_buffer.sample_context( + idx, + self.embedding_batch_size + ) + if self.eval_data_collector: + p = self.eval_data_collector.collect_new_paths( + num_steps=self.max_path_length, # TODO: also cap num trajs + max_path_length=self.max_path_length, + discard_incomplete_paths=False, + accum_context=False, + resample_latent_period=0, + update_posterior_period=0, + initial_context=init_context, + task_idx=idx, + ) + else: + init_context = ptu.from_numpy(init_context) + # TODO: replace with sampler + # self.agent.infer_posterior(context) + p, _ = self.sampler.obtain_samples( + deterministic=self.eval_deterministic, + max_samples=self.max_path_length, + accum_context=False, + max_trajs=1, + resample_latent_period=0, + update_posterior_period=0, + initial_context=init_context, + task_idx=idx, + ) + paths += p + + if self.sparse_rewards: + for p in paths: + sparse_rewards = np.stack(e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) + p['rewards'] = sparse_rewards + + train_returns.append(eval_util.get_average_returns(paths)) + + train_returns_offline_buffer = self._get_returns_init_from_offline_buffer(indices) + # train_returns = np.mean(train_returns) + ### eval train tasks with on-policy data to match eval of test tasks + train_final_returns, train_online_returns, train_task_to_final_context = ( + self._do_eval(indices, epoch) + ) + # logger.log('train online returns') + # logger.log(train_online_returns) + + ### test tasks + # logger.log('evaluating on {} test tasks'.format(len(self.eval_task_indices))) + test_final_returns, test_online_returns, test_task_to_final_context =( + self._do_eval(self.eval_task_indices, epoch) + ) + # logger.log('test online returns') + # logger.log(test_online_returns) + # save the final posterior + self.agent.log_diagnostics(self.eval_statistics) + + z_dist_log = self._get_z_distribution_log(train_task_to_final_context) + append_log(self.eval_statistics, z_dist_log, prefix='trainer/train_tasks/') + + if hasattr(self.env, "log_diagnostics"): + self.env.log_diagnostics(paths, prefix=None) + + avg_train_online_return = np.mean(np.stack(train_online_returns), axis=0) + avg_test_online_return = np.mean(np.stack(test_online_returns), axis=0) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/init_from_offline_buffer/train_tasks/all_returns', + train_returns_offline_buffer, + )) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/init_from_buffer/train_tasks/all_returns', + train_returns, + )) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/adaptation/train_tasks/final_returns', + train_final_returns, + )) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/adaptation/test_tasks/final_returns', + test_final_returns, + )) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/adaptation/train_tasks/all_returns', + avg_train_online_return, + )) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/adaptation/test_tasks/all_returns', + avg_test_online_return, + )) + + if len(self.fake_task_idx_to_z) > 0: + self_generated_indices = np.random.choice( + np.array(list(self.fake_task_idx_to_z.keys())), + self._num_tasks_to_eval_on, + ) + self_generated_final_returns, self_generated_online_returns, _ = self._do_eval(self_generated_indices, epoch) + avg_self_generated_return = np.mean(np.stack(self_generated_online_returns)) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/adaptation/generated_tasks/final_returns', + self_generated_final_returns, + )) + self.eval_statistics.update(eval_util.create_stats_ordered_dict( + 'eval/adaptation/generated_tasks/all_returns', + avg_self_generated_return, + )) + + try: + import os + import psutil + process = psutil.Process(os.getpid()) + self.eval_statistics['RAM Usage (Mb)'] = int(process.memory_info().rss / 1000000) + except ImportError: + pass + logger.save_extra_data(avg_train_online_return, file_name='online-train-epoch{}'.format(epoch)) + logger.save_extra_data(avg_test_online_return, file_name='online-test-epoch{}'.format(epoch)) + + for key, value in self.eval_statistics.items(): + logger.record_tabular(key, value) + self.eval_statistics = None + + if self.render_eval_paths: + self.env.render_paths(paths) + + if self.plotter: + self.plotter.draw() + + # @abc.abstractmethod + def training_mode(self, mode): + """ + Set training mode to `mode`. + :param mode: If True, training will happen (e.g. set the dropout + probabilities to not all ones). + """ + pass + + def to(self, device=None): + self.trainer.to(device=device) + + def _get_z_distribution_log(self, idx_to_final_context): + """Log diagnostics about the shift in z-distribution""" + logs = OrderedDict() + task_indices = list(sorted(idx_to_final_context.keys())) + offline_context1 = self._reward_decoder_buffer.sample_context( + task_indices, + self.embedding_batch_size + ) + context_distrib1 = self.agent.latent_posterior(offline_context1) + offline_context2 = self._reward_decoder_buffer.sample_context( + task_indices, + self.embedding_batch_size + ) + context_distrib2 = self.agent.latent_posterior(offline_context2) + + context_distribs1 = [ + Normal(m, s) for m, s in + zip(context_distrib1.mean, context_distrib1.stddev) + ] + context_distribs2 = [ + Normal(m, s) for m, s in + zip(context_distrib2.mean, context_distrib2.stddev) + ] + within_task_kl_2samples, between_task_kl_2samples = ( + self._compute_within_and_between_task_kl( + context_distribs1, + context_distribs2, + ) + ) + + logs['two_offline_z_posteriors/within_task_kl'] = ( + within_task_kl_2samples + ) + logs['two_offline_z_posteriors/between_task_kl'] = ( + between_task_kl_2samples + ) + offline_z_posterior_kl_prior = kl_divergence( + context_distrib1, self.agent.latent_prior).sum(dim=1) + logs.update(eval_util.create_stats_ordered_dict( + 'offline_z_posterior/kl_prior', + ptu.get_numpy(offline_z_posterior_kl_prior), + )) + + context_distribs_online = [] + for idx in task_indices: + context = idx_to_final_context[idx] + context_distrib = self.agent.latent_posterior(context) + context_distribs_online.append(context_distrib) + + within_task_kl_off_on, between_task_kl_off_on = ( + self._compute_within_and_between_task_kl( + context_distribs_online, + context_distribs2, + ) + ) + online_posterior_kls = np.array([ + ptu.get_numpy(kl_divergence(q, self.agent.latent_prior).sum()) + for q in context_distribs_online + ]) + logs.update(eval_util.create_stats_ordered_dict( + 'online_z_posterior/kl_prior', + online_posterior_kls, + )) + logs['offline_vs_online_z_posterior/within_task_kl'] = ( + within_task_kl_off_on + ) + logs['offline_vs_online_z_posterior/between_task_kl'] = ( + between_task_kl_off_on + ) + logs['offline_vs_online_z_posterior/within_task_kl/normalized'] = ( + (within_task_kl_off_on - within_task_kl_2samples) / ( + (between_task_kl_2samples - within_task_kl_2samples) + ) + ) + return logs + + def _compute_within_and_between_task_kl(self, context_distribs1, + context_distribs2): + n_tasks = len(context_distribs1) + divergences = np.zeros((n_tasks, n_tasks)) + for i1, d1 in enumerate(context_distribs1): + for i2, d2 in enumerate(context_distribs2): + kl = kl_divergence(d1, d2).sum().item() + divergences[i1, i2] = kl + within_task_avg_kl = (divergences * np.eye(n_tasks)).sum() / n_tasks + between_task_avg_kl = (divergences * (1-np.eye(n_tasks))).sum() / ( + n_tasks * (n_tasks - 1) + ) + return within_task_avg_kl, between_task_avg_kl + + def _get_returns_init_from_offline_buffer(self, indices): + train_returns = [] + for idx in indices: + self.env.reset_task(idx) + paths = [] + for _ in range(self.num_steps_per_eval // self.max_path_length): + init_context = self._reward_decoder_buffer.sample_context( + idx, + self.embedding_batch_size + ) + init_context = ptu.from_numpy(init_context) + p, _ = self.sampler.obtain_samples( + deterministic=self.eval_deterministic, + max_samples=self.max_path_length, + accum_context=False, + max_trajs=1, + resample_latent_period=0, + update_posterior_period=0, + initial_context=init_context, + task_idx=idx, + ) + paths += p + + if self.sparse_rewards: + for p in paths: + sparse_rewards = np.stack(e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) + p['rewards'] = sparse_rewards + + train_returns.append(eval_util.get_average_returns(paths)) + return train_returns diff --git a/rlkit/core/simple_offline_rl_algorithm.py b/rlkit/core/simple_offline_rl_algorithm.py new file mode 100644 index 000000000..c6a9e388e --- /dev/null +++ b/rlkit/core/simple_offline_rl_algorithm.py @@ -0,0 +1,161 @@ +from collections import OrderedDict + +import numpy as np + + +from rlkit.core.timer import timer +from rlkit.core import logger +from rlkit.core.logging import add_prefix +from rlkit.torch.core import np_to_pytorch_batch + + +def _get_epoch_timings(): + times_itrs = timer.get_times() + times = OrderedDict() + for key in sorted(times_itrs): + time = times_itrs[key] + times['time/{} (s)'.format(key)] = time + return times + + +class SimpleOfflineRlAlgorithm(object): + def __init__( + self, + trainer, + replay_buffer, + batch_size, + logging_period, + num_batches, + ): + self.trainer = trainer + self.replay_buffer = replay_buffer + self.batch_size = batch_size + self.num_batches = num_batches + self.logging_period = logging_period + + def train(self): + # first train only the Q function + iteration = 0 + for i in range(self.num_batches): + train_data = self.replay_buffer.random_batch(self.batch_size) + train_data = np_to_pytorch_batch(train_data) + obs = train_data['observations'] + next_obs = train_data['next_observations'] + train_data['observations'] = obs + train_data['next_observations'] = next_obs + self.trainer.train_from_torch(train_data) + if i % self.logging_period == 0: + stats_with_prefix = add_prefix( + self.trainer.eval_statistics, prefix="trainer/") + self.trainer.end_epoch(iteration) + iteration += 1 + logger.record_dict(stats_with_prefix) + logger.dump_tabular(with_prefix=True, with_timestamp=False) + + +class OfflineMetaRLAlgorithm(object): + def __init__( + self, + # main objects needed + meta_replay_buffer, + replay_buffer, + task_embedding_replay_buffer, + trainer, + train_tasks, + # settings + batch_size, + logging_period, + meta_batch_size, + num_batches, + task_embedding_batch_size, + extra_eval_fns=(), + use_meta_learning_buffer=False, + ): + self.trainer = trainer + self.meta_replay_buffer = meta_replay_buffer + self.replay_buffer = replay_buffer + self.task_embedding_replay_buffer = task_embedding_replay_buffer + self.batch_size = batch_size + self.task_embedding_batch_size = task_embedding_batch_size + self.num_batches = num_batches + self.logging_period = logging_period + self.train_tasks = train_tasks + self.meta_batch_size = meta_batch_size + self._extra_eval_fns = extra_eval_fns + self.use_meta_learning_buffer = use_meta_learning_buffer + + def train(self): + # first train only the Q function + iteration = 0 + timer.return_global_times = True + timer.reset() + for i in range(self.num_batches): + if self.use_meta_learning_buffer: + train_data = self.meta_replay_buffer.sample_meta_batch( + rl_batch_size=self.batch_size, + meta_batch_size=self.meta_batch_size, + embedding_batch_size=self.task_embedding_batch_size, + ) + train_data = np_to_pytorch_batch(train_data) + else: + task_indices = np.random.choice( + self.train_tasks, self.meta_batch_size, + ) + train_data = self.replay_buffer.sample_batch( + task_indices, + self.batch_size, + ) + train_data = np_to_pytorch_batch(train_data) + obs = train_data['observations'] + next_obs = train_data['next_observations'] + train_data['observations'] = obs + train_data['next_observations'] = next_obs + train_data['context'] = ( + self.task_embedding_replay_buffer.sample_context( + task_indices, + self.task_embedding_batch_size, + )) + timer.start_timer('train', unique=False) + self.trainer.train_from_torch(train_data) + timer.stop_timer('train') + if i % self.logging_period == 0 or i == self.num_batches - 1: + stats_with_prefix = add_prefix( + self.trainer.eval_statistics, prefix="trainer/") + self.trainer.end_epoch(iteration) + logger.record_dict(stats_with_prefix) + timer.start_timer('extra_fns', unique=False) + for fn in self._extra_eval_fns: + extra_stats = fn() + logger.record_dict(extra_stats) + timer.stop_timer('extra_fns') + + + # TODO: evaluate during offline RL + # eval_stats = self.get_eval_statistics() + # eval_stats_with_prefix = add_prefix(eval_stats, prefix="eval/") + # logger.record_dict(eval_stats_with_prefix) + + logger.record_tabular('iteration', iteration) + logger.record_dict(_get_epoch_timings()) + try: + import os + import psutil + process = psutil.Process(os.getpid()) + logger.record_tabular('RAM Usage (Mb)', int(process.memory_info().rss / 1000000)) + except ImportError: + pass + logger.dump_tabular(with_prefix=True, with_timestamp=False) + iteration += 1 + + def to(self, device): + self.trainer.to(device) + # def get_eval_statistics(self): + # ### train tasks + # # eval on a subset of train tasks for speed + # stats = OrderedDict() + # indices = np.random.choice(self.train_task_indices, len(self.eval_task_indices)) + # for key, path_collector in self.path_collectors.item(): + # paths = path_collector.collect_paths() + # returns = eval_util.get_average_returns(paths) + # stats[key + '/AverageReturns'] = returns + # return stats diff --git a/rlkit/core/timer.py b/rlkit/core/timer.py new file mode 100644 index 000000000..21cfffd4b --- /dev/null +++ b/rlkit/core/timer.py @@ -0,0 +1,51 @@ +import time + +from collections import defaultdict + + +class Timer: + def __init__(self, return_global_times=False): + self.stamps = None + self.epoch_start_time = None + self.global_start_time = time.time() + self._return_global_times = return_global_times + + self.reset() + + def reset(self): + self.stamps = defaultdict(lambda: 0) + self.start_times = {} + self.epoch_start_time = time.time() + + def start_timer(self, name, unique=True): + if unique: + assert name not in self.start_times.keys() + self.start_times[name] = time.time() + + def stop_timer(self, name): + assert name in self.start_times.keys() + start_time = self.start_times[name] + end_time = time.time() + self.stamps[name] += (end_time - start_time) + + def get_times(self): + global_times = {} + cur_time = time.time() + global_times['epoch_time'] = (cur_time - self.epoch_start_time) + if self._return_global_times: + global_times['global_time'] = (cur_time - self.global_start_time) + return { + **self.stamps.copy(), + **global_times, + } + + @property + def return_global_times(self): + return self._return_global_times + + @return_global_times.setter + def return_global_times(self, value): + self._return_global_times = value + + +timer = Timer() diff --git a/rlkit/data_management/meta_learning_replay_buffer.py b/rlkit/data_management/meta_learning_replay_buffer.py new file mode 100644 index 000000000..955c00784 --- /dev/null +++ b/rlkit/data_management/meta_learning_replay_buffer.py @@ -0,0 +1,183 @@ +import numpy as np +import random + +from rlkit.data_management.simple_replay_buffer import ( + SimpleReplayBuffer as RLKitSimpleReplayBuffer +) + +from rlkit.envs.env_utils import get_dim + + +class MetaLearningReplayBuffer(object): + def __init__( + self, + max_replay_buffer_size, + env, + task_indices, + use_next_obs_in_context, + sparse_rewards, + mini_buffer_max_size, + use_ground_truth_context=False, + ground_truth_tasks=None, + sample_buffer_in_proportion_to_size=False, + ): + """ + This has a separate mini-replay buffer for each set of tasks + """ + self.max_replay_buffer_size = max_replay_buffer_size + self.use_next_obs_in_context = use_next_obs_in_context + self.sparse_rewards = sparse_rewards + self.env = env + self._ob_space = env.observation_space + self._action_space = env.action_space + self.use_ground_truth_context = use_ground_truth_context + self.task_indices = task_indices + self.ground_truth_tasks = ground_truth_tasks + if use_ground_truth_context: + assert ground_truth_tasks is not None + self.env_info_sizes = dict() + if sparse_rewards: + self.env_info_sizes['sparse_reward'] = 1 + self.task_buffers = [] + self.mini_buffer_max_size = mini_buffer_max_size + self._num_steps_can_sample = 0 + + self.sample_buffer_in_proportion_to_size = ( + sample_buffer_in_proportion_to_size + ) + + def create_buffer(self, size=None): + if size is None: + size = self.mini_buffer_max_size + return RLKitSimpleReplayBuffer( + max_replay_buffer_size=size, + observation_dim=get_dim(self._ob_space), + action_dim=get_dim(self._action_space), + env_info_sizes=self.env_info_sizes, + ) + + @property + def num_steps_can_sample(self): + return self._num_steps_can_sample + + def add_paths(self, paths): + new_buffer = self.create_buffer() + for path in paths: + new_buffer.add_path(path) + self._num_steps_can_sample += new_buffer.num_steps_can_sample() + self.append_buffer(new_buffer) + + def append_buffer(self, new_buffer): + self.task_buffers.append(new_buffer) + while self.num_steps_can_sample > self.max_replay_buffer_size: + self._remove_task_buffer() + + def _remove_task_buffer(self): + buffer_to_remove = random.choice(self.task_buffers) + self.task_buffers.remove(buffer_to_remove) + self._num_steps_can_sample -= buffer_to_remove.num_steps_can_sample() + + def sample_batch(self, indices, batch_size): + """ + sample batch of training data from a list of tasks for training the + actor-critic. + + :param indices: task indices + :param batch_size: batch size for each task index + :return: + """ + # TODO: replace with pythonplusplus.treemap + # this batch consists of transitions sampled randomly from replay buffer + # rewards are always dense + # batches = [np_to_pytorch_batch(self.replay_buffer.random_batch(idx, batch_size=self.batch_size)) for idx in indices] + batches = [self.random_batch(idx, batch_size=batch_size) for idx in indices] + unpacked = [self.unpack_batch(batch) for batch in batches] + # group like elements together + unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))] + # unpacked = [torch.cat(x, dim=0) for x in unpacked] + unpacked = [np.concatenate(x, axis=0) for x in unpacked] + + obs, actions, rewards, next_obs, terms = unpacked + return { + 'observations': obs, + 'actions': actions, + 'rewards': rewards, + 'next_observations': next_obs, + 'terminals': terms, + } + + def _sample_contexts(self, indices, batch_size): + ''' sample batch of context from a list of tasks from the replay buffer ''' + # make method work given a single task index + if not hasattr(indices, '__iter__'): + indices = [indices] + batches = [ + self.random_batch( + idx, + batch_size=batch_size, + sequence=False) + for idx in indices + ] + if any(b is None for b in batches): + import ipdb; ipdb.set_trace() + return None + if self.use_ground_truth_context: + return np.array([self.ground_truth_tasks[i] for i in indices]) + context = [self.unpack_batch(batch) for batch in batches] + # group like elements together + context = [[x[i] for x in context] for i in range(len(context[0]))] + # context = [torch.cat(x, dim=0) for x in context] + context = [np.concatenate(x, axis=0) for x in context] + # full context consists of [obs, act, rewards, next_obs, terms] + # if dynamics don't change across tasks, don't include next_obs + # don't include terminals in context + if self.use_next_obs_in_context: + context = np.concatenate(context[:-1], axis=2) + else: + context = np.concatenate(context[:-2], axis=2) + return context + + def random_batch(self, task, batch_size, sequence=False): + if sequence: + batch = self.task_buffers[task].random_sequence(batch_size) + else: + try: + batch = self.task_buffers[task].random_batch(batch_size) + except KeyError: + import ipdb; ipdb.set_trace() + print(task) + return batch + + def sample_meta_batch(self, meta_batch_size, rl_batch_size, embedding_batch_size): + possible_indices = np.arange(len(self.task_buffers)) + if self.sample_buffer_in_proportion_to_size: + sizes = np.array([buffer.num_steps_can_sample() for buffer in self.task_buffers]) + sample_probs = sizes / np.sum(sizes) + indices = np.random.choice( + possible_indices, + meta_batch_size, + p=sample_probs, + ) + else: + indices = np.random.choice(possible_indices, meta_batch_size) + batch = self.sample_batch(indices, rl_batch_size) + context = self._sample_contexts(indices, embedding_batch_size) + batch['context'] = context + return batch + + def sample_context(self, batch_size): + possible_indices = np.arange(len(self.task_buffers)) + index = np.random.choice(possible_indices) + return self._sample_contexts([index], batch_size) + + def unpack_batch(self, batch): + ''' unpack a batch and return individual elements ''' + o = batch['observations'][None, ...] + a = batch['actions'][None, ...] + if self.sparse_rewards: + r = batch['sparse_rewards'][None, ...] + else: + r = batch['rewards'][None, ...] + no = batch['next_observations'][None, ...] + t = batch['terminals'][None, ...] + return [o, a, r, no, t] diff --git a/rlkit/data_management/multitask_replay_buffer.py b/rlkit/data_management/multitask_replay_buffer.py new file mode 100644 index 000000000..7e658e180 --- /dev/null +++ b/rlkit/data_management/multitask_replay_buffer.py @@ -0,0 +1,300 @@ +import numpy as np + +from rlkit.data_management.replay_buffer import ReplayBuffer +from rlkit.data_management.simple_replay_buffer import ( + SimpleReplayBuffer as RLKitSimpleReplayBuffer +) +from gym.spaces import Box, Discrete, Tuple + + +class MultiTaskReplayBuffer(object): + def __init__( + self, + max_replay_buffer_size, + env, + task_indices, + use_next_obs_in_context, + sparse_rewards, + use_ground_truth_context=False, + ground_truth_tasks=None, + env_info_sizes=None, + ): + """ + :param max_replay_buffer_size: + :param env: + :param task_indices: for multi-task setting + """ + if env_info_sizes is None: + env_info_sizes = {} + self.use_next_obs_in_context = use_next_obs_in_context + self.sparse_rewards = sparse_rewards + self.env = env + self._ob_space = env.observation_space + self._action_space = env.action_space + self.use_ground_truth_context = use_ground_truth_context + self.task_indices = task_indices + self.ground_truth_tasks = ground_truth_tasks + if use_ground_truth_context: + assert ground_truth_tasks is not None + if sparse_rewards: + env_info_sizes['sparse_reward'] = 1 + self.task_buffers = dict([(idx, RLKitSimpleReplayBuffer( + max_replay_buffer_size=max_replay_buffer_size, + observation_dim=get_dim(self._ob_space), + action_dim=get_dim(self._action_space), + env_info_sizes=env_info_sizes, + )) for idx in task_indices]) + self._max_replay_buffer_size = max_replay_buffer_size + self._env_info_sizes = env_info_sizes + + def create_new_task_buffer(self, task_idx): + if task_idx in self.task_buffers: + raise IndexError("task_idx already exists: {}".format(task_idx)) + new_task_buffer = RLKitSimpleReplayBuffer( + max_replay_buffer_size=self._max_replay_buffer_size, + observation_dim=get_dim(self._ob_space), + action_dim=get_dim(self._action_space), + env_info_sizes=self._env_info_sizes, + ) + self.task_buffers[task_idx] = new_task_buffer + + def add_sample(self, task, observation, action, reward, terminal, + next_observation, **kwargs): + + if isinstance(self._action_space, Discrete): + action = np.eye(self._action_space.n)[action] + self.task_buffers[task].add_sample( + observation, action, reward, terminal, + next_observation, **kwargs) + + def terminate_episode(self, task): + self.task_buffers[task].terminate_episode() + + def random_batch(self, task, batch_size, sequence=False): + if sequence: + batch = self.task_buffers[task].random_sequence(batch_size) + else: + try: + batch = self.task_buffers[task].random_batch(batch_size) + except KeyError: + import ipdb; ipdb.set_trace() + print(task) + return batch + + def num_steps_can_sample(self, task): + return self.task_buffers[task].num_steps_can_sample() + + def add_path(self, task, path): + self.task_buffers[task].add_path(path) + + def add_paths(self, task, paths): + for path in paths: + self.task_buffers[task].add_path(path) + + def clear_buffer(self, task): + self.task_buffers[task].clear() + + def clear_all_buffers(self): + for buffer in self.task_buffers.values(): + buffer.clear() + + def sample_batch(self, indices, batch_size): + """ + sample batch of training data from a list of tasks for training the + actor-critic. + + :param indices: task indices + :param batch_size: batch size for each task index + :return: + """ + # TODO: replace with pythonplusplus.treemap + # this batch consists of transitions sampled randomly from replay buffer + # rewards are always dense + # batches = [np_to_pytorch_batch(self.replay_buffer.random_batch(idx, batch_size=self.batch_size)) for idx in indices] + batches = [self.random_batch(idx, batch_size=batch_size) for idx in indices] + unpacked = [self.unpack_batch(batch) for batch in batches] + # group like elements together + unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))] + # unpacked = [torch.cat(x, dim=0) for x in unpacked] + unpacked = [np.concatenate(x, axis=0) for x in unpacked] + + obs, actions, rewards, next_obs, terms = unpacked + return { + 'observations': obs, + 'actions': actions, + 'rewards': rewards, + 'next_observations': next_obs, + 'terminals': terms, + } + + def sample_context(self, indices, batch_size): + ''' sample batch of context from a list of tasks from the replay buffer ''' + # make method work given a single task index + if not hasattr(indices, '__iter__'): + indices = [indices] + batches = [ + self.random_batch( + idx, + batch_size=batch_size, + sequence=False) + for idx in indices + ] + if any(b is None for b in batches): + import ipdb; ipdb.set_trace() + return None + if self.use_ground_truth_context: + return np.array([self.ground_truth_tasks[i] for i in indices]) + context = [self.unpack_batch(batch) for batch in batches] + # group like elements together + context = [[x[i] for x in context] for i in range(len(context[0]))] + # context = [torch.cat(x, dim=0) for x in context] + context = [np.concatenate(x, axis=0) for x in context] + # full context consists of [obs, act, rewards, next_obs, terms] + # if dynamics don't change across tasks, don't include next_obs + # don't include terminals in context + if self.use_next_obs_in_context: + context = np.concatenate(context[:-1], axis=2) + else: + context = np.concatenate(context[:-2], axis=2) + return context + + def unpack_batch(self, batch): + ''' unpack a batch and return individual elements ''' + o = batch['observations'][None, ...] + a = batch['actions'][None, ...] + if self.sparse_rewards: + r = batch['sparse_rewards'][None, ...] + else: + r = batch['rewards'][None, ...] + no = batch['next_observations'][None, ...] + t = batch['terminals'][None, ...] + return [o, a, r, no, t] + +def get_dim(space): + if isinstance(space, Box): + return space.low.size + elif isinstance(space, Discrete): + return space.n + elif isinstance(space, Tuple): + return sum(get_dim(subspace) for subspace in space.spaces) + elif hasattr(space, 'flat_dim'): + return space.flat_dim + else: + # import OldBox here so it is not necessary to have rand_param_envs + # installed if not running the rand_param envs + from rand_param_envs.gym.spaces.box import Box as OldBox + if isinstance(space, OldBox): + return space.low.size + else: + raise TypeError("Unknown space: {}".format(space)) + + +# WARNING: deprecated +class SimpleReplayBuffer(ReplayBuffer): + def __init__( + self, max_replay_buffer_size, observation_dim, action_dim, + ): + print("WARNING: will deprecate this SimpleReplayBuffer.") + self._observation_dim = observation_dim + self._action_dim = action_dim + self._max_replay_buffer_size = max_replay_buffer_size + self._observations = np.zeros((max_replay_buffer_size, observation_dim)) + # It's a bit memory inefficient to save the observations twice, + # but it makes the code *much* easier since you no longer have to + # worry about termination conditions. + self._next_obs = np.zeros((max_replay_buffer_size, observation_dim)) + self._actions = np.zeros((max_replay_buffer_size, action_dim)) + # Make everything a 2D np array to make it easier for other code to + # reason about the shape of the data + self._rewards = np.zeros((max_replay_buffer_size, 1)) + self._sparse_rewards = np.zeros((max_replay_buffer_size, 1)) + # self._terminals[i] = a terminal was received at time i + self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8') + self.clear() + + def add_sample(self, observation, action, reward, terminal, + next_observation, **kwargs): + self._observations[self._top] = observation + self._actions[self._top] = action + self._rewards[self._top] = reward + self._terminals[self._top] = terminal + self._next_obs[self._top] = next_observation + self._sparse_rewards[self._top] = kwargs['env_info'].get('sparse_reward', 0) + self._advance() + + def terminate_episode(self): + # store the episode beginning once the episode is over + # n.b. allows last episode to loop but whatever + self._episode_starts.append(self._cur_episode_start) + self._cur_episode_start = self._top + + def size(self): + return self._size + + def clear(self): + self._top = 0 + self._size = 0 + self._episode_starts = [] + self._cur_episode_start = 0 + + def _advance(self): + self._top = (self._top + 1) % self._max_replay_buffer_size + if self._size < self._max_replay_buffer_size: + self._size += 1 + + def sample_data(self, indices): + return dict( + observations=self._observations[indices], + actions=self._actions[indices], + rewards=self._rewards[indices], + terminals=self._terminals[indices], + next_observations=self._next_obs[indices], + sparse_rewards=self._sparse_rewards[indices], + ) + + def random_batch(self, batch_size): + ''' batch of unordered transitions ''' + indices = np.random.randint(0, self._size, batch_size) + return self.sample_data(indices) + + def random_sequence(self, batch_size): + ''' batch of trajectories ''' + # take random trajectories until we have enough + i = 0 + indices = [] + while len(indices) < batch_size: + # TODO hack to not deal with wrapping episodes, just don't take the last one + start = np.random.choice(self.episode_starts[:-1]) + pos_idx = self._episode_starts.index(start) + indices += list(range(start, self._episode_starts[pos_idx + 1])) + i += 1 + # cut off the last traj if needed to respect batch size + indices = indices[:batch_size] + return self.sample_data(indices) + + def num_steps_can_sample(self): + return self._size + + def copy_data(self, other_buffer: 'SimpleReplayBuffer'): + start_i = self._top + end_i = self._top + other_buffer._top + if end_i > self._max_replay_buffer_size: + raise NotImplementedError() + self._observations[start_i:end_i] = ( + other_buffer._observations[:other_buffer._top].copy() + ) + self._actions[start_i:end_i] = ( + other_buffer._actions[:other_buffer._top].copy() + ) + self._rewards[start_i:end_i] = ( + other_buffer._rewards[:other_buffer._top].copy() + ) + self._terminals[start_i:end_i] = ( + other_buffer._terminals[:other_buffer._top].copy() + ) + self._next_obs[start_i:end_i] = ( + other_buffer._next_obs[:other_buffer._top].copy() + ) + self._sparse_rewards[start_i:end_i] = ( + other_buffer._sparse_rewards[:other_buffer._top].copy() + ) diff --git a/rlkit/data_management/simple_replay_buffer.py b/rlkit/data_management/simple_replay_buffer.py index b39668659..ec01983be 100644 --- a/rlkit/data_management/simple_replay_buffer.py +++ b/rlkit/data_management/simple_replay_buffer.py @@ -35,7 +35,7 @@ def __init__( self._env_infos = {} for key, size in env_info_sizes.items(): self._env_infos[key] = np.zeros((max_replay_buffer_size, size)) - self._env_info_keys = env_info_sizes.keys() + self._env_info_keys = list(env_info_sizes.keys()) self._replace = replace diff --git a/rlkit/envs/pearl_envs/__init__.py b/rlkit/envs/pearl_envs/__init__.py new file mode 100644 index 000000000..9574e8272 --- /dev/null +++ b/rlkit/envs/pearl_envs/__init__.py @@ -0,0 +1,57 @@ +from rlkit.envs.pearl_envs.ant_normal import AntNormal +from rlkit.envs.pearl_envs.ant_dir import AntDirEnv +from rlkit.envs.pearl_envs.ant_goal import AntGoalEnv +from rlkit.envs.pearl_envs.half_cheetah_dir import HalfCheetahDirEnv +from rlkit.envs.pearl_envs.half_cheetah_vel import HalfCheetahVelEnv +from rlkit.envs.pearl_envs.hopper_rand_params_wrapper import \ + HopperRandParamsWrappedEnv +from rlkit.envs.pearl_envs.humanoid_dir import HumanoidDirEnv +from rlkit.envs.pearl_envs.point_robot import PointEnv, SparsePointEnv +from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import \ + Walker2DRandParamsEnv +from rlkit.envs.pearl_envs.walker_rand_params_wrapper import \ + WalkerRandParamsWrappedEnv + +ENVS = {} + + +def register_env(name): + """Registers a env by name for instantiation in rlkit.""" + + def register_env_fn(fn): + if name in ENVS: + raise ValueError("Cannot register duplicate env {}".format(name)) + if not callable(fn): + raise TypeError("env {} must be callable".format(name)) + ENVS[name] = fn + return fn + + return register_env_fn + + +def _register_env(name, fn): + """Registers a env by name for instantiation in rlkit.""" + if name in ENVS: + raise ValueError("Cannot register duplicate env {}".format(name)) + if not callable(fn): + raise TypeError("env {} must be callable".format(name)) + ENVS[name] = fn + + +def register_pearl_envs(): + _register_env('sparse-point-robot', SparsePointEnv) + _register_env('ant-normal', AntNormal) + _register_env('ant-dir', AntDirEnv) + _register_env('ant-goal', AntGoalEnv) + _register_env('cheetah-dir', HalfCheetahDirEnv) + _register_env('cheetah-vel', HalfCheetahVelEnv) + _register_env('humanoid-dir', HumanoidDirEnv) + _register_env('point-robot', PointEnv) + _register_env('walker-rand-params', WalkerRandParamsWrappedEnv) + _register_env('hopper-rand-params', HopperRandParamsWrappedEnv) + +# automatically import any envs in the envs/ directory +# for file in os.listdir(os.path.dirname(__file__)): +# if file.endswith('.py') and not file.startswith('_'): +# module = file[:file.find('.py')] +# importlib.import_module('rlkit.envs.pearl_envs.' + module) diff --git a/rlkit/envs/pearl_envs/ant.py b/rlkit/envs/pearl_envs/ant.py new file mode 100644 index 000000000..c154d5fde --- /dev/null +++ b/rlkit/envs/pearl_envs/ant.py @@ -0,0 +1,70 @@ +import numpy as np + +from .mujoco_env import MujocoEnv + + +class AntEnv(MujocoEnv): + def __init__(self, use_low_gear_ratio=False): + # self.init_serialization(locals()) + if use_low_gear_ratio: + xml_path = 'low_gear_ratio_ant.xml' + else: + xml_path = 'ant.xml' + super().__init__( + xml_path, + frame_skip=5, + automatically_set_obs_and_action_space=True, + ) + + def step(self, a): + torso_xyz_before = self.get_body_com("torso") + self.do_simulation(a, self.frame_skip) + torso_xyz_after = self.get_body_com("torso") + torso_velocity = torso_xyz_after - torso_xyz_before + forward_reward = torso_velocity[0]/self.dt + ctrl_cost = 0. # .5 * np.square(a).sum() + contact_cost = 0.5 * 1e-3 * np.sum( + np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) + survive_reward = 0. # 1.0 + reward = forward_reward - ctrl_cost - contact_cost + survive_reward + state = self.state_vector() + notdone = np.isfinite(state).all() \ + and state[2] >= 0.2 and state[2] <= 1.0 + done = not notdone + ob = self._get_obs() + return ob, reward, done, dict( + reward_forward=forward_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + reward_survive=survive_reward, + torso_velocity=torso_velocity, + ) + + def _get_obs(self): + # this is gym ant obs, should use rllab? + # if position is needed, override this in subclasses + return np.concatenate([ + self.sim.data.qpos.flat[2:], + self.sim.data.qvel.flat, + ]) + + def reset_model(self): + qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) + qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 + self.set_state(qpos, qvel) + return self._get_obs() + + def viewer_setup(self): + try: + from multiworld.envs.mujoco.cameras import create_camera_init + self.camera_init = create_camera_init( + lookat=(0, 0, 0), + distance=10, + elevation=-45, + azimuth=90, + trackbodyid=self.sim.model.body_name2id('torso'), + ) + self.camera_init(self.viewer.cam) + except ImportError as e: + pass +# \ No newline at end of file diff --git a/rlkit/envs/pearl_envs/ant_dir.py b/rlkit/envs/pearl_envs/ant_dir.py new file mode 100644 index 000000000..157829dbb --- /dev/null +++ b/rlkit/envs/pearl_envs/ant_dir.py @@ -0,0 +1,81 @@ +import numpy as np + +from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv + + +class AntDirEnv(MultitaskAntEnv): + + def __init__( + self, + task=None, + n_tasks=2, + fixed_tasks=None, + forward_backward=False, + direction_in_degrees=False, + **kwargs + ): + if task is None: + task = {} + self.fixed_tasks = fixed_tasks + self.direction_in_degrees = direction_in_degrees + self.quick_init(locals()) + self.forward_backward = forward_backward + super(AntDirEnv, self).__init__(task, n_tasks, **kwargs) + + def step(self, action): + torso_xyz_before = np.array(self.get_body_com("torso")) + + if self.direction_in_degrees: + goal = self._goal / 180 * np.pi + else: + goal = self._goal + direct = (np.cos(goal), np.sin(goal)) + + self.do_simulation(action, self.frame_skip) + torso_xyz_after = np.array(self.get_body_com("torso")) + torso_velocity = torso_xyz_after - torso_xyz_before + forward_reward = np.dot((torso_velocity[:2]/self.dt), direct) + + ctrl_cost = .5 * np.square(action).sum() + contact_cost = 0.5 * 1e-3 * np.sum( + np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) + survive_reward = 1.0 + reward = forward_reward - ctrl_cost - contact_cost + survive_reward + state = self.state_vector() + notdone = np.isfinite(state).all() \ + and state[2] >= 0.2 and state[2] <= 1.0 + done = not notdone + ob = self._get_obs() + return ob, reward, done, dict( + reward_forward=forward_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + reward_survive=survive_reward, + torso_velocity=torso_velocity, + torso_xy=self.sim.data.qpos.flat[:2].copy(), + ) + + def sample_tasks(self, num_tasks): + if self.forward_backward: + assert num_tasks == 2 + if self.direction_in_degrees: + directions = np.array([0., 180]) + else: + directions = np.array([0., np.pi]) + elif self.fixed_tasks: + directions = np.array(self.fixed_tasks) + else: + if self.direction_in_degrees: + directions = np.random.uniform(0., 360, size=(num_tasks,)) + else: + directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,)) + tasks = [{'goal': desired_dir} for desired_dir in directions] + return tasks + + def task_to_vec(self, task): + direction = task['goal'] + if self.direction_in_degrees: + normalized_direction = direction / 360 + else: + normalized_direction = direction / (2*np.pi) + return np.array([normalized_direction]) diff --git a/rlkit/envs/pearl_envs/ant_goal.py b/rlkit/envs/pearl_envs/ant_goal.py new file mode 100644 index 000000000..001958aad --- /dev/null +++ b/rlkit/envs/pearl_envs/ant_goal.py @@ -0,0 +1,45 @@ +import numpy as np + +from rlkit.envs.pearl_envs.ant_multitask_base import MultitaskAntEnv + + +# Copy task structure from https://github.com/jonasrothfuss/ProMP/blob/master/meta_policy_search/envs/mujoco_envs/ant_rand_goal.py +class AntGoalEnv(MultitaskAntEnv): + def __init__(self, task={}, n_tasks=2, randomize_tasks=True, **kwargs): + self.quick_init(locals()) + super(AntGoalEnv, self).__init__(task, n_tasks, **kwargs) + + def step(self, action): + self.do_simulation(action, self.frame_skip) + xposafter = np.array(self.get_body_com("torso")) + + goal_reward = -np.sum(np.abs(xposafter[:2] - self._goal)) # make it happy, not suicidal + + ctrl_cost = .1 * np.square(action).sum() + contact_cost = 0.5 * 1e-3 * np.sum( + np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) + survive_reward = 0.0 + reward = goal_reward - ctrl_cost - contact_cost + survive_reward + state = self.state_vector() + done = False + ob = self._get_obs() + return ob, reward, done, dict( + goal_forward=goal_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + reward_survive=survive_reward, + ) + + def sample_tasks(self, num_tasks): + a = np.random.random(num_tasks) * 2 * np.pi + r = 3 * np.random.random(num_tasks) ** 0.5 + goals = np.stack((r * np.cos(a), r * np.sin(a)), axis=-1) + tasks = [{'goal': goal} for goal in goals] + return tasks + + def _get_obs(self): + return np.concatenate([ + self.sim.data.qpos.flat, + self.sim.data.qvel.flat, + np.clip(self.sim.data.cfrc_ext, -1, 1).flat, + ]) diff --git a/rlkit/envs/pearl_envs/ant_multitask_base.py b/rlkit/envs/pearl_envs/ant_multitask_base.py new file mode 100644 index 000000000..5f7685c11 --- /dev/null +++ b/rlkit/envs/pearl_envs/ant_multitask_base.py @@ -0,0 +1,43 @@ +from rlkit.envs.pearl_envs.ant import AntEnv + + +class MultitaskAntEnv(AntEnv): + def __init__(self, task=None, n_tasks=2, + randomize_tasks=True, + **kwargs): + if task is None: + task = {} + self._task = task + self.tasks = self.sample_tasks(n_tasks) + self._goal = self.tasks[0]['goal'] + super(MultitaskAntEnv, self).__init__(**kwargs) + + """ + def step(self, action): + xposbefore = self.sim.data.qpos[0] + self.do_simulation(action, self.frame_skip) + xposafter = self.sim.data.qpos[0] + + forward_vel = (xposafter - xposbefore) / self.dt + forward_reward = -1.0 * abs(forward_vel - self._goal_vel) + ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) + + observation = self._get_obs() + reward = forward_reward - ctrl_cost + done = False + infos = dict(reward_forward=forward_reward, + reward_ctrl=-ctrl_cost, task=self._task) + return (observation, reward, done, infos) + """ + + + def get_all_task_idx(self): + return range(len(self.tasks)) + + def reset_task(self, idx): + try: + self._task = self.tasks[idx] + except IndexError as e: + import ipdb; ipdb.set_trace() + self._goal = self._task['goal'] # assume parameterization of task by single vector + self.reset() diff --git a/rlkit/envs/pearl_envs/ant_normal.py b/rlkit/envs/pearl_envs/ant_normal.py new file mode 100644 index 000000000..db76d411d --- /dev/null +++ b/rlkit/envs/pearl_envs/ant_normal.py @@ -0,0 +1,24 @@ +from gym.envs.mujoco import AntEnv + + +class AntNormal(AntEnv): + def __init__( + self, + *args, + n_tasks=2, # number of distinct tasks in this domain, shoudl equal sum of train and eval tasks + randomize_tasks=True, # shuffle the tasks after creating them + **kwargs + ): + self.tasks = [0 for _ in range(n_tasks)] + self._goal = 0 + super().__init__(*args, **kwargs) + + def get_all_task_idx(self): + return self.tasks + + def reset_task(self, idx): + # not tasks. just give the same reward every time step. + pass + + def sample_tasks(self, num_tasks): + return [0 for _ in range(num_tasks)] diff --git a/rlkit/envs/pearl_envs/assets/ant.xml b/rlkit/envs/pearl_envs/assets/ant.xml new file mode 100644 index 000000000..8ae1bc865 --- /dev/null +++ b/rlkit/envs/pearl_envs/assets/ant.xml @@ -0,0 +1,86 @@ + + + + diff --git a/rlkit/envs/pearl_envs/assets/low_gear_ratio_ant.xml b/rlkit/envs/pearl_envs/assets/low_gear_ratio_ant.xml new file mode 100644 index 000000000..c2a2711f4 --- /dev/null +++ b/rlkit/envs/pearl_envs/assets/low_gear_ratio_ant.xml @@ -0,0 +1,84 @@ + + + + diff --git a/rlkit/envs/pearl_envs/half_cheetah.py b/rlkit/envs/pearl_envs/half_cheetah.py new file mode 100644 index 000000000..33c5baecd --- /dev/null +++ b/rlkit/envs/pearl_envs/half_cheetah.py @@ -0,0 +1,26 @@ +import numpy as np +from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_ + +class HalfCheetahEnv(HalfCheetahEnv_): + def _get_obs(self): + return np.concatenate([ + self.sim.data.qpos.flat[1:], + self.sim.data.qvel.flat, + self.get_body_com("torso").flat, + ]).astype(np.float32).flatten() + + def viewer_setup(self): + camera_id = self.model.camera_name2id('track') + self.viewer.cam.type = 2 + self.viewer.cam.fixedcamid = camera_id + self.viewer.cam.distance = self.model.stat.extent * 0.35 + # Hide the overlay + self.viewer._hide_overlay = True + + def render(self, mode='human', width=500, height=500, **kwargs): + if mode == 'rgb_array': + self._get_viewer(mode).render(width=width, height=height) + data = self._get_viewer(mode).read_pixels(width, height, depth=False)[::-1, :, :] + return data + elif mode == 'human': + self._get_viewer(mode).render() diff --git a/rlkit/envs/pearl_envs/half_cheetah_dir.py b/rlkit/envs/pearl_envs/half_cheetah_dir.py new file mode 100644 index 000000000..9ed21f4aa --- /dev/null +++ b/rlkit/envs/pearl_envs/half_cheetah_dir.py @@ -0,0 +1,60 @@ +import numpy as np + +from .half_cheetah import HalfCheetahEnv + + +class HalfCheetahDirEnv(HalfCheetahEnv): + """Half-cheetah environment with target direction, as described in [1]. The + code is adapted from + https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand_direc.py + + The half-cheetah follows the dynamics from MuJoCo [2], and receives at each + time step a reward composed of a control cost and a reward equal to its + velocity in the target direction. The tasks are generated by sampling the + target directions from a Bernoulli distribution on {-1, 1} with parameter + 0.5 (-1: backward, +1: forward). + + [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic + Meta-Learning for Fast Adaptation of Deep Networks", 2017 + (https://arxiv.org/abs/1703.03400) + [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for + model-based control", 2012 + (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) + """ + def __init__(self, task={}, n_tasks=2, randomize_tasks=False): + directions = [-1, 1] + self.tasks = [{'direction': direction} for direction in directions] + self._task = task + self._goal_dir = task.get('direction', 1) + self._goal = self._goal_dir + super(HalfCheetahDirEnv, self).__init__() + + def step(self, action): + xposbefore = self.sim.data.qpos[0] + self.do_simulation(action, self.frame_skip) + xposafter = self.sim.data.qpos[0] + + forward_vel = (xposafter - xposbefore) / self.dt + forward_reward = self._goal_dir * forward_vel + ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) + + observation = self._get_obs() + reward = forward_reward - ctrl_cost + done = False + infos = dict(reward_forward=forward_reward, + reward_ctrl=-ctrl_cost, task=self._task) + return (observation, reward, done, infos) + + def sample_tasks(self, num_tasks): + directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1 + tasks = [{'direction': direction} for direction in directions] + return tasks + + def get_all_task_idx(self): + return list(range(len(self.tasks))) + + def reset_task(self, idx): + self._task = self.tasks[idx] + self._goal_dir = self._task['direction'] + self._goal = self._goal_dir + self.reset() diff --git a/rlkit/envs/pearl_envs/half_cheetah_vel.py b/rlkit/envs/pearl_envs/half_cheetah_vel.py new file mode 100644 index 000000000..1263f6043 --- /dev/null +++ b/rlkit/envs/pearl_envs/half_cheetah_vel.py @@ -0,0 +1,65 @@ +import numpy as np + +from .half_cheetah import HalfCheetahEnv + + +class HalfCheetahVelEnv(HalfCheetahEnv): + """Half-cheetah environment with target velocity, as described in [1]. The + code is adapted from + https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand.py + + The half-cheetah follows the dynamics from MuJoCo [2], and receives at each + time step a reward composed of a control cost and a penalty equal to the + difference between its current velocity and the target velocity. The tasks + are generated by sampling the target velocities from the uniform + distribution on [0, 2]. + + [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic + Meta-Learning for Fast Adaptation of Deep Networks", 2017 + (https://arxiv.org/abs/1703.03400) + [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for + model-based control", 2012 + (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) + """ + def __init__(self, task={}, presampled_tasks=None, n_tasks=2, randomize_tasks=True): + self._task = task + self.tasks = presampled_tasks or self.sample_tasks(n_tasks) + self._goal_vel = self.tasks[0].get('velocity', 0.0) + self._goal = self._goal_vel + super(HalfCheetahVelEnv, self).__init__() + + def step(self, action): + xposbefore = self.sim.data.qpos[0] + self.do_simulation(action, self.frame_skip) + xposafter = self.sim.data.qpos[0] + + forward_vel = (xposafter - xposbefore) / self.dt + forward_reward = -1.0 * abs(forward_vel - self._goal_vel) + ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) + + observation = self._get_obs() + reward = forward_reward - ctrl_cost + done = False + infos = dict( + reward_forward=forward_reward, + reward_ctrl=-ctrl_cost, + goal_vel=self._goal_vel, + forward_vel=forward_vel, + xposbefore=xposbefore, + ) + return (observation, reward, done, infos) + + def sample_tasks(self, num_tasks): + np.random.seed(1337) + velocities = np.random.uniform(0.0, 3.0, size=(num_tasks,)) + tasks = [{'velocity': velocity} for velocity in velocities] + return tasks + + def get_all_task_idx(self): + return range(len(self.tasks)) + + def reset_task(self, idx): + self._task = self.tasks[idx] + self._goal_vel = self._task['velocity'] + self._goal = self._goal_vel + self.reset() diff --git a/rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py b/rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py new file mode 100644 index 000000000..67b2171f2 --- /dev/null +++ b/rlkit/envs/pearl_envs/hopper_rand_params_wrapper.py @@ -0,0 +1,17 @@ +from rlkit.envs.pearl_envs.rand_param_envs.hopper_rand_params import HopperRandParamsEnv + + +class HopperRandParamsWrappedEnv(HopperRandParamsEnv): + def __init__(self, n_tasks=2, randomize_tasks=True): + super(HopperRandParamsWrappedEnv, self).__init__() + self.tasks = self.sample_tasks(n_tasks) + self.reset_task(0) + + def get_all_task_idx(self): + return range(len(self.tasks)) + + def reset_task(self, idx): + self._task = self.tasks[idx] + self._goal = idx + self.set_task(self._task) + self.reset() diff --git a/rlkit/envs/pearl_envs/humanoid_dir.py b/rlkit/envs/pearl_envs/humanoid_dir.py new file mode 100644 index 000000000..4c4c03d1d --- /dev/null +++ b/rlkit/envs/pearl_envs/humanoid_dir.py @@ -0,0 +1,59 @@ +import numpy as np +from gym.envs.mujoco import HumanoidEnv as HumanoidEnv + + +def mass_center(model, sim): + mass = np.expand_dims(model.body_mass, 1) + xpos = sim.data.xipos + return (np.sum(mass * xpos, 0) / np.sum(mass)) + + +class HumanoidDirEnv(HumanoidEnv): + + def __init__(self, task={}, n_tasks=2, randomize_tasks=True): + self.tasks = self.sample_tasks(n_tasks) + self.reset_task(0) + super(HumanoidDirEnv, self).__init__() + + def step(self, action): + pos_before = np.copy(mass_center(self.model, self.sim)[:2]) + self.do_simulation(action, self.frame_skip) + pos_after = mass_center(self.model, self.sim)[:2] + + alive_bonus = 5.0 + data = self.sim.data + goal_direction = (np.cos(self._goal), np.sin(self._goal)) + lin_vel_cost = 0.25 * np.sum(goal_direction * (pos_after - pos_before)) / self.model.opt.timestep + quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() + quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum() + quad_impact_cost = min(quad_impact_cost, 10) + reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus + qpos = self.sim.data.qpos + done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) + + return self._get_obs(), reward, done, dict(reward_linvel=lin_vel_cost, + reward_quadctrl=-quad_ctrl_cost, + reward_alive=alive_bonus, + reward_impact=-quad_impact_cost) + + def _get_obs(self): + data = self.sim.data + return np.concatenate([data.qpos.flat[2:], + data.qvel.flat, + data.cinert.flat, + data.cvel.flat, + data.qfrc_actuator.flat, + data.cfrc_ext.flat]) + + def get_all_task_idx(self): + return range(len(self.tasks)) + + def reset_task(self, idx): + self._task = self.tasks[idx] + self._goal = self._task['goal'] # assume parameterization of task by single vector + + def sample_tasks(self, num_tasks): + # velocities = np.random.uniform(0., 1.0 * np.pi, size=(num_tasks,)) + directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,)) + tasks = [{'goal': d} for d in directions] + return tasks diff --git a/rlkit/envs/pearl_envs/mujoco_env.py b/rlkit/envs/pearl_envs/mujoco_env.py new file mode 100644 index 000000000..8ac6a9477 --- /dev/null +++ b/rlkit/envs/pearl_envs/mujoco_env.py @@ -0,0 +1,62 @@ +import os +from os import path + +import mujoco_py +import numpy as np +from gym.envs.mujoco import mujoco_env + +from rlkit.core.serializable import Serializable + +ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') + + +class MujocoEnv(mujoco_env.MujocoEnv, Serializable): + """ + My own wrapper around MujocoEnv. + + The caller needs to declare + """ + def __init__( + self, + model_path, + frame_skip=1, + model_path_is_local=True, + automatically_set_obs_and_action_space=False, + ): + if model_path_is_local: + model_path = get_asset_xml(model_path) + if automatically_set_obs_and_action_space: + mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) + else: + """ + Code below is copy/pasted from MujocoEnv's __init__ function. + """ + if model_path.startswith("/"): + fullpath = model_path + else: + fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) + if not path.exists(fullpath): + raise IOError("File %s does not exist" % fullpath) + self.frame_skip = frame_skip + self.model = mujoco_py.MjModel(fullpath) + self.data = self.model.data + self.viewer = None + + self.metadata = { + 'render.modes': ['human', 'rgb_array'], + 'video.frames_per_second': int(np.round(1.0 / self.dt)) + } + + self.init_qpos = self.model.data.qpos.ravel().copy() + self.init_qvel = self.model.data.qvel.ravel().copy() + self._seed() + + def init_serialization(self, locals): + Serializable.quick_init(self, locals) + + def log_diagnostics(self, *args, **kwargs): + pass + + +def get_asset_xml(xml_name): + return os.path.join(ENV_ASSET_DIR, xml_name) diff --git a/rlkit/envs/pearl_envs/point_robot.py b/rlkit/envs/pearl_envs/point_robot.py new file mode 100644 index 000000000..c39411cd5 --- /dev/null +++ b/rlkit/envs/pearl_envs/point_robot.py @@ -0,0 +1,170 @@ +import numpy as np +from gym import spaces +from gym import Env + + +class PointEnv(Env): + """ + point robot on a 2-D plane with position control + tasks (aka goals) are positions on the plane + + - tasks sampled from unit square + - reward is L2 distance + """ + GOAL_SIZE = 0.1 # fraction of image + GOAL_COLOR = np.array([0, 255, 0], dtype=np.uint8) + AGENT_SIZE = 0.05 + AGENT_COLOR = np.array([0, 0, 255], dtype=np.uint8) + + def __init__(self, randomize_tasks=False, n_tasks=2): + if randomize_tasks: + goals = [[np.random.uniform(-1., 1.), np.random.uniform(-1., 1.)] for _ in range(n_tasks)] + else: + # some hand-coded goals for debugging + goals = [np.array([10, -10]), + np.array([10, 10]), + np.array([-10, 10]), + np.array([-10, -10]), + np.array([0, 0]), + + np.array([7, 2]), + np.array([0, 4]), + np.array([-6, 9]) + ] + goals = [g / 10. for g in goals] + self.goals = goals + + self.reset_task(0) + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,)) + self.action_space = spaces.Box(low=-0.1, high=0.1, shape=(2,)) + + @property + def tasks(self): + return self.goals + + @tasks.setter + def tasks(self, value): + self.goals =value + + def reset_task(self, idx): + ''' reset goal AND reset the agent ''' + self._goal = self.goals[idx] + self.reset() + + def get_all_task_idx(self): + return range(len(self.goals)) + + def reset_model(self): + # reset to a random location on the unit square + self._state = np.random.uniform(-1., 1., size=(2,)) + return self._get_obs() + + def reset(self): + return self.reset_model() + + def _get_obs(self): + return np.copy(self._state) + + def step(self, action): + self._state = self._state + action + x, y = self._state + x -= self._goal[0] + y -= self._goal[1] + reward = - (x ** 2 + y ** 2) ** 0.5 + done = False + ob = self._get_obs() + return ob, reward, done, dict() + + def viewer_setup(self): + print('no viewer') + pass + + def render(self): + print('current state:', self._state) + + def get_image(self, width, height): + white_img = np.zeros((height, width, 3), dtype=np.uint8) + img_with_goal = draw( + self._goal, + width, + height, + white_img, + self.GOAL_SIZE, + self.GOAL_COLOR + ) + final_img = draw( + self._state, + width, + height, + img_with_goal, + self.AGENT_SIZE, + self.AGENT_COLOR + ) + return final_img + + +def draw(xy, width, height, img, size, color): + x, y = xy + x_pixel = map_to_int(x, [-1, 1], [0, width]) + y_pixel = map_to_int(y, [-1, 1], [0, height]) + + x_min = int(max(x_pixel-size * width, 0)) + x_max = int(min(x_pixel+size * width, width)) + + y_min = int(max(y_pixel - size * height, 0)) + y_max = int(min(y_pixel + size * height, height)) + + img[y_min:y_max, x_min:x_max, :] = color + return img + + +def map_to_int(x, in_range, out_range): + min_x, max_x = in_range + min_y, max_y = out_range + normalized_x = (x - min_x) / (max_x - min_x) + return (max_y - min_y) * normalized_x + min_y + + +class SparsePointEnv(PointEnv): + ''' + - tasks sampled from unit half-circle + - reward is L2 distance given only within goal radius + + NOTE that `step()` returns the dense reward because this is used during meta-training + the algorithm should call `sparsify_rewards()` to get the sparse rewards + ''' + def __init__(self, randomize_tasks=False, n_tasks=2, goal_radius=0.2): + super().__init__(randomize_tasks, n_tasks) + self.goal_radius = goal_radius + + if randomize_tasks: + np.random.seed(1337) + radius = 1.0 + angles = np.linspace(0, np.pi, num=n_tasks) + xs = radius * np.cos(angles) + ys = radius * np.sin(angles) + goals = np.stack([xs, ys], axis=1) + np.random.shuffle(goals) + goals = goals.tolist() + + self.goals = goals + self.reset_task(0) + + def sparsify_rewards(self, r): + ''' zero out rewards when outside the goal radius ''' + mask = (r >= -self.goal_radius).astype(np.float32) + r = r * mask + return r + + def reset_model(self): + self._state = np.array([0, 0]) + return self._get_obs() + + def step(self, action): + ob, reward, done, d = super().step(action) + sparse_reward = self.sparsify_rewards(reward) + # make sparse rewards positive + if reward >= -self.goal_radius: + sparse_reward += 1 + d.update({'sparse_reward': sparse_reward}) + return ob, reward, done, d diff --git a/rlkit/envs/pearl_envs/rand_param_envs/__init__.py b/rlkit/envs/pearl_envs/rand_param_envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rlkit/envs/pearl_envs/rand_param_envs/base.py b/rlkit/envs/pearl_envs/rand_param_envs/base.py new file mode 100644 index 000000000..ff1cfd6e8 --- /dev/null +++ b/rlkit/envs/pearl_envs/rand_param_envs/base.py @@ -0,0 +1,139 @@ +# from rlkit.envs.pearl_envs.rand_param_envs.gym.core import Env +# from rlkit.envs.pearl_envs.rand_param_envs.gym.envs.mujoco import MujocoEnv +import numpy as np +from gym import Env +from gym.envs.mujoco import MujocoEnv +# from rlkit.envs.mujoco.mujoco_env import MujocoEnv + + +class MetaEnv(Env): + def step(self, *args, **kwargs): + return self._step(*args, **kwargs) + + def sample_tasks(self, n_tasks): + """ + Samples task of the meta-environment + + Args: + n_tasks (int) : number of different meta-tasks needed + + Returns: + tasks (list) : an (n_tasks) length list of tasks + """ + raise NotImplementedError + + def set_task(self, task): + """ + Sets the specified task to the current environment + + Args: + task: task of the meta-learning environment + """ + raise NotImplementedError + + def get_task(self): + """ + Gets the task that the agent is performing in the current environment + + Returns: + task: task of the meta-learning environment + """ + raise NotImplementedError + + def log_diagnostics(self, paths, prefix): + """ + Logs env-specific diagnostic information + + Args: + paths (list) : list of all paths collected with this env during this iteration + prefix (str) : prefix for logger + """ + pass + +class RandomEnv(MetaEnv, MujocoEnv): + """ + This class provides functionality for randomizing the physical parameters of a mujoco model + The following parameters are changed: + - body_mass + - body_inertia + - damping coeff at the joints + """ + RAND_PARAMS = ['body_mass', 'dof_damping', 'body_inertia', 'geom_friction'] + RAND_PARAMS_EXTENDED = RAND_PARAMS + ['geom_size'] + + def __init__(self, log_scale_limit, file_name, *args, rand_params=RAND_PARAMS, **kwargs): + MujocoEnv.__init__(self, file_name, 4) + assert set(rand_params) <= set(self.RAND_PARAMS_EXTENDED), \ + "rand_params must be a subset of " + str(self.RAND_PARAMS_EXTENDED) + self.log_scale_limit = log_scale_limit + self.rand_params = rand_params + self.save_parameters() + + def sample_tasks(self, n_tasks): + """ + Generates randomized parameter sets for the mujoco env + + Args: + n_tasks (int) : number of different meta-tasks needed + + Returns: + tasks (list) : an (n_tasks) length list of tasks + """ + param_sets = [] + + for _ in range(n_tasks): + # body mass -> one multiplier for all body parts + + new_params = {} + + if 'body_mass' in self.rand_params: + body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.body_mass.shape) + new_params['body_mass'] = self.init_params['body_mass'] * body_mass_multiplyers + + # body_inertia + if 'body_inertia' in self.rand_params: + body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.body_inertia.shape) + new_params['body_inertia'] = body_inertia_multiplyers * self.init_params['body_inertia'] + + # damping -> different multiplier for different dofs/joints + if 'dof_damping' in self.rand_params: + dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.dof_damping.shape) + new_params['dof_damping'] = np.multiply(self.init_params['dof_damping'], dof_damping_multipliers) + + # friction at the body components + if 'geom_friction' in self.rand_params: + dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.geom_friction.shape) + new_params['geom_friction'] = np.multiply(self.init_params['geom_friction'], dof_damping_multipliers) + + param_sets.append(new_params) + + return param_sets + + def set_task(self, task): + for param, param_val in task.items(): + param_variable = getattr(self.model, param) + assert param_variable.shape == param_val.shape, 'shapes of new parameter value and old one must match' + param_variable[:] = param_val + # setattr(self.model, param, param_val) + self.cur_params = task + + def get_task(self): + return self.cur_params + + def save_parameters(self): + self.init_params = {} + if 'body_mass' in self.rand_params: + self.init_params['body_mass'] = self.model.body_mass + + # body_inertia + if 'body_inertia' in self.rand_params: + self.init_params['body_inertia'] = self.model.body_inertia + + # damping -> different multiplier for different dofs/joints + if 'dof_damping' in self.rand_params: + self.init_params['dof_damping'] = self.model.dof_damping + + # friction at the body components + if 'geom_friction' in self.rand_params: + self.init_params['geom_friction'] = self.model.geom_friction + self.cur_params = self.init_params \ No newline at end of file diff --git a/rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py b/rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py new file mode 100644 index 000000000..1324c8d5e --- /dev/null +++ b/rlkit/envs/pearl_envs/rand_param_envs/hopper_rand_params.py @@ -0,0 +1,54 @@ +import numpy as np +from gym import utils +from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv + + +class HopperRandParamsEnv(RandomEnv, utils.EzPickle): + def __init__(self, log_scale_limit=3.0): + RandomEnv.__init__(self, log_scale_limit, 'hopper.xml', 4) + utils.EzPickle.__init__(self) + + def _step(self, a): + posbefore = self.sim.data.qpos[0] + self.do_simulation(a, self.frame_skip) + posafter, height, ang = self.sim.data.qpos[0:3] + alive_bonus = 1.0 + reward = (posafter - posbefore) / self.dt + reward += alive_bonus + reward -= 1e-3 * np.square(a).sum() + s = self.state_vector() + done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and + (height > .7) and (abs(ang) < .2)) + ob = self._get_obs() + return ob, reward, done, {} + + def _get_obs(self): + return np.concatenate([ + self.sim.data.qpos.flat[1:], + np.clip(self.sim.data.qvel.flat, -10, 10) + ]) + + def reset_model(self): + qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq) + qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + self.set_state(qpos, qvel) + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.trackbodyid = 2 + self.viewer.cam.distance = self.model.stat.extent * 0.75 + self.viewer.cam.lookat[2] += .8 + self.viewer.cam.elevation = -20 + +if __name__ == "__main__": + + env = HopperRandParamsEnv() + tasks = env.sample_tasks(40) + while True: + env.reset() + env.set_task(np.random.choice(tasks)) + print(env.model.body_mass) + for _ in range(100): + env.render() + env.step(env.action_space.sample()) # take a random action + diff --git a/rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py b/rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py new file mode 100644 index 000000000..b43b349be --- /dev/null +++ b/rlkit/envs/pearl_envs/rand_param_envs/pr2_env_reach.py @@ -0,0 +1,82 @@ +import numpy as np +from gym import utils +from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv +import os + +class PR2Env(RandomEnv, utils.EzPickle): + + FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets/pr2.xml') + + def __init__(self, log_scale_limit=1.): + self.viewer = None + RandomEnv.__init__(self, log_scale_limit, 'pr2.xml', 4) + utils.EzPickle.__init__(self) + + def _get_obs(self): + return np.concatenate([ + self.model.data.qpos.flat[:7], + self.model.data.qvel.flat[:7], # Do not include the velocity of the target (should be 0). + self.get_tip_position().flat, + self.get_vec_tip_to_goal().flat, + ]) + + def get_tip_position(self): + return self.model.data.site_xpos[0] + + def get_vec_tip_to_goal(self): + tip_position = self.get_tip_position() + goal_position = self.goal + vec_tip_to_goal = goal_position - tip_position + return vec_tip_to_goal + + @property + def goal(self): + return self.model.data.qpos.flat[-3:] + + def _step(self, action): + + self.do_simulation(action, self.frame_skip) + + vec_tip_to_goal = self.get_vec_tip_to_goal() + distance_tip_to_goal = np.linalg.norm(vec_tip_to_goal) + + reward = - distance_tip_to_goal + + state = self.state_vector() + notdone = np.isfinite(state).all() + done = not notdone + + ob = self._get_obs() + + return ob, reward, done, {} + + def reset_model(self): + qpos = self.init_qpos + qvel = self.init_qvel + goal = np.random.uniform((0.2, -0.4, 0.5), (0.5, 0.4, 1.5)) + qpos[-3:] = goal + qpos[:7] += self.np_random.uniform(low=-.005, high=.005, size=7) + qvel[:7] += self.np_random.uniform(low=-.005, high=.005, size=7) + self.set_state(qpos, qvel) + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.distance = self.model.stat.extent * 2 + # self.viewer.cam.lookat[2] += .8 + self.viewer.cam.elevation = -50 + # self.viewer.cam.lookat[0] = self.model.stat.center[0] + # self.viewer.cam.lookat[1] = self.model.stat.center[1] + # self.viewer.cam.lookat[2] = self.model.stat.center[2] + + +if __name__ == "__main__": + + env = PR2Env() + tasks = env.sample_tasks(40) + while True: + env.reset() + env.set_task(np.random.choice(tasks)) + print(env.model.body_mass) + for _ in range(100): + env.render() + env.step(env.action_space.sample()) diff --git a/rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py b/rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py new file mode 100644 index 000000000..364f940fc --- /dev/null +++ b/rlkit/envs/pearl_envs/rand_param_envs/walker2d_rand_params.py @@ -0,0 +1,58 @@ +import numpy as np +from gym import utils +from rlkit.envs.pearl_envs.rand_param_envs.base import RandomEnv + + +class Walker2DRandParamsEnv(RandomEnv, utils.EzPickle): + def __init__(self, log_scale_limit=3.0): + RandomEnv.__init__(self, log_scale_limit, 'walker2d.xml', 5) + utils.EzPickle.__init__(self) + + def _step(self, a): + # import ipdb; ipdb.set_trace() + # posbefore = self.model.data.qpos[0, 0] + posbefore = self.sim.data.qpos[0] + self.do_simulation(a, self.frame_skip) + # posafter, height, ang = self.model.data.qpos[0:3, 0] + posafter, height, ang = self.sim.data.qpos[0:3] + alive_bonus = 1.0 + reward = ((posafter - posbefore) / self.dt) + reward += alive_bonus + reward -= 1e-3 * np.square(a).sum() + done = not (height > 0.8 and height < 2.0 and + ang > -1.0 and ang < 1.0) + ob = self._get_obs() + return ob, reward, done, {} + + def _get_obs(self): + # qpos = self.model.data.qpos + # qvel = self.model.data.qvel + qpos = self.sim.data.qpos + qvel = self.sim.data.qvel + return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel() + + def reset_model(self): + self.set_state( + self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq), + self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + ) + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.trackbodyid = 2 + self.viewer.cam.distance = self.model.stat.extent * 0.5 + self.viewer.cam.lookat[2] += .8 + self.viewer.cam.elevation = -20 + +if __name__ == "__main__": + + env = Walker2DRandParamsEnv() + tasks = env.sample_tasks(40) + while True: + env.reset() + env.set_task(np.random.choice(tasks)) + print(env.model.body_mass) + for _ in range(100): + env.render() + env.step(env.action_space.sample()) # take a random action + diff --git a/rlkit/envs/pearl_envs/walker_rand_params_wrapper.py b/rlkit/envs/pearl_envs/walker_rand_params_wrapper.py new file mode 100644 index 000000000..4d1bf6ed2 --- /dev/null +++ b/rlkit/envs/pearl_envs/walker_rand_params_wrapper.py @@ -0,0 +1,17 @@ +from rlkit.envs.pearl_envs.rand_param_envs.walker2d_rand_params import Walker2DRandParamsEnv + + +class WalkerRandParamsWrappedEnv(Walker2DRandParamsEnv): + def __init__(self, n_tasks=2, randomize_tasks=True): + super(WalkerRandParamsWrappedEnv, self).__init__() + self.tasks = self.sample_tasks(n_tasks) + self.reset_task(0) + + def get_all_task_idx(self): + return range(len(self.tasks)) + + def reset_task(self, idx): + self._task = self.tasks[idx] + self._goal = idx + self.set_task(self._task) + self.reset() diff --git a/rlkit/envs/pearl_envs/wrappers.py b/rlkit/envs/pearl_envs/wrappers.py new file mode 100644 index 000000000..667b43203 --- /dev/null +++ b/rlkit/envs/pearl_envs/wrappers.py @@ -0,0 +1,156 @@ +import numpy as np +from gym import Env +from gym.spaces import Box +import mujoco_py + +from rlkit.core.serializable import Serializable + + +class ProxyEnv(Serializable, Env): + def __init__(self, wrapped_env): + Serializable.quick_init(self, locals()) + self._wrapped_env = wrapped_env + self.action_space = self._wrapped_env.action_space + self.observation_space = self._wrapped_env.observation_space + + @property + def wrapped_env(self): + return self._wrapped_env + + def reset(self, **kwargs): + return self._wrapped_env.reset(**kwargs) + + def step(self, action): + return self._wrapped_env.step(action) + + def render(self, *args, **kwargs): + return self._wrapped_env.render(*args, **kwargs) + + def log_diagnostics(self, paths, *args, **kwargs): + if hasattr(self._wrapped_env, 'log_diagnostics'): + self._wrapped_env.log_diagnostics(paths, *args, **kwargs) + + @property + def horizon(self): + return self._wrapped_env.horizon + + def terminate(self): + if hasattr(self.wrapped_env, "terminate"): + self.wrapped_env.terminate() + + +class NormalizedBoxEnv(ProxyEnv, Serializable): + """ + Normalize action to in [-1, 1]. + + Optionally normalize observations and scale reward. + """ + def __init__( + self, + env, + reward_scale=1., + obs_mean=None, + obs_std=None, + ): + # self._wrapped_env needs to be called first because + # Serializable.quick_init calls getattr, on this class. And the + # implementation of getattr (see below) calls self._wrapped_env. + # Without setting this first, the call to self._wrapped_env would call + # getattr again (since it's not set yet) and therefore loop forever. + self._wrapped_env = env + # Or else serialization gets delegated to the wrapped_env. Serialize + # this env separately from the wrapped_env. + self._serializable_initialized = False + Serializable.quick_init(self, locals()) + ProxyEnv.__init__(self, env) + self._should_normalize = not (obs_mean is None and obs_std is None) + if self._should_normalize: + if obs_mean is None: + obs_mean = np.zeros_like(env.observation_space.low) + else: + obs_mean = np.array(obs_mean) + if obs_std is None: + obs_std = np.ones_like(env.observation_space.low) + else: + obs_std = np.array(obs_std) + self._reward_scale = reward_scale + self._obs_mean = obs_mean + self._obs_std = obs_std + ub = np.ones(self._wrapped_env.action_space.shape) + self.action_space = Box(-1 * ub, ub) + + def estimate_obs_stats(self, obs_batch, override_values=False): + if self._obs_mean is not None and not override_values: + raise Exception("Observation mean and std already set. To " + "override, set override_values to True.") + self._obs_mean = np.mean(obs_batch, axis=0) + self._obs_std = np.std(obs_batch, axis=0) + + def _apply_normalize_obs(self, obs): + return (obs - self._obs_mean) / (self._obs_std + 1e-8) + + def __getstate__(self): + d = Serializable.__getstate__(self) + # Add these explicitly in case they were modified + d["_obs_mean"] = self._obs_mean + d["_obs_std"] = self._obs_std + d["_reward_scale"] = self._reward_scale + return d + + def __setstate__(self, d): + Serializable.__setstate__(self, d) + self._obs_mean = d["_obs_mean"] + self._obs_std = d["_obs_std"] + self._reward_scale = d["_reward_scale"] + + def step(self, action): + lb = self._wrapped_env.action_space.low + ub = self._wrapped_env.action_space.high + scaled_action = lb + (action + 1.) * 0.5 * (ub - lb) + scaled_action = np.clip(scaled_action, lb, ub) + + wrapped_step = self._wrapped_env.step(scaled_action) + next_obs, reward, done, info = wrapped_step + if self._should_normalize: + next_obs = self._apply_normalize_obs(next_obs) + return next_obs, reward * self._reward_scale, done, info + + def __str__(self): + return "Normalized: %s" % self._wrapped_env + + def log_diagnostics(self, paths, **kwargs): + if hasattr(self._wrapped_env, "log_diagnostics"): + return self._wrapped_env.log_diagnostics(paths, **kwargs) + else: + return None + + def __getattr__(self, attrname): + return getattr(self._wrapped_env, attrname) + + +class CameraWrapper(object): + + def __init__(self, env, *args, **kwargs): + self._wrapped_env = env + self.initialize_camera() + + def get_image(self, width=256, height=256, camera_name=None): + # use sim.render to avoid MJViewer which doesn't seem to work without display + return self.sim.render( + width=width, + height=height, + camera_name=camera_name, + ) + + def initialize_camera(self): + # set camera parameters for viewing + sim = self.sim + viewer = mujoco_py.MjRenderContextOffscreen(sim) + camera = viewer.cam + camera.type = 1 + camera.trackbodyid = 0 + camera.elevation = -20 + sim.add_render_context(viewer) + + def __getattr__(self, attrname): + return getattr(self._wrapped_env, attrname) diff --git a/rlkit/launchers/launcher_util.py b/rlkit/launchers/launcher_util.py index f6a393c51..ce330cb2f 100644 --- a/rlkit/launchers/launcher_util.py +++ b/rlkit/launchers/launcher_util.py @@ -100,6 +100,7 @@ def run_experiment_here( base_log_dir=None, force_randomize_seed=False, log_dir=None, + unpack_variant=False, **setup_logger_kwargs ): """ @@ -163,7 +164,16 @@ def run_experiment_here( ), actual_log_dir ) - return experiment_function(variant) + if unpack_variant: + raw_variant = variant.copy() + raw_variant.pop('exp_id', None) + raw_variant.pop('seed', None) + raw_variant.pop('exp_prefix', None) + raw_variant.pop('logger_config', None) + raw_variant.pop('instance_type', None) + return experiment_function(**raw_variant) + else: + return experiment_function(variant) def create_exp_name(exp_prefix, exp_id=0, seed=0): @@ -434,6 +444,7 @@ def run_experiment( snapshot_gap=1, base_log_dir=None, local_input_dir_to_mount_point_dict=None, # TODO(vitchyr): test this + unpack_variant=False, # local settings skip_wait=False, # ec2 settings @@ -467,6 +478,14 @@ def foo(variant): `base_log_dir/-my-experiment/-my-experiment-` By default, the base_log_dir is determined by `config.LOCAL_LOG_DIR/` + :param unpack_variant: If True, the function will be called with + ``` + foo(**variant) + ``` + rather than + ``` + foo(variant) + ``` :param method_call: a function that takes in a dictionary as argument :param mode: A string: - 'local' @@ -577,6 +596,7 @@ def foo(variant): snapshot_gap=snapshot_gap, git_infos=git_infos, script_name=main.__file__, + unpack_variant=unpack_variant, ) if mode == 'here_no_doodad': run_experiment_kwargs['base_log_dir'] = base_log_dir diff --git a/rlkit/torch/smac/agent.py b/rlkit/torch/smac/agent.py new file mode 100644 index 000000000..e5f072480 --- /dev/null +++ b/rlkit/torch/smac/agent.py @@ -0,0 +1,289 @@ +"""Code based on https://github.com/katerakelly/oyster""" +import copy + +import numpy as np +import torch +import torch.nn.functional as F +from rlkit.util.wrapper import Wrapper +from torch import nn as nn + +import rlkit.torch.pytorch_util as ptu +from rlkit.policies.base import Policy +from rlkit.torch.distributions import ( + Delta, +) +from rlkit.torch.sac.policies import MakeDeterministic + + +def _product_of_gaussians(mus, sigmas_squared): + ''' + compute mu, sigma of product of gaussians + ''' + sigmas_squared = torch.clamp(sigmas_squared, min=1e-7) + sigma_squared = 1. / torch.sum(torch.reciprocal(sigmas_squared), dim=0) + mu = sigma_squared * torch.sum(mus / sigmas_squared, dim=0) + return mu, sigma_squared + + +def _mean_of_gaussians(mus, sigmas_squared): + ''' + compute mu, sigma of mean of gaussians + ''' + mu = torch.mean(mus, dim=0) + sigma_squared = torch.mean(sigmas_squared, dim=0) + return mu, sigma_squared + + +def _natural_to_canonical(n1, n2): + ''' convert from natural to canonical gaussian parameters ''' + mu = -0.5 * n1 / n2 + sigma_squared = -0.5 * 1 / n2 + return mu, sigma_squared + + +def _canonical_to_natural(mu, sigma_squared): + ''' convert from canonical to natural gaussian parameters ''' + n1 = mu / sigma_squared + n2 = -0.5 * 1 / sigma_squared + return n1, n2 + + +class SmacAgent(nn.Module): + + def __init__(self, + latent_dim, + context_encoder, + policy, + reward_predictor, + use_next_obs_in_context=False, + _debug_ignore_context=False, + _debug_do_not_sqrt=False, + _debug_use_ground_truth_context=False + ): + super().__init__() + self.latent_dim = latent_dim + + self.context_encoder = context_encoder + self.policy = policy + self.reward_predictor = reward_predictor + self.deterministic_policy = MakeDeterministic(self.policy) + self._debug_ignore_context = _debug_ignore_context + self._debug_use_ground_truth_context = _debug_use_ground_truth_context + + # self.recurrent = kwargs['recurrent'] + # self.use_ib = kwargs['use_information_bottleneck'] + # self.sparse_rewards = kwargs['sparse_rewards'] + self.use_next_obs_in_context = use_next_obs_in_context + + # initialize buffers for z dist and z + # use buffers so latent context can be saved along with model weights + self.register_buffer('z', torch.zeros(1, latent_dim)) + self.register_buffer('z_means', torch.zeros(1, latent_dim)) + self.register_buffer('z_vars', torch.zeros(1, latent_dim)) + + self.z_means = None + self.z_vars = None + self.context = None + self.z = None + + # rp = reward predictor + # TODO: add back in reward predictor code + self.z_means_rp = None + self.z_vars_rp = None + self.z_rp = None + self.context_encoder_rp = context_encoder + self._use_context_encoder_snapshot_for_reward_pred = False + + self.latent_prior = torch.distributions.Normal( + ptu.zeros(self.latent_dim), + ptu.ones(self.latent_dim) + ) + + self._debug_do_not_sqrt = _debug_do_not_sqrt + + def clear_z(self, num_tasks=1): + ''' + reset q(z|c) to the prior + sample a new z from the prior + ''' + # reset distribution over z to the prior + mu = ptu.zeros(num_tasks, self.latent_dim) + var = ptu.ones(num_tasks, self.latent_dim) + self.z_means = mu + self.z_vars = var + + @property + def use_context_encoder_snapshot_for_reward_pred(self): + return self._use_context_encoder_snapshot_for_reward_pred + + @use_context_encoder_snapshot_for_reward_pred.setter + def use_context_encoder_snapshot_for_reward_pred(self, value): + if value and not self.use_context_encoder_snapshot_for_reward_pred: + # copy context encoder on switch + self.context_encoder_rp = copy.deepcopy(self.context_encoder) + self.context_encoder_rp.to(ptu.device) + self.reward_predictor = copy.deepcopy(self.reward_predictor) + self.reward_predictor.to(ptu.device) + self._use_context_encoder_snapshot_for_reward_pred = value + + def detach_z(self): + ''' disable backprop through z ''' + self.z = self.z.detach() + if self.recurrent: + self.context_encoder.hidden = self.context_encoder.hidden.detach() + + self.z_rp = self.z_rp.detach() + if self.recurrent: + self.context_encoder_rp.hidden = self.context_encoder_rp.hidden.detach() + + def update_context(self, context, inputs): + ''' append single transition to the current context ''' + if self._debug_use_ground_truth_context: + return context + o, a, r, no, d, info = inputs + o = ptu.from_numpy(o[None, None, ...]) + a = ptu.from_numpy(a[None, None, ...]) + r = ptu.from_numpy(np.array([r])[None, None, ...]) + no = ptu.from_numpy(no[None, None, ...]) + + if self.use_next_obs_in_context: + data = torch.cat([o, a, r, no], dim=2) + else: + data = torch.cat([o, a, r], dim=2) + if context is None: + context = data + else: + try: + context = torch.cat([context, data], dim=1) + except Exception as e: + import ipdb; ipdb.set_trace() + return context + + def compute_kl_div(self): + ''' compute KL( q(z|c) || r(z) ) ''' + prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) + posteriors = [torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip(torch.unbind(self.z_means), torch.unbind(self.z_vars))] + kl_divs = [torch.distributions.kl.kl_divergence(post, prior) for post in posteriors] + kl_div_sum = torch.sum(torch.stack(kl_divs)) + return kl_div_sum + + def batched_latent_prior(self, batch_size): + return torch.distributions.Normal( + ptu.zeros(batch_size, self.latent_dim), + ptu.ones(batch_size, self.latent_dim) + ) + + def latent_posterior(self, context, squeeze=False, for_reward_prediction=False): + ''' compute q(z|c) as a function of input context and sample new z from it''' + if isinstance(context, np.ndarray): + context = ptu.from_numpy(context) + if self._debug_use_ground_truth_context: + if squeeze: + context = context.squeeze(dim=0) + return Delta(context) + if for_reward_prediction: + context_encoder = self.context_encoder_rp + else: + context_encoder = self.context_encoder + params = context_encoder(context) + params = params.view(context.size(0), -1, context_encoder.output_size) + mu = params[..., :self.latent_dim] + sigma_squared = F.softplus(params[..., self.latent_dim:]) + z_params = [_product_of_gaussians(m, s) for m, s in zip(torch.unbind(mu), torch.unbind(sigma_squared))] + z_means = torch.stack([p[0] for p in z_params]) + z_vars = torch.stack([p[1] for p in z_params]) + if squeeze: + z_means = z_means.squeeze(dim=0) + z_vars = z_vars.squeeze(dim=0) + if self._debug_do_not_sqrt: + return torch.distributions.Normal(z_means, z_vars) + else: + return torch.distributions.Normal(z_means, torch.sqrt(z_vars)) + + def get_action(self, obs, z, deterministic=False): + ''' sample action from the policy, conditioned on the task embedding ''' + obs = ptu.from_numpy(obs[None]) + if self._debug_ignore_context: + z = ptu.from_numpy(z[None]) * 0 + else: + z = ptu.from_numpy(z[None]) + if len(obs.shape) != len(z.shape): + import ipdb; ipdb.set_trace() + in_ = torch.cat([obs, z], dim=1)[0] + if deterministic: + return self.deterministic_policy.get_action(in_) + else: + return self.policy.get_action(in_) + + def set_num_steps_total(self, n): + self.policy.set_num_steps_total(n) + + def forward( + self, obs, context, + return_task_z=False, + return_latent_posterior=False, + return_latent_posterior_and_task_z=False, + ): + ''' given context, get statistics under the current policy of a set of observations ''' + context_distrib = self.latent_posterior(context) + task_z = context_distrib.rsample() + + t, b, _ = obs.size() + obs = obs.view(t * b, -1) + task_z = [z.repeat(b, 1) for z in task_z] + task_z = torch.cat(task_z, dim=0) + + # run policy, get log probs and new actions + in_ = torch.cat([obs, task_z.detach()], dim=1) + action_distribution = self.policy(in_) + # policy_outputs = self.policy(in_, reparameterize=True, return_log_prob=True) + if return_latent_posterior_and_task_z: + return action_distribution, context_distrib, task_z + if return_latent_posterior: + return action_distribution, context_distrib + if return_task_z: + return action_distribution, task_z + else: + return action_distribution + + # return policy_outputs, task_z + + def infer_reward(self, obs, action, z): + obs = ptu.from_numpy(obs[None]) + action = ptu.from_numpy(action[None]) + z = ptu.from_numpy(z[None]) + reward = self.reward_predictor(obs, action, z) + return ptu.get_numpy(reward)[0] + + def log_diagnostics(self, eval_statistics): + ''' + adds logging data about encodings to eval_statistics + ''' + z_mean = np.mean(np.abs(ptu.get_numpy(self.z_means[0]))) + z_sig = np.mean(ptu.get_numpy(self.z_vars[0])) + eval_statistics['Z mean eval'] = z_mean + eval_statistics['Z variance eval'] = z_sig + + # z_mean_rp = np.mean(np.abs(ptu.get_numpy(self.z_means_rp[0]))) + # z_sig_rp = np.mean(ptu.get_numpy(self.z_vars_rp[0])) + # eval_statistics['Z rew-pred mean eval'] = z_mean_rp + # eval_statistics['Z rew-pred variance eval'] = z_sig_rp + + @property + def networks(self): + if self.context_encoder is self.context_encoder_rp: + return [self.context_encoder, self.policy] + else: + return [self.context_encoder, self.context_encoder_rp, self.policy] + + +class MakeSMACAgentDeterministic(Wrapper, Policy): + def __init__(self, stochastic_policy): + super().__init__(stochastic_policy) + self.stochastic_policy = stochastic_policy + + def get_action(self, *args): + return self.stochastic_policy.get_action(*args, deterministic=True) + + def get_actions(self, *args): + return self.stochastic_policy.get_actions(*args, deterministic=True) diff --git a/rlkit/torch/smac/base_config.py b/rlkit/torch/smac/base_config.py new file mode 100644 index 000000000..9e6aa9ac4 --- /dev/null +++ b/rlkit/torch/smac/base_config.py @@ -0,0 +1,164 @@ +DEFAULT_CONFIG = { + "qf_kwargs": { + "hidden_sizes": [300, 300, 300], + }, + "policy_kwargs": { + "hidden_sizes": [300, 300, 300], + }, + "logger_config": { + "snapshot_mode": "gap_and_last", + "snapshot_gap": 25, + }, + "context_decoder_kwargs": { + "hidden_sizes": [64, 64], + }, + "save_video": False, + "save_video_period": 25, + + "pretrain_rl": True, + + "trainer_kwargs": { + "beta": 100.0, + "alpha": 0.0, + "rl_weight": 1.0, + "use_awr_update": True, + "use_reparam_update": False, + "use_automatic_entropy_tuning": False, + "awr_weight": 1.0, + "bc_weight": 0.0, + "compute_bc": False, + "awr_use_mle_for_vf": False, + "awr_sample_actions": False, + "awr_min_q": True, + "reparam_weight": 0.0, + "backprop_q_loss_into_encoder": False, + "train_context_decoder": True, + + "soft_target_tau": 0.005, # for SAC target network update + "target_update_period": 1, + "policy_lr": 3E-4, + "qf_lr": 3E-4, + "context_lr": 3e-4, + "kl_lambda": .1, # weight on KL divergence term in encoder loss + "use_information_bottleneck": True, # False makes latent context deterministic + "use_next_obs_in_context": False, # use next obs if it is useful in distinguishing tasks + "sparse_rewards": False, # whether to sparsify rewards as determined in env + "recurrent": False, # recurrent or permutation-invariant encoder + "discount": 0.99, # RL discount factor + "reward_scale": 5.0, # scale rewards before constructing Bellman update, effectively controls weight on the entropy of the policy + }, + "tags": {}, + "latent_dim": 5, + "algo_kwargs": { + "use_rl_buffer_for_enc_buffer": True, + "freeze_encoder_buffer_in_unsupervised_phase": False, + "clear_encoder_buffer_before_every_update": False, + "num_iterations_with_reward_supervision": 0, + "exploration_resample_latent_period": 1, + "meta_batch": 4, + "embedding_batch_size": 256, + "num_initial_steps": 2000, + "num_steps_prior": 400, + "num_steps_posterior": 0, + "num_extra_rl_steps_posterior": 600, + "num_train_steps_per_itr": 4000, + "num_evals": 4, + "num_steps_per_eval": 600, + "num_exp_traj_eval": 2, + + "num_iterations": 501, # number of data sampling / training iterates + "num_tasks_sample": 5, # number of randomly sampled tasks to collect data for each iteration + "batch_size": 256, # number of transitions in the RL batch + "embedding_mini_batch_size": 64, # number of context transitions to backprop through (should equal the arg above except in the recurrent encoder case) + "max_path_length": 200, # max path length for this environment + "update_post_train": 1, # how often to resample the context when collecting data during training (in trajectories) + "dump_eval_paths": False, # whether to save evaluation trajectories + "save_extra_manual_epoch_list": [0, 50, 100, 200, 300, 400, 500], + "save_extra_manual_beginning_epoch_list": [0], + "save_replay_buffer": False, + "save_algorithm": True, + }, + "online_trainer_kwargs": { + "awr_weight": 1.0, + "reparam_weight": 1.0, + "use_reparam_update": True, + "use_awr_update": True, + }, + "skip_initial_data_collection_if_pretrained": True, + "pretrain_offline_algo_kwargs": { + "batch_size": 128, + "logging_period": 1000, + "meta_batch_size": 4, + "num_batches": 50000, + "task_embedding_batch_size": 64, + }, + "n_train_tasks": 100, + "n_eval_tasks": 20, + "env_params": {}, +} + +DEFAULT_PEARL_CONFIG = { + "qf_kwargs": { + "hidden_sizes": [300, 300, 300], + }, + "vf_kwargs": { + "hidden_sizes": [300, 300, 300], + }, + "policy_kwargs": { + "hidden_sizes": [300, 300, 300], + }, + "trainer_kwargs": { + "soft_target_tau": 0.005, # for SAC target network update + "target_update_period": 1, + "policy_lr": 3E-4, + "qf_lr": 3E-4, + "context_lr": 3e-4, + "kl_lambda": .1, # weight on KL divergence term in encoder loss + "use_information_bottleneck": True, # False makes latent context deterministic + "use_next_obs_in_context": False, # use next obs if it is useful in distinguishing tasks + "sparse_rewards": False, # whether to sparsify rewards as determined in env + "recurrent": False, # recurrent or permutation-invariant encoder + "discount": 0.99, # RL discount factor + "reward_scale": 5.0, # scale rewards before constructing Bellman update, effectively controls weight on the entropy of the policy + "backprop_q_loss_into_encoder": True, + }, + "algo_kwargs": { + "num_iterations": 501, # number of data sampling / training iterates + "num_tasks_sample": 5, # number of randomly sampled tasks to collect data for each iteration + "batch_size": 256, # number of transitions in the RL batch + "embedding_mini_batch_size": 64, # number of context transitions to backprop through (should equal the arg above except in the recurrent encoder case) + "max_path_length": 200, # max path length for this environment + "update_post_train": 1, # how often to resample the context when collecting data during training (in trajectories) + "dump_eval_paths": False, # whether to save evaluation trajectories + "num_iterations_with_reward_supervision": None, + "save_extra_manual_epoch_list": [0, 50, 100, 200, 300, 400, 500], + "save_extra_manual_beginning_epoch_list": [0], + "save_replay_buffer": False, + "save_algorithm": True, + "exploration_resample_latent_period": 1, + + "freeze_encoder_buffer_in_unsupervised_phase": False, + "clear_encoder_buffer_before_every_update": True, + "meta_batch": 4, + "embedding_batch_size": 256, + "num_initial_steps": 2000, + "num_steps_prior": 400, + "num_steps_posterior": 0, + "num_extra_rl_steps_posterior": 600, + "num_train_steps_per_itr": 4000, + "num_evals": 4, + "num_steps_per_eval": 600, + "num_exp_traj_eval": 2, + }, + "latent_dim": 5, + "logger_config": { + "snapshot_mode": "gap_and_last", + "snapshot_gap": 25, + }, + "context_decoder_kwargs": { + "hidden_sizes": [64, 64], + }, + "env_params": {}, + "n_train_tasks": 100, + "n_eval_tasks": 20, +} diff --git a/rlkit/torch/smac/diagnostics.py b/rlkit/torch/smac/diagnostics.py new file mode 100644 index 000000000..3db7e2729 --- /dev/null +++ b/rlkit/torch/smac/diagnostics.py @@ -0,0 +1,27 @@ +from rlkit.envs.pearl_envs import ( + AntDirEnv, + HalfCheetahVelEnv, +) + + +def get_env_info_sizes(env): + info_sizes = {} + if isinstance(env.wrapped_env, AntDirEnv): + info_sizes = dict( + reward_forward=1, + reward_ctrl=1, + reward_contact=1, + reward_survive=1, + torso_velocity=3, + torso_xy=2, + ) + if isinstance(env.wrapped_env, HalfCheetahVelEnv): + info_sizes = dict( + reward_forward=1, + reward_ctrl=1, + goal_vel=1, + forward_vel=1, + xposbefore=1, + ) + + return info_sizes diff --git a/rlkit/torch/smac/launcher.py b/rlkit/torch/smac/launcher.py new file mode 100644 index 000000000..3efedc1f4 --- /dev/null +++ b/rlkit/torch/smac/launcher.py @@ -0,0 +1,264 @@ +import pickle + +import rlkit.torch.pytorch_util as ptu +from rlkit.core import logger +from rlkit.core.meta_rl_algorithm import MetaRLAlgorithm +from rlkit.core.simple_offline_rl_algorithm import ( + OfflineMetaRLAlgorithm, +) +from rlkit.data_management.env_replay_buffer import EnvReplayBuffer +from rlkit.demos.source.mdp_path_loader import MDPPathLoader +from rlkit.envs.pearl_envs import ENVS, register_pearl_envs +from rlkit.envs.wrappers import NormalizedBoxEnv +from rlkit.util.io import load_local_or_remote_file +from rlkit.torch.networks import ConcatMlp +from rlkit.torch.smac.agent import SmacAgent +from rlkit.torch.smac.diagnostics import get_env_info_sizes +from rlkit.torch.smac.networks import MlpEncoder, DummyMlpEncoder, MlpDecoder +from rlkit.torch.smac.launcher_util import ( + policy_class_from_str, + load_buffer_onto_algo, + EvalPearl, + load_macaw_buffer_onto_algo, + relabel_offline_data, +) +from rlkit.torch.smac.smac import SmacTrainer + + +def smac_experiment( + trainer_kwargs=None, + algo_kwargs=None, + qf_kwargs=None, + policy_kwargs=None, + context_encoder_kwargs=None, + context_decoder_kwargs=None, + env_name=None, + env_params=None, + path_loader_kwargs=None, + latent_dim=None, + policy_class="TanhGaussianPolicy", + # video/debug + debug=False, + use_dummy_encoder=False, + networks_ignore_context=False, + use_ground_truth_context=False, + save_video=False, + save_video_period=False, + # Pre-train params + pretrain_rl=False, + pretrain_offline_algo_kwargs=None, + pretrain_buffer_kwargs=None, + load_buffer_kwargs=None, + saved_tasks_path=None, + macaw_format_base_path=None, # overrides saved_tasks_path and load_buffer_kwargs + load_macaw_buffer_kwargs=None, + train_task_idxs=None, + eval_task_idxs=None, + relabel_offline_dataset=False, + skip_initial_data_collection_if_pretrained=False, + relabel_kwargs=None, + # PEARL + n_train_tasks=0, + n_eval_tasks=0, + use_next_obs_in_context=False, + tags=None, + online_trainer_kwargs=None, +): + if not skip_initial_data_collection_if_pretrained: + raise NotImplementedError("deprecated! make sure to skip it!") + if relabel_kwargs is None: + relabel_kwargs = {} + del tags + pretrain_buffer_kwargs = pretrain_buffer_kwargs or {} + context_decoder_kwargs = context_decoder_kwargs or {} + pretrain_offline_algo_kwargs = pretrain_offline_algo_kwargs or {} + online_trainer_kwargs = online_trainer_kwargs or {} + register_pearl_envs() + env_params = env_params or {} + context_encoder_kwargs = context_encoder_kwargs or {} + trainer_kwargs = trainer_kwargs or {} + path_loader_kwargs = path_loader_kwargs or {} + load_macaw_buffer_kwargs = load_macaw_buffer_kwargs or {} + + base_env = ENVS[env_name](**env_params) + if saved_tasks_path: + task_data = load_local_or_remote_file( + saved_tasks_path, file_type='joblib') + tasks = task_data['tasks'] + train_task_idxs = task_data['train_task_indices'] + eval_task_idxs = task_data['eval_task_indices'] + base_env.tasks = tasks + elif macaw_format_base_path is not None: + tasks = pickle.load( + open('{}/tasks.pkl'.format(macaw_format_base_path), 'rb')) + base_env.tasks = tasks + else: + tasks = base_env.tasks + task_indices = base_env.get_all_task_idx() + train_task_idxs = list(task_indices[:n_train_tasks]) + eval_task_idxs = list(task_indices[-n_eval_tasks:]) + if hasattr(base_env, 'task_to_vec'): + train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs] + eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs] + else: + train_tasks = [tasks[i] for i in train_task_idxs] + eval_tasks = [tasks[i] for i in eval_task_idxs] + if use_ground_truth_context: + latent_dim = len(train_tasks[0]) + expl_env = NormalizedBoxEnv(base_env) + + reward_dim = 1 + + if debug: + algo_kwargs['max_path_length'] = 50 + algo_kwargs['batch_size'] = 5 + algo_kwargs['num_epochs'] = 5 + algo_kwargs['num_eval_steps_per_epoch'] = 100 + algo_kwargs['num_expl_steps_per_train_loop'] = 100 + algo_kwargs['num_trains_per_train_loop'] = 10 + algo_kwargs['min_num_steps_before_training'] = 100 + + obs_dim = expl_env.observation_space.low.size + action_dim = expl_env.action_space.low.size + + if use_next_obs_in_context: + context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim + else: + context_encoder_input_dim = obs_dim + action_dim + reward_dim + context_encoder_output_dim = latent_dim * 2 + + def create_qf(): + return ConcatMlp( + input_size=obs_dim + action_dim + latent_dim, + output_size=1, + **qf_kwargs + ) + + qf1 = create_qf() + qf2 = create_qf() + target_qf1 = create_qf() + target_qf2 = create_qf() + + if isinstance(policy_class, str): + policy_class = policy_class_from_str(policy_class) + policy = policy_class( + obs_dim=obs_dim + latent_dim, + action_dim=action_dim, + **policy_kwargs, + ) + encoder_class = DummyMlpEncoder if use_dummy_encoder else MlpEncoder + context_encoder = encoder_class( + input_size=context_encoder_input_dim, + output_size=context_encoder_output_dim, + hidden_sizes=[200, 200, 200], + use_ground_truth_context=use_ground_truth_context, + **context_encoder_kwargs + ) + context_decoder = MlpDecoder( + input_size=obs_dim + action_dim + latent_dim, + output_size=1, + **context_decoder_kwargs + ) + reward_predictor = context_decoder + agent = SmacAgent( + latent_dim, + context_encoder, + policy, + reward_predictor, + use_next_obs_in_context=use_next_obs_in_context, + _debug_ignore_context=networks_ignore_context, + _debug_use_ground_truth_context=use_ground_truth_context, + ) + trainer = SmacTrainer( + agent=agent, + env=expl_env, + latent_dim=latent_dim, + qf1=qf1, + qf2=qf2, + target_qf1=target_qf1, + target_qf2=target_qf2, + reward_predictor=reward_predictor, + context_encoder=context_encoder, + context_decoder=context_decoder, + _debug_ignore_context=networks_ignore_context, + _debug_use_ground_truth_context=use_ground_truth_context, + **trainer_kwargs + ) + algorithm = MetaRLAlgorithm( + agent=agent, + env=expl_env, + trainer=trainer, + train_task_indices=train_task_idxs, + eval_task_indices=eval_task_idxs, + train_tasks=train_tasks, + eval_tasks=eval_tasks, + use_next_obs_in_context=use_next_obs_in_context, + use_ground_truth_context=use_ground_truth_context, + env_info_sizes=get_env_info_sizes(expl_env), + **algo_kwargs + ) + + if macaw_format_base_path: + load_macaw_buffer_onto_algo( + algo=algorithm, + base_directory=macaw_format_base_path, + train_task_idxs=train_task_idxs, + **load_macaw_buffer_kwargs + ) + elif load_buffer_kwargs: + load_buffer_onto_algo(algorithm, **load_buffer_kwargs) + if relabel_offline_dataset: + relabel_offline_data( + algorithm, + tasks=tasks, + env=expl_env.wrapped_env, + **relabel_kwargs + ) + if path_loader_kwargs: + replay_buffer = algorithm.replay_buffer.task_buffers[0] + enc_replay_buffer = algorithm.enc_replay_buffer.task_buffers[0] + demo_test_buffer = EnvReplayBuffer( + env=expl_env, **pretrain_buffer_kwargs) + path_loader = MDPPathLoader( + trainer, + replay_buffer=replay_buffer, + demo_train_buffer=enc_replay_buffer, + demo_test_buffer=demo_test_buffer, + **path_loader_kwargs + ) + path_loader.load_demos() + + if pretrain_rl: + eval_pearl_fn = EvalPearl(algorithm, train_task_idxs, eval_task_idxs) + pretrain_algo = OfflineMetaRLAlgorithm( + meta_replay_buffer=algorithm.meta_replay_buffer, + replay_buffer=algorithm.replay_buffer, + task_embedding_replay_buffer=algorithm.enc_replay_buffer, + trainer=trainer, + train_tasks=train_task_idxs, + extra_eval_fns=[eval_pearl_fn], + use_meta_learning_buffer=algorithm.use_meta_learning_buffer, + **pretrain_offline_algo_kwargs + ) + pretrain_algo.to(ptu.device) + logger.remove_tabular_output( + 'progress.csv', relative_to_snapshot_dir=True + ) + logger.add_tabular_output( + 'pretrain.csv', relative_to_snapshot_dir=True + ) + pretrain_algo.train() + logger.remove_tabular_output( + 'pretrain.csv', relative_to_snapshot_dir=True + ) + logger.add_tabular_output( + 'progress.csv', relative_to_snapshot_dir=True, + ) + if skip_initial_data_collection_if_pretrained: + algorithm.num_initial_steps = 0 + + algorithm.trainer.configure(**online_trainer_kwargs) + algorithm.to(ptu.device) + algorithm.train() + + diff --git a/rlkit/torch/smac/launcher_util.py b/rlkit/torch/smac/launcher_util.py new file mode 100644 index 000000000..34f4c5a7b --- /dev/null +++ b/rlkit/torch/smac/launcher_util.py @@ -0,0 +1,390 @@ +import glob +import re +from collections import OrderedDict +from pathlib import Path +from typing import List, Any + +import numpy as np + +import rlkit.torch.pytorch_util as ptu +from rlkit.core import eval_util +from rlkit.core.logging import append_log +from rlkit.core.meta_rl_algorithm import MetaRLAlgorithm +from rlkit.envs.pearl_envs import AntDirEnv, HalfCheetahVelEnv +from rlkit.util.io import load_local_or_remote_file +from rlkit.torch.sac.policies import GaussianPolicy, TanhGaussianPolicy + +ENV_PARAMS = { + 'HalfCheetah-v2': { + 'num_expl_steps_per_train_loop': 1000, + 'max_path_length': 1000, + 'env_demo_path': dict( + path="demos/icml2020/mujoco/hc_action_noise_15.npy", + obs_dict=False, + is_demo=True, + ), + 'env_offpolicy_data_path': dict( + path="demos/icml2020/mujoco/hc_off_policy_15_demos_100.npy", + obs_dict=False, + is_demo=False, + train_split=0.9, + ), + }, + 'Ant-v2': { + 'num_expl_steps_per_train_loop': 1000, + 'max_path_length': 1000, + 'env_demo_path': dict( + path="demos/icml2020/mujoco/ant_action_noise_15.npy", + obs_dict=False, + is_demo=True, + ), + 'env_offpolicy_data_path': dict( + path="demos/icml2020/mujoco/ant_off_policy_15_demos_100.npy", + obs_dict=False, + is_demo=False, + train_split=0.9, + ), + }, + 'Walker2d-v2': { + 'num_expl_steps_per_train_loop': 1000, + 'max_path_length': 1000, + 'env_demo_path': dict( + path="demos/icml2020/mujoco/walker_action_noise_15.npy", + obs_dict=False, + is_demo=True, + ), + 'env_offpolicy_data_path': dict( + path="demos/icml2020/mujoco/walker_off_policy_15_demos_100.npy", + obs_dict=False, + is_demo=False, + train_split=0.9, + ), + }, + + 'SawyerRigGrasp-v0': { + 'env_id': 'SawyerRigGrasp-v0', + # 'num_expl_steps_per_train_loop': 1000, + 'max_path_length': 50, + # 'num_epochs': 1000, + }, + + 'pen-binary-v0': { + 'env_id': 'pen-binary-v0', + 'max_path_length': 200, + 'sparse_reward': True, + 'env_demo_path': dict( + path="demos/icml2020/hand/pen2_sparse.npy", + # path="demos/icml2020/hand/sparsity/railrl_pen-binary-v0_demos.npy", + obs_dict=True, + is_demo=True, + ), + 'env_offpolicy_data_path': dict( + # path="demos/icml2020/hand/pen_bc_sparse1.npy", + # path="demos/icml2020/hand/pen_bc_sparse2.npy", + # path="demos/icml2020/hand/pen_bc_sparse3.npy", + # path="demos/icml2020/hand/pen_bc_sparse4.npy", + path="demos/icml2020/hand/pen_bc_sparse4.npy", + # path="ashvin/icml2020/hand/sparsity/bc/pen-binary1/run10/id*/video_*_*.p", + # sync_dir="ashvin/icml2020/hand/sparsity/bc/pen-binary1/run10", + obs_dict=False, + is_demo=False, + train_split=0.9, + ), + }, + 'door-binary-v0': { + 'env_id': 'door-binary-v0', + 'max_path_length': 200, + 'sparse_reward': True, + 'env_demo_path': dict( + path="demos/icml2020/hand/door2_sparse.npy", + # path="demos/icml2020/hand/sparsity/railrl_door-binary-v0_demos.npy", + obs_dict=True, + is_demo=True, + ), + 'env_offpolicy_data_path': dict( + # path="demos/icml2020/hand/door_bc_sparse1.npy", + # path="demos/icml2020/hand/door_bc_sparse3.npy", + path="demos/icml2020/hand/door_bc_sparse4.npy", + # path="ashvin/icml2020/hand/sparsity/bc/door-binary1/run10/id*/video_*_*.p", + # sync_dir="ashvin/icml2020/hand/sparsity/bc/door-binary1/run10", + obs_dict=False, + is_demo=False, + train_split=0.9, + ), + }, + 'relocate-binary-v0': { + 'env_id': 'relocate-binary-v0', + 'max_path_length': 200, + 'sparse_reward': True, + 'env_demo_path': dict( + path="demos/icml2020/hand/relocate2_sparse.npy", + # path="demos/icml2020/hand/sparsity/railrl_relocate-binary-v0_demos.npy", + obs_dict=True, + is_demo=True, + ), + 'env_offpolicy_data_path': dict( + # path="demos/icml2020/hand/relocate_bc_sparse1.npy", + path="demos/icml2020/hand/relocate_bc_sparse4.npy", + # path="ashvin/icml2020/hand/sparsity/bc/relocate-binary1/run10/id*/video_*_*.p", + # sync_dir="ashvin/icml2020/hand/sparsity/bc/relocate-binary1/run10", + obs_dict=False, + is_demo=False, + train_split=0.9, + ), + }, +} + + +def policy_class_from_str(policy_class): + if policy_class == 'GaussianPolicy': + return GaussianPolicy + elif policy_class == 'TanhGaussianPolicy': + return TanhGaussianPolicy + else: + raise ValueError(policy_class) + + +def relabel_data(source_data, task, env): + new_data = source_data.copy() # shallow copy + if isinstance(env, AntDirEnv): + ctrl_cost = - source_data['reward_ctrl'] + contact_cost = - source_data['reward_contact'] + survive_reward = source_data['reward_survive'] + torso_velocity = source_data['torso_velocity'] + + goal = task['goal'] + if env.direction_in_degrees: + goal = goal / 180 * np.pi + direct = (np.cos(goal), np.sin(goal)) + new_forward_reward = np.dot((torso_velocity[..., :2]/env.dt), direct).reshape(-1, 1) + new_rewards = new_forward_reward - ctrl_cost - contact_cost + survive_reward + elif isinstance(env, HalfCheetahVelEnv): + forward_vel = source_data['forward_vel'] + goal = task['velocity'] + action = source_data['actions'] + forward_reward = -1.0 * abs(forward_vel - goal) + ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action), axis=-1, keepdims=True) + new_rewards = forward_reward - ctrl_cost + else: + raise TypeError(str(env)) + new_data['rewards'] = new_rewards + return new_data + + +def relabel_offline_data( + algo: MetaRLAlgorithm, + tasks: List[Any], + env, + num_tasks_to_relabel='all', +): + key_to_original_data = { + k: buff.to_dict() for k, buff in algo.replay_buffer.task_buffers.items() + } + for source_task_idx in algo.replay_buffer.task_buffers: + source_data = key_to_original_data[source_task_idx] + if num_tasks_to_relabel == 'all': + target_tasks_to_relabel = list(algo.replay_buffer.task_buffers.keys()) + else: + target_tasks_to_relabel = np.random.choice( + list(algo.replay_buffer.task_buffers.keys()), + num_tasks_to_relabel + ) + for target_task_idx in target_tasks_to_relabel: + if source_task_idx == target_task_idx: + continue + target_task = tasks[target_task_idx] + relabeled_data = relabel_data(source_data, target_task, env) + target_buffer = algo.replay_buffer.task_buffers[target_task_idx] + target_buffer.add_from_dict(relabeled_data) + + key_to_original_enc_data = { + k: buff.to_dict() for k, buff in algo.enc_replay_buffer.task_buffers.items() + } + for source_task_idx in algo.enc_replay_buffer.task_buffers: + source_data = key_to_original_enc_data[source_task_idx] + for target_task_idx in algo.enc_replay_buffer.task_buffers: + if source_task_idx == target_task_idx: + continue + target_task = tasks[target_task_idx] + relabeled_data = relabel_data(source_data, target_task, env) + target_buffer = algo.enc_replay_buffer.task_buffers[target_task_idx] + target_buffer.add_from_dict(relabeled_data) + + +def load_macaw_buffer_onto_algo( + algo: MetaRLAlgorithm, + base_directory: str, + train_task_idxs: List[int], + # start_idx=0, + # end_idx=None, + rl_buffer_start_end_idxs=((0, None),), + encoder_buffer_start_end_idxs=((0, None),), + encoder_buffer_matches_rl_buffer=False, + # start_idx_enc=0, + # end_idx_enc=None, +): + base_dir = Path(base_directory) + task_idx_to_path = get_task_idx_to_path(base_dir, prefix='macaw_replay_buffer') + task_idx_to_enc_path = get_task_idx_to_path(base_dir, prefix='macaw_enc_replay_buffer') + + for task_idx in train_task_idxs: + dataset_path = task_idx_to_path[task_idx] + rl_data = np.load(dataset_path, allow_pickle=True).item() + for start_idx, end_idx in rl_buffer_start_end_idxs: + algo.replay_buffer.task_buffers[task_idx].add_from_dict( + rl_data, + start_idx=start_idx, + end_idx=end_idx, + ) + if algo.use_rl_buffer_for_enc_buffer: + return + for task_idx in train_task_idxs: + if encoder_buffer_matches_rl_buffer: + encoder_buffer_start_end_idxs = rl_buffer_start_end_idxs + dataset_path = task_idx_to_path[task_idx] + enc_data = np.load(dataset_path, allow_pickle=True).item() + else: + enc_dataset_path = task_idx_to_enc_path[task_idx] + enc_data = np.load(enc_dataset_path, allow_pickle=True).item() + for start_idx, end_idx in encoder_buffer_start_end_idxs: + algo.enc_replay_buffer.task_buffers[task_idx].add_from_dict( + enc_data, + start_idx=start_idx, + end_idx=end_idx, + ) + + +def get_task_idx_to_path(base_dir, prefix): + task_idx_to_path = {} + for buffer_path in glob.glob(str(base_dir / '{}*'.format(prefix))): + pattern = re.compile('{}_task_(\d+).npy'.format(prefix)) + match = pattern.search(buffer_path) + task_idx = int(match.group(1)) + task_idx_to_path[task_idx] = buffer_path + return task_idx_to_path + + +def load_buffer_onto_algo( + algo: MetaRLAlgorithm, + pretrain_buffer_path: str, + start_idx=0, + end_idx=None, + start_idx_enc=0, + end_idx_enc=None, + populate_enc_buffer_with_rl_data=False, +): + data = load_local_or_remote_file( + pretrain_buffer_path, + file_type='joblib', + ) + saved_replay_buffer = data['replay_buffer'] + saved_enc_replay_buffer = data['enc_replay_buffer'] + if algo.use_meta_learning_buffer: + for k in saved_replay_buffer.task_buffers: + if k not in saved_replay_buffer.task_buffers: + print("No saved buffer for task {}. Skipping.".format(k)) + continue + saved_buffer = saved_replay_buffer.task_buffers[k] + new_buffer = algo.meta_replay_buffer.create_buffer( + size=saved_buffer.num_steps_can_sample() + ) + new_buffer.copy_data( + saved_buffer, + start_idx=start_idx, + end_idx=end_idx, + ) + algo.meta_replay_buffer.append_buffer(new_buffer) + else: + rl_replay_buffer = algo.replay_buffer + encoder_replay_buffer = algo.enc_replay_buffer + for k in rl_replay_buffer.task_buffers: + if k not in saved_replay_buffer.task_buffers: + print("No saved buffer for task {}. Skipping.".format(k)) + continue + rl_replay_buffer.task_buffers[k].copy_data( + saved_replay_buffer.task_buffers[k], + start_idx=start_idx, + end_idx=end_idx, + ) + if algo.use_rl_buffer_for_enc_buffer: + return + if populate_enc_buffer_with_rl_data: + for k in encoder_replay_buffer.task_buffers: + if k not in saved_replay_buffer.task_buffers: + print("No saved buffer for task {}. Skipping.".format(k)) + continue + encoder_replay_buffer.task_buffers[k].copy_data( + saved_replay_buffer.task_buffers[k], + start_idx=start_idx, + end_idx=end_idx, + ) + else: + for k in encoder_replay_buffer.task_buffers: + if k not in saved_enc_replay_buffer.task_buffers: + print("No saved buffer for task {}. Skipping.".format(k)) + continue + encoder_replay_buffer.task_buffers[k].copy_data( + saved_enc_replay_buffer.task_buffers[k], + start_idx=start_idx_enc, + end_idx=end_idx_enc, + ) + + +class EvalPearl(object): + def __init__( + self, + algorithm: MetaRLAlgorithm, + train_task_indices: List[int], + test_task_indices: List[int], + ): + self.algorithm = algorithm + self.train_task_indices = train_task_indices + self.test_task_indices = test_task_indices + + def __call__(self): + results = OrderedDict() + for name, indices in [ + ('train_tasks', self.train_task_indices), + ('test_tasks', self.test_task_indices), + ]: + final_returns, online_returns, idx_to_final_context = self.algorithm._do_eval(indices, -1) + results['eval/adaptation/{}/final_returns Mean'.format(name)] = np.mean(final_returns) + results['eval/adaptation/{}/all_returns Mean'.format(name)] = np.mean(online_returns) + + if 'train' in name: + z_dist_log = self.algorithm._get_z_distribution_log( + idx_to_final_context + ) + append_log(results, z_dist_log, prefix='trainer/{}/'.format(name)) + + paths = [] + for idx in self.train_task_indices: + paths += self._get_init_from_buffer_path(idx) + results['eval/init_from_buffer/train_tasks/all_returns Mean'] = np.mean( + eval_util.get_average_returns(paths) + ) + return results + + def _get_init_from_buffer_path(self, idx): + if self.algorithm.use_meta_learning_buffer: + init_context = self.algorithm.meta_replay_buffer._sample_contexts( + [idx], + self.algorithm.embedding_batch_size + ) + else: + init_context = self.algorithm.enc_replay_buffer.sample_context( + idx, + self.algorithm.embedding_batch_size + ) + init_context = ptu.from_numpy(init_context) + p, _ = self.algorithm.sampler.obtain_samples( + deterministic=self.algorithm.eval_deterministic, + max_samples=self.algorithm.max_path_length, + accum_context=False, + max_trajs=1, + resample_latent_period=0, + update_posterior_period=0, + initial_context=init_context, + task_idx=idx, + ) + return p diff --git a/rlkit/torch/smac/networks.py b/rlkit/torch/smac/networks.py new file mode 100644 index 000000000..8a37656ef --- /dev/null +++ b/rlkit/torch/smac/networks.py @@ -0,0 +1,83 @@ +import torch +from torch import nn +import rlkit.torch.pytorch_util as ptu +from rlkit.torch.networks import ConcatMlp + + +class MlpEncoder(ConcatMlp): + ''' + encode context via MLP + ''' + def __init__(self, *args, use_ground_truth_context=False, **kwargs): + super().__init__(*args, **kwargs) + self.use_ground_truth_context = use_ground_truth_context + + def forward(self, context): + if self.use_ground_truth_context: + return context + else: + return super().forward(context) + + def reset(self, num_tasks=1): + pass + + +class MlpDecoder(ConcatMlp): + ''' + decoder context via MLP + ''' + pass + + +class DummyMlpEncoder(MlpEncoder): + def forward(self, *args, **kwargs): + output = super().forward(*args, **kwargs) + return 0 * output + # TODO: check if this caused issues + + +class RecurrentEncoder(ConcatMlp): + ''' + encode context via recurrent network + ''' + + def __init__(self, + *args, + **kwargs + ): + self.save_init_params(locals()) + super().__init__(*args, **kwargs) + self.hidden_dim = self.hidden_sizes[-1] + self.register_buffer('hidden', torch.zeros(1, 1, self.hidden_dim)) + + # input should be (task, seq, feat) and hidden should be (task, 1, feat) + + self.lstm = nn.LSTM(self.hidden_dim, self.hidden_dim, num_layers=1, batch_first=True) + + def forward(self, in_, return_preactivations=False): + # expects inputs of dimension (task, seq, feat) + task, seq, feat = in_.size() + out = in_.view(task * seq, feat) + + # embed with MLP + for i, fc in enumerate(self.fcs): + out = fc(out) + out = self.hidden_activation(out) + + out = out.view(task, seq, -1) + out, (hn, cn) = self.lstm(out, (self.hidden, torch.zeros(self.hidden.size()).to(ptu.device))) + self.hidden = hn + # take the last hidden state to predict z + out = out[:, -1, :] + + # output layer + preactivation = self.last_fc(out) + output = self.output_activation(preactivation) + if return_preactivations: + return output, preactivation + else: + return output + + def reset(self, num_tasks=1): + self.hidden = self.hidden.new_full((1, num_tasks, self.hidden_dim), 0) + diff --git a/rlkit/torch/smac/pearl.py b/rlkit/torch/smac/pearl.py new file mode 100644 index 000000000..82b083443 --- /dev/null +++ b/rlkit/torch/smac/pearl.py @@ -0,0 +1,347 @@ +from collections import OrderedDict +import copy +import numpy as np + +import torch +import torch.optim as optim +from torch import nn as nn +from torch.distributions import kl_divergence + +import rlkit.torch.pytorch_util as ptu + +from rlkit.core.eval_util import create_stats_ordered_dict +from rlkit.torch.torch_rl_algorithm import TorchTrainer +from itertools import chain + + +class PEARLSoftActorCriticTrainer(TorchTrainer): + def __init__( + self, + latent_dim, + agent, + qf1, + qf2, + vf, + context_encoder, + reward_predictor, + context_decoder, + + reward_scale=1., + discount=0.99, + policy_lr=1e-3, + qf_lr=1e-3, + vf_lr=1e-3, + context_lr=1e-3, + kl_lambda=1., + policy_mean_reg_weight=1e-3, + policy_std_reg_weight=1e-3, + policy_pre_activation_weight=0., + optimizer_class=optim.Adam, + recurrent=False, + use_information_bottleneck=True, + use_next_obs_in_context=False, + sparse_rewards=False, + train_context_decoder=False, + backprop_q_loss_into_encoder=True, + + train_reward_pred_in_unsupervised_phase=False, + use_encoder_snapshot_for_reward_pred_in_unsupervised_phase=False, + + soft_target_tau=1e-2, + target_update_period=1, + plotter=None, + render_eval_paths=False, + ): + super().__init__() + + self.train_agent = True + self.reward_scale = reward_scale + self.discount = discount + self.soft_target_tau = soft_target_tau + assert target_update_period == 1 + self.policy_mean_reg_weight = policy_mean_reg_weight + self.policy_std_reg_weight = policy_std_reg_weight + self.policy_pre_activation_weight = policy_pre_activation_weight + self.plotter = plotter + self.render_eval_paths = render_eval_paths + + self.train_reward_pred_in_unsupervised_phase = train_reward_pred_in_unsupervised_phase + self.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase = ( + use_encoder_snapshot_for_reward_pred_in_unsupervised_phase + ) + + self.recurrent = recurrent + self.latent_dim = latent_dim + self.qf_criterion = nn.MSELoss() + self.vf_criterion = nn.MSELoss() + self.vib_criterion = nn.MSELoss() + self.l2_reg_criterion = nn.MSELoss() + self.reward_pred_criterion = nn.MSELoss() + self.kl_lambda = kl_lambda + + self.use_information_bottleneck = use_information_bottleneck + self.sparse_rewards = sparse_rewards + self.use_next_obs_in_context = use_next_obs_in_context + self.train_encoder_decoder = True + self.train_context_decoder = train_context_decoder + self.backprop_q_loss_into_encoder = backprop_q_loss_into_encoder + + self.agent = agent + self.policy = agent.policy + self.qf1, self.qf2, self.vf = qf1, qf2, vf + self.target_vf = copy.deepcopy(self.vf) + self.context_encoder = context_encoder + self.context_decoder = context_decoder + self.reward_predictor = reward_predictor + + self.policy_optimizer = optimizer_class( + self.policy.parameters(), + lr=policy_lr, + ) + if train_context_decoder: + self.context_optimizer = optimizer_class( + chain( + self.context_encoder.parameters(), + self.context_decoder.parameters(), + ), + lr=context_lr, + ) + else: + self.context_optimizer = optimizer_class( + self.context_encoder.parameters(), + lr=context_lr, + ) + self.qf1_optimizer = optimizer_class( + self.qf1.parameters(), + lr=qf_lr, + ) + self.qf2_optimizer = optimizer_class( + self.qf2.parameters(), + lr=qf_lr, + ) + self.vf_optimizer = optimizer_class( + self.vf.parameters(), + lr=vf_lr, + ) + self.reward_predictor_optimizer = optimizer_class( + self.reward_predictor.parameters(), + lr=context_lr, + ) + + self.eval_statistics = None + self._need_to_update_eval_statistics = True + + ###### Torch stuff ##### + @property + def networks(self): + return [ + self.policy, + self.qf1, self.qf2, self.vf, self.target_vf, + self.context_encoder, + self.context_decoder, + self.reward_predictor, + ] + + def training_mode(self, mode): + for net in self.networks: + net.train(mode) + + def to(self, device=None): + if device == None: + device = ptu.device + for net in self.networks: + net.to(device) + + + ##### Training ##### + def _min_q(self, obs, actions, task_z): + q1 = self.qf1(obs, actions, task_z.detach()) + q2 = self.qf2(obs, actions, task_z.detach()) + min_q = torch.min(q1, q2) + return min_q + + def _update_target_network(self): + ptu.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau) + + # def train_from_torch(self, indices, context, context_dict): + def train_from_torch(self, batch): + rewards = batch['rewards'] + terminals = batch['terminals'] + obs = batch['observations'] + actions = batch['actions'] + next_obs = batch['next_observations'] + context = batch['context'] + + # data is (task, batch, feat) + # obs, actions, rewards, next_obs, terms = self.sample_sac(indices) + + # run inference in networks + action_distrib, p_z, task_z_with_grad = self.agent( + obs, context, return_latent_posterior_and_task_z=True, + ) + task_z_detached = task_z_with_grad.detach() + new_actions, log_pi, pre_tanh_value = ( + action_distrib.rsample_logprob_and_pretanh() + ) + log_pi = log_pi.unsqueeze(1) + policy_mean = action_distrib.mean + policy_log_std = action_distrib.log_std + + # flattens out the task dimension + t, b, _ = obs.size() + obs = obs.view(t * b, -1) + actions = actions.view(t * b, -1) + next_obs = next_obs.view(t * b, -1) + unscaled_rewards_flat = rewards.view(t * b, 1) + rewards_flat = unscaled_rewards_flat * self.reward_scale + terms_flat = terminals.view(t * b, 1) + + # Q and V networks + # encoder will only get gradients from Q nets + if self.backprop_q_loss_into_encoder: + q1_pred = self.qf1(obs, actions, task_z_with_grad) + q2_pred = self.qf2(obs, actions, task_z_with_grad) + else: + q1_pred = self.qf1(obs, actions, task_z_detached) + q2_pred = self.qf2(obs, actions, task_z_detached) + v_pred = self.vf(obs, task_z_detached) + # get targets for use in V and Q updates + with torch.no_grad(): + target_v_values = self.target_vf(next_obs, task_z_detached) + + """ + QF, Encoder, and Decoder Loss + """ + # note: encoder/deocder do not get grads from policy or vf + q_target = rewards_flat + (1. - terms_flat) * self.discount * target_v_values + qf_loss = torch.mean((q1_pred - q_target) ** 2) + torch.mean((q2_pred - q_target) ** 2) + + # KL constraint on z if probabilistic + kl_div = kl_divergence(p_z, self.agent.latent_prior).sum() + kl_loss = self.kl_lambda * kl_div + if self.train_context_decoder: + # TODO: change to use a distribution + reward_pred = self.context_decoder(obs, actions, task_z_with_grad) + reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean() + context_loss = kl_loss + reward_prediction_loss + else: + context_loss = kl_loss + reward_prediction_loss = ptu.zeros(1) + + if self.train_encoder_decoder: + self.context_optimizer.zero_grad() + if self.train_agent: + self.qf1_optimizer.zero_grad() + self.qf2_optimizer.zero_grad() + context_loss.backward(retain_graph=True) + qf_loss.backward() + if self.train_agent: + self.qf1_optimizer.step() + self.qf2_optimizer.step() + if self.train_encoder_decoder: + self.context_optimizer.step() + + """ + VF update + """ + min_q_new_actions = self._min_q(obs, new_actions, task_z_detached) + v_target = min_q_new_actions - log_pi + vf_loss = self.vf_criterion(v_pred, v_target.detach()) + self.vf_optimizer.zero_grad() + vf_loss.backward() + self.vf_optimizer.step() + self._update_target_network() + + """ + Policy update + """ + # n.b. policy update includes dQ/da + log_policy_target = min_q_new_actions + policy_loss = ( + log_pi - log_policy_target + ).mean() + + mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() + std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() + pre_activation_reg_loss = self.policy_pre_activation_weight * ( + (pre_tanh_value**2).sum(dim=1).mean() + ) + policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss + policy_loss = policy_loss + policy_reg_loss + + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + + # save some statistics for eval + if self._need_to_update_eval_statistics: + self._need_to_update_eval_statistics = False + # eval should set this to None. + # this way, these statistics are only computed for one batch. + self.eval_statistics = OrderedDict() + if self.use_information_bottleneck: + z_mean = np.mean(np.abs(ptu.get_numpy(p_z.mean))) + z_sig = np.mean(ptu.get_numpy(p_z.stddev)) + self.eval_statistics['Z mean-abs train'] = z_mean + self.eval_statistics['Z variance train'] = z_sig + self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div) + self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss) + + self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) + self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) + self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( + policy_loss + )) + self.eval_statistics['task_embedding/kl_divergence'] = ( + ptu.get_numpy(kl_div) + ) + self.eval_statistics['task_embedding/kl_loss'] = ( + ptu.get_numpy(kl_loss) + ) + self.eval_statistics['task_embedding/reward_prediction_loss'] = ( + ptu.get_numpy(reward_prediction_loss) + ) + self.eval_statistics['task_embedding/context_loss'] = ( + ptu.get_numpy(context_loss) + ) + self.eval_statistics.update(create_stats_ordered_dict( + 'Q Predictions', + ptu.get_numpy(q1_pred), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'V Predictions', + ptu.get_numpy(v_pred), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Log Pis', + ptu.get_numpy(log_pi), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Policy mu', + ptu.get_numpy(policy_mean), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Policy log std', + ptu.get_numpy(policy_log_std), + )) + + def get_snapshot(self): + # NOTE: overriding parent method which also optionally saves the env + snapshot = OrderedDict( + qf1=self.qf1.state_dict(), + qf2=self.qf2.state_dict(), + policy=self.agent.policy.state_dict(), + vf=self.vf.state_dict(), + target_vf=self.target_vf.state_dict(), + context_encoder=self.agent.context_encoder.state_dict(), + context_decoder=self.context_decoder.state_dict(), + ) + return snapshot + + def end_epoch(self, epoch): + self._need_to_update_eval_statistics = True + + def get_diagnostics(self): + stats = super().get_diagnostics() + stats.update(self.eval_statistics) + return stats diff --git a/rlkit/torch/smac/pearl_launcher.py b/rlkit/torch/smac/pearl_launcher.py new file mode 100644 index 000000000..edde59413 --- /dev/null +++ b/rlkit/torch/smac/pearl_launcher.py @@ -0,0 +1,173 @@ +import pickle + +import rlkit.torch.pytorch_util as ptu +from rlkit.core import logger +from rlkit.core.meta_rl_algorithm import MetaRLAlgorithm +from rlkit.envs.pearl_envs import ENVS, register_pearl_envs +from rlkit.envs.wrappers import NormalizedBoxEnv +from rlkit.util.io import load_local_or_remote_file +from rlkit.torch.networks import ConcatMlp +from rlkit.torch.smac.agent import SmacAgent +from rlkit.torch.smac.diagnostics import ( + get_env_info_sizes, +) +from rlkit.torch.smac.networks import MlpEncoder, MlpDecoder +from rlkit.torch.smac.launcher_util import load_buffer_onto_algo +from rlkit.torch.smac.pearl import PEARLSoftActorCriticTrainer +from rlkit.torch.sac.policies import TanhGaussianPolicy + + +def pearl_experiment( + qf_kwargs=None, + vf_kwargs=None, + trainer_kwargs=None, + algo_kwargs=None, + context_encoder_kwargs=None, + context_decoder_kwargs=None, + policy_kwargs=None, + env_name=None, + env_params=None, + latent_dim=None, + # video/debug + debug=False, + _debug_do_not_sqrt=False, + # PEARL + n_train_tasks=0, + n_eval_tasks=0, + use_next_obs_in_context=False, + saved_tasks_path=None, + tags=None, +): + del tags + register_pearl_envs() + env_params = env_params or {} + context_encoder_kwargs = context_encoder_kwargs or {} + context_decoder_kwargs = context_decoder_kwargs or {} + trainer_kwargs = trainer_kwargs or {} + base_env = ENVS[env_name](**env_params) + if saved_tasks_path: + task_data = load_local_or_remote_file( + saved_tasks_path, file_type='joblib') + tasks = task_data['tasks'] + train_task_idxs = task_data['train_task_indices'] + eval_task_idxs = task_data['eval_task_indices'] + base_env.tasks = tasks + else: + tasks = base_env.tasks + task_indices = base_env.get_all_task_idx() + train_task_idxs = list(task_indices[:n_train_tasks]) + eval_task_idxs = list(task_indices[-n_eval_tasks:]) + if hasattr(base_env, 'task_to_vec'): + train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs] + eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs] + else: + train_tasks = [tasks[i] for i in train_task_idxs] + eval_tasks = [tasks[i] for i in eval_task_idxs] + expl_env = NormalizedBoxEnv(base_env) + eval_env = NormalizedBoxEnv(ENVS[env_name](**env_params)) + eval_env.tasks = expl_env.tasks + reward_dim = 1 + + if debug: + algo_kwargs['max_path_length'] = 50 + algo_kwargs['batch_size'] = 5 + algo_kwargs['num_epochs'] = 5 + algo_kwargs['num_eval_steps_per_epoch'] = 100 + algo_kwargs['num_expl_steps_per_train_loop'] = 100 + algo_kwargs['num_trains_per_train_loop'] = 10 + algo_kwargs['min_num_steps_before_training'] = 100 + + obs_dim = expl_env.observation_space.low.size + action_dim = eval_env.action_space.low.size + + if use_next_obs_in_context: + context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim + else: + context_encoder_input_dim = obs_dim + action_dim + reward_dim + context_encoder_output_dim = latent_dim * 2 + + def create_qf(): + return ConcatMlp( + input_size=obs_dim + action_dim + latent_dim, + output_size=1, + **qf_kwargs + ) + + qf1 = create_qf() + qf2 = create_qf() + vf = ConcatMlp( + input_size=obs_dim + latent_dim, + output_size=1, + **vf_kwargs + ) + + policy = TanhGaussianPolicy( + obs_dim=obs_dim + latent_dim, + action_dim=action_dim, + **policy_kwargs, + ) + context_encoder = MlpEncoder( + input_size=context_encoder_input_dim, + output_size=context_encoder_output_dim, + hidden_sizes=[200, 200, 200], + **context_encoder_kwargs + ) + context_decoder = MlpDecoder( + input_size=obs_dim + action_dim + latent_dim, + output_size=1, + **context_decoder_kwargs + ) + reward_predictor = context_decoder + agent = SmacAgent( + latent_dim, + context_encoder, + policy, + reward_predictor, + use_next_obs_in_context=use_next_obs_in_context, + _debug_do_not_sqrt=_debug_do_not_sqrt, + ) + trainer = PEARLSoftActorCriticTrainer( + latent_dim=latent_dim, + agent=agent, + qf1=qf1, + qf2=qf2, + vf=vf, + reward_predictor=reward_predictor, + context_encoder=context_encoder, + context_decoder=context_decoder, + **trainer_kwargs + ) + algorithm = MetaRLAlgorithm( + agent=agent, + env=expl_env, + trainer=trainer, + train_task_indices=train_task_idxs, + eval_task_indices=eval_task_idxs, + train_tasks=train_tasks, + eval_tasks=eval_tasks, + use_next_obs_in_context=use_next_obs_in_context, + env_info_sizes=get_env_info_sizes(expl_env), + **algo_kwargs + ) + saved_path = logger.save_extra_data( + data=dict( + tasks=expl_env.tasks, + train_task_indices=train_task_idxs, + eval_task_indices=eval_task_idxs, + train_tasks=train_tasks, + eval_tasks=eval_tasks, + ), + file_name='tasks_description', + ) + print('saved tasks description to', saved_path) + saved_path = logger.save_extra_data( + expl_env.tasks, + file_name='tasks', + mode='pickle', + ) + print('saved raw tasks to', saved_path) + + algorithm.to(ptu.device) + + algorithm.to(ptu.device) + algorithm.train() diff --git a/rlkit/torch/smac/sampler.py b/rlkit/torch/smac/sampler.py new file mode 100644 index 000000000..d8f24ac9d --- /dev/null +++ b/rlkit/torch/smac/sampler.py @@ -0,0 +1,308 @@ +import numpy as np + +from rlkit.torch.smac.agent import MakeSMACAgentDeterministic +import rlkit.torch.pytorch_util as ptu + + +class SMACInPlacePathSampler(object): + """ + A sampler that does not serialization for sampling. Instead, it just uses + the current policy and environment as-is. + + WARNING: This will affect the environment! So + ``` + sampler = InPlacePathSampler(env, ...) + sampler.obtain_samples # this has side-effects: env will change! + ``` + """ + def __init__(self, env, policy, max_path_length): + self.env = env + self.policy = policy + + self.max_path_length = max_path_length + + def start_worker(self): + pass + + def shutdown_worker(self): + pass + + def obtain_samples( + self, + deterministic=False, + max_trajs=np.inf, + max_samples=np.inf, + **kwargs + ): + """ + Obtains samples in the environment until either we reach either + `max_samples` transitions or `max_trajs` trajectories. + """ + assert max_samples < np.inf or max_trajs < np.inf, "either max_samples or max_trajs must be finite" + policy = MakeSMACAgentDeterministic(self.policy) if deterministic else self.policy + paths = [] + n_steps_total = 0 + n_trajs = 0 + while n_steps_total < max_samples and n_trajs < max_trajs: + path = rollout( + self.env, + policy, + max_path_length=self.max_path_length, + **kwargs + ) + # save the latent context that generated this trajectory + # path['context'] = policy.z.detach().cpu().numpy() + paths.append(path) + n_steps_total += len(path['observations']) + n_trajs += 1 + # don't we also want the option to resample z ever transition? + # if n_trajs % resample == 0: + # policy.sample_z() + return paths, n_steps_total + + +def rollout( + env, + agent, + task_idx, + max_path_length=np.inf, + accum_context=True, + animated=False, + save_frames=False, + use_predicted_reward=False, + resample_latent_period=0, + update_posterior_period=0, + initial_context=None, + initial_reward_context=None, + infer_posterior_at_start=True, + initialized_z_reward=None, +): + """ + The following value for the following keys will be a 2D array, with the + first dimension corresponding to the time dimension. + - observations + - actions + - rewards + - next_observations + - terminals + + The next two elements will be lists of dictionaries, with the index into + the list being the index into the time + - agent_infos + - env_infos + + :param initial_context: + :param infer_posterior_at_start: If True, infer the posterior from `initial_context` if possible. + :param env: + :param agent: + :task_idx: the task index + :param task_idx: the index of the task inside the environment. + :param max_path_length: + :param accum_context: if True, accumulate the collected context + :param animated: + :param save_frames: if True, save video of rollout + :param resample_latent_period: How often to resample from the latent posterior, in units of env steps. + If zero, never resample after the first sample. + :param update_posterior_period: How often to update the latent posterior, + in units of env steps. + If zero, never update unless an initial context is provided, in which + case only update at the start using that initial context. + :return: + """ + observations = [] + actions = [] + rewards = [] + terminals = [] + agent_infos = [] + env_infos = [] + zs = [] + if initialized_z_reward is None: + env.reset_task(task_idx) + o = env.reset() + next_o = None + + if animated: + env.render() + if initial_context is not None and len(initial_context) == 0: + initial_context = None + + context = initial_context + + if infer_posterior_at_start and initial_context is not None: + z_dist = agent.latent_posterior(context, squeeze=True) + else: + z_dist = agent.latent_prior + + if use_predicted_reward: + if initialized_z_reward is None: + z_reward_dist = agent.latent_posterior( + initial_reward_context, squeeze=True, for_reward_prediction=True, + ) + z_reward = ptu.get_numpy(z_reward_dist.sample()) + else: + z_reward = initialized_z_reward + + z = ptu.get_numpy(z_dist.sample()) + for path_length in range(max_path_length): + if resample_latent_period != 0 and path_length % resample_latent_period == 0: + z = ptu.get_numpy(z_dist.sample()) + a, agent_info = agent.get_action(o, z) + next_o, r, d, env_info = env.step(a) + if use_predicted_reward: + r = agent.infer_reward(o, a, z_reward) + r = r[0] + if accum_context: + context = agent.update_context( + context, + [o, a, r, next_o, d, env_info], + ) + # TODO: remove "context is not None" check after fixing first-loop hack + if update_posterior_period != 0 and path_length % update_posterior_period == 0 and context is not None and len(context) > 0: + z_dist = agent.latent_posterior(context, squeeze=True) + zs.append(z) + observations.append(o) + rewards.append(r) + terminals.append(d) + actions.append(a) + agent_infos.append(agent_info) + o = next_o + if animated: + env.render() + if save_frames: + from PIL import Image + image = Image.fromarray(np.flipud(env.get_image())) + env_info['frame'] = image + env_infos.append(env_info) + if d: + break + + actions = np.array(actions) + if len(actions.shape) == 1: + actions = np.expand_dims(actions, 1) + observations = np.array(observations) + if len(observations.shape) == 1 and not isinstance(observations[0], dict): + observations = np.expand_dims(observations, 1) + next_o = np.array([next_o]) + next_observations = np.concatenate( + ( + observations[1:, ...], + np.expand_dims(next_o, 0) + ), + axis=0, + ) + return dict( + observations=observations, + actions=actions, + rewards=np.array(rewards).reshape(-1, 1), + next_observations=next_observations, + terminals=np.array(terminals).reshape(-1, 1), + agent_infos=agent_infos, + env_infos=env_infos, + latents=np.array(zs), + context=context, + ) + + +def rollout_multiple( + *args, + num_repeats=1, + initial_context=None, + accum_context=True, + **kwargs +): + """ + Do multiple rollouts and concatenate the paths + """ + assert num_repeats >= 1 + last_path = rollout( + *args, + accum_context=accum_context, + initial_context=initial_context, + **kwargs) + paths = [last_path] + for i in range(num_repeats-1): + if accum_context: + initial_context = last_path['context'] + new_path = rollout( + *args, + initial_context=initial_context, + accum_context=True, + **kwargs) + paths.append(new_path) + last_path = new_path + + return paths + + +def merge_paths(paths): + flat_path = paths[0] + for new_path in paths[1:]: + for k in [ + 'observations', + 'actions', + 'rewards', + 'next_observations', + 'terminals', + 'latents', + ]: + flat_path[k] = np.concatenate(( + flat_path[k], + new_path[k], + ), axis=0) + return flat_path + + +def rollout_multiple_and_flatten(*args, **kwargs): + paths = rollout_multiple(*args, **kwargs) + return merge_paths(paths) + + +def split_paths(paths): + """ + Stack multiples obs/actions/etc. from different paths + :param paths: List of paths, where one path is something returned from + the rollout functino above. + :return: Tuple. Every element will have shape batch_size X DIM, including + the rewards and terminal flags. + """ + rewards = [path["rewards"].reshape(-1, 1) for path in paths] + terminals = [path["terminals"].reshape(-1, 1) for path in paths] + actions = [path["actions"] for path in paths] + obs = [path["observations"] for path in paths] + next_obs = [path["next_observations"] for path in paths] + rewards = np.vstack(rewards) + terminals = np.vstack(terminals) + obs = np.vstack(obs) + actions = np.vstack(actions) + next_obs = np.vstack(next_obs) + assert len(rewards.shape) == 2 + assert len(terminals.shape) == 2 + assert len(obs.shape) == 2 + assert len(actions.shape) == 2 + assert len(next_obs.shape) == 2 + return rewards, terminals, obs, actions, next_obs + + +def split_paths_to_dict(paths): + rewards, terminals, obs, actions, next_obs = split_paths(paths) + return dict( + rewards=rewards, + terminals=terminals, + observations=obs, + actions=actions, + next_observations=next_obs, + ) + + +def get_stat_in_paths(paths, dict_name, scalar_name): + if len(paths) == 0: + return np.array([[]]) + + if type(paths[0][dict_name]) == dict: + # Support rllab interface + return [path[dict_name][scalar_name] for path in paths] + + return [ + [info[scalar_name] for info in path[dict_name]] + for path in paths + ] diff --git a/rlkit/torch/smac/smac.py b/rlkit/torch/smac/smac.py new file mode 100644 index 000000000..bedceffe1 --- /dev/null +++ b/rlkit/torch/smac/smac.py @@ -0,0 +1,668 @@ +from collections import OrderedDict +from itertools import chain + +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from torch import nn as nn +from torch.distributions import kl_divergence + +import rlkit.torch.pytorch_util as ptu +from rlkit.core.logging import add_prefix +from rlkit.util import ml_util +from rlkit.core.eval_util import create_stats_ordered_dict +from rlkit.torch.networks import LinearTransform +from rlkit.torch.smac.agent import SmacAgent +from rlkit.torch.torch_rl_algorithm import TorchTrainer + + +class SmacTrainer(TorchTrainer): + def __init__( + self, + agent: SmacAgent, + env, + latent_dim, + qf1, + qf2, + target_qf1, + target_qf2, + context_encoder, + reward_predictor, + context_decoder, + + train_context_decoder=False, + backprop_q_loss_into_encoder=True, + context_lr=1e-3, + kl_lambda=1., + policy_mean_reg_weight=1e-3, + policy_std_reg_weight=1e-3, + policy_pre_activation_weight=0., + recurrent=False, + use_information_bottleneck=True, + use_next_obs_in_context=False, + sparse_rewards=False, + + train_reward_pred_in_unsupervised_phase=False, + use_encoder_snapshot_for_reward_pred_in_unsupervised_phase=False, + + # from AWAC + buffer_policy=None, + + discount=0.99, + reward_scale=1.0, + beta=1.0, + beta_schedule_kwargs=None, + + policy_lr=1e-3, + qf_lr=1e-3, + policy_weight_decay=0, + q_weight_decay=0, + optimizer_class=optim.Adam, + + soft_target_tau=1e-2, + target_update_period=1, + plotter=None, + render_eval_paths=False, + + use_automatic_entropy_tuning=True, + target_entropy=None, + + bc_num_pretrain_steps=0, + q_num_pretrain1_steps=0, + q_num_pretrain2_steps=0, + bc_batch_size=128, + alpha=1.0, + + policy_update_period=1, + q_update_period=1, + + weight_loss=True, + compute_bc=True, + use_awr_update=True, + use_reparam_update=False, + + bc_weight=0.0, + rl_weight=1.0, + reparam_weight=1.0, + reparam_weight_schedule_kwargs=None, + awr_weight=1.0, + + awr_use_mle_for_vf=False, + vf_K=1, + awr_sample_actions=False, + buffer_policy_sample_actions=False, + awr_min_q=False, + brac=False, + + reward_transform_class=None, + reward_transform_kwargs=None, + terminal_transform_class=None, + terminal_transform_kwargs=None, + + pretraining_logging_period=1000, + + train_bc_on_rl_buffer=False, + use_automatic_beta_tuning=False, + beta_epsilon=1e-10, + normalize_over_batch=True, + Z_K=10, + clip_score=None, + validation_qlearning=False, + + mask_positive_advantage=False, + buffer_policy_reset_period=-1, + num_buffer_policy_train_steps_on_reset=100, + advantage_weighted_buffer_loss=True, + + # for debugging + _debug_ignore_context=False, + _debug_use_ground_truth_context=False, + ): + super().__init__() + + self.train_agent = True + self.train_context_decoder = train_context_decoder + self.train_encoder_decoder = True + self.backprop_q_loss_into_encoder = backprop_q_loss_into_encoder + self.reward_scale = reward_scale + self.discount = discount + self.soft_target_tau = soft_target_tau + self.policy_mean_reg_weight = policy_mean_reg_weight + self.policy_std_reg_weight = policy_std_reg_weight + self.policy_pre_activation_weight = policy_pre_activation_weight + self.plotter = plotter + self.render_eval_paths = render_eval_paths + + self.train_reward_pred_in_unsupervised_phase = train_reward_pred_in_unsupervised_phase + self.use_encoder_snapshot_for_reward_pred_in_unsupervised_phase = ( + use_encoder_snapshot_for_reward_pred_in_unsupervised_phase + ) + + self.recurrent = recurrent + self.latent_dim = latent_dim + self.qf_criterion = nn.MSELoss() + self.vf_criterion = nn.MSELoss() + self.vib_criterion = nn.MSELoss() + self.l2_reg_criterion = nn.MSELoss() + self.reward_pred_criterion = nn.MSELoss() + self.kl_lambda = kl_lambda + + self.use_information_bottleneck = use_information_bottleneck + self.sparse_rewards = sparse_rewards + self.use_next_obs_in_context = use_next_obs_in_context + + self._debug_ignore_context = _debug_ignore_context + + self.agent = agent + self.policy = agent.policy + self.qf1, self.qf2 = qf1, qf2 + self.context_encoder = context_encoder + self.context_decoder = context_decoder + self.reward_predictor = reward_predictor + + self.policy_optimizer = optimizer_class( + self.policy.parameters(), + lr=policy_lr, + ) + if train_context_decoder: + self.context_optimizer = optimizer_class( + chain( + self.context_encoder.parameters(), + self.context_decoder.parameters(), + ), + lr=context_lr, + ) + else: + self.context_optimizer = optimizer_class( + self.context_encoder.parameters(), + lr=context_lr, + ) + self.qf1_optimizer = optimizer_class( + self.qf1.parameters(), + lr=qf_lr, + ) + self.qf2_optimizer = optimizer_class( + self.qf2.parameters(), + lr=qf_lr, + ) + + self.eval_statistics = None + self._need_to_update_eval_statistics = True + + self.target_qf1 = target_qf1 + self.target_qf2 = target_qf2 + self.buffer_policy = buffer_policy + self.soft_target_tau = soft_target_tau + self.target_update_period = target_update_period + + self.use_awr_update = use_awr_update + self.use_automatic_entropy_tuning = use_automatic_entropy_tuning + if self.use_automatic_entropy_tuning: + if target_entropy: + self.target_entropy = target_entropy + else: + self.target_entropy = -np.prod( + env.action_space.shape).item() # heuristic value from Tuomas + self.log_alpha = ptu.zeros(1, requires_grad=True) + self.alpha_optimizer = optimizer_class( + [self.log_alpha], + lr=policy_lr, + ) + + self.awr_use_mle_for_vf = awr_use_mle_for_vf + self.vf_K = vf_K + self.awr_sample_actions = awr_sample_actions + self.awr_min_q = awr_min_q + + self.plotter = plotter + self.render_eval_paths = render_eval_paths + + self.qf_criterion = nn.MSELoss() + + self.optimizers = {} + + self.policy_optimizer = optimizer_class( + self.policy.parameters(), + weight_decay=policy_weight_decay, + lr=policy_lr, + ) + self.optimizers[self.policy] = self.policy_optimizer + self.qf1_optimizer = optimizer_class( + self.qf1.parameters(), + weight_decay=q_weight_decay, + lr=qf_lr, + ) + self.qf2_optimizer = optimizer_class( + self.qf2.parameters(), + weight_decay=q_weight_decay, + lr=qf_lr, + ) + + self.use_automatic_beta_tuning = use_automatic_beta_tuning and buffer_policy and train_bc_on_rl_buffer + self.beta_epsilon = beta_epsilon + if self.use_automatic_beta_tuning: + self.log_beta = ptu.zeros(1, requires_grad=True) + self.beta_optimizer = optimizer_class( + [self.log_beta], + lr=policy_lr, + ) + else: + self.beta = beta + self.beta_schedule_kwargs = beta_schedule_kwargs + if beta_schedule_kwargs is None: + self.beta_schedule = ml_util.ConstantSchedule(beta) + else: + schedule_class = beta_schedule_kwargs.pop("schedule_class", + ml_util.PiecewiseLinearSchedule) + self.beta_schedule = schedule_class(**beta_schedule_kwargs) + + self.discount = discount + self.reward_scale = reward_scale + self.eval_statistics = OrderedDict() + self._n_train_steps_total = 0 + self._need_to_update_eval_statistics = True + + self.bc_num_pretrain_steps = bc_num_pretrain_steps + self.q_num_pretrain1_steps = q_num_pretrain1_steps + self.q_num_pretrain2_steps = q_num_pretrain2_steps + self.bc_batch_size = bc_batch_size + self.rl_weight = rl_weight + self.bc_weight = bc_weight + self.compute_bc = compute_bc + self.alpha = alpha + self.q_update_period = q_update_period + self.policy_update_period = policy_update_period + self.weight_loss = weight_loss + + self.reparam_weight = reparam_weight + self.reparam_weight_schedule = None + self.reparam_weight_schedule_kwargs = reparam_weight_schedule_kwargs + self.awr_weight = awr_weight + self.update_policy = True + self.pretraining_logging_period = pretraining_logging_period + self.normalize_over_batch = normalize_over_batch + self.Z_K = Z_K + + self.reward_transform_class = reward_transform_class or LinearTransform + self.reward_transform_kwargs = reward_transform_kwargs or dict(m=1, b=0) + self.terminal_transform_class = terminal_transform_class or LinearTransform + self.terminal_transform_kwargs = terminal_transform_kwargs or dict(m=1, + b=0) + self.reward_transform = self.reward_transform_class( + **self.reward_transform_kwargs) + self.terminal_transform = self.terminal_transform_class( + **self.terminal_transform_kwargs) + self.use_reparam_update = use_reparam_update + self.clip_score = clip_score + self.buffer_policy_sample_actions = buffer_policy_sample_actions + + self.train_bc_on_rl_buffer = train_bc_on_rl_buffer and buffer_policy + self.validation_qlearning = validation_qlearning + self.brac = brac + self.mask_positive_advantage = mask_positive_advantage + self.buffer_policy_reset_period = buffer_policy_reset_period + self.num_buffer_policy_train_steps_on_reset = num_buffer_policy_train_steps_on_reset + self.advantage_weighted_buffer_loss = advantage_weighted_buffer_loss + self._debug_use_ground_truth_context = _debug_use_ground_truth_context + self._num_gradient_steps = 0 + + @property + def train_reparam_weight(self): + if self.reparam_weight_schedule_kwargs is not None and self.reparam_weight_schedule is None: + self.reparam_weight_schedule = ml_util.create_schedule( + **self.reparam_weight_schedule_kwargs + ) + if self.reparam_weight_schedule is None: + return self.reparam_weight + else: + return self.reparam_weight_schedule.get_value( + self._n_train_steps_total + ) + + ##### Training ##### + def train_from_torch(self, batch): + rewards = batch['rewards'] + terminals = batch['terminals'] + obs = batch['observations'] + actions = batch['actions'] + next_obs = batch['next_observations'] + context = batch['context'] + + if self.reward_transform: + rewards = self.reward_transform(rewards) + + if self.terminal_transform: + terminals = self.terminal_transform(terminals) + """ + Policy and Alpha Loss + """ + dist, p_z, task_z_with_grad = self.agent( + obs, context, return_latent_posterior_and_task_z=True, + ) + task_z_detached = task_z_with_grad.detach() + new_obs_actions, log_pi = dist.rsample_and_logprob() + log_pi = log_pi.unsqueeze(1) + next_dist = self.agent(next_obs, context) + + if self._debug_ignore_context: + task_z_with_grad = task_z_with_grad * 0 + + # flattens out the task dimension + t, b, _ = obs.size() + obs = obs.view(t * b, -1) + actions = actions.view(t * b, -1) + next_obs = next_obs.view(t * b, -1) + unscaled_rewards_flat = rewards.view(t * b, 1) + rewards_flat = unscaled_rewards_flat * self.reward_scale + terms_flat = terminals.view(t * b, 1) + + if self.use_automatic_entropy_tuning: + alpha_loss = -(self.log_alpha * ( + log_pi + self.target_entropy).detach()).mean() + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + alpha = self.log_alpha.exp() + else: + alpha_loss = 0 + alpha = self.alpha + + """ + QF Loss + """ + if self.backprop_q_loss_into_encoder: + q1_pred = self.qf1(obs, actions, task_z_with_grad) + q2_pred = self.qf2(obs, actions, task_z_with_grad) + else: + q1_pred = self.qf1(obs, actions, task_z_detached) + q2_pred = self.qf2(obs, actions, task_z_detached) + # Make sure policy accounts for squashing functions like tanh correctly! + new_next_actions, new_log_pi = next_dist.rsample_and_logprob() + new_log_pi = new_log_pi.unsqueeze(1) + with torch.no_grad(): + target_q_values = torch.min( + self.target_qf1(next_obs, new_next_actions, task_z_detached), + self.target_qf2(next_obs, new_next_actions, task_z_detached), + ) - alpha * new_log_pi + + q_target = rewards_flat + ( + 1. - terms_flat) * self.discount * target_q_values + qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) + qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) + + """ + Context Encoder Loss + """ + if self._debug_use_ground_truth_context: + kl_div = kl_loss = ptu.zeros(0) + else: + kl_div = kl_divergence(p_z, self.agent.latent_prior).mean(dim=0).sum() + kl_loss = self.kl_lambda * kl_div + + if self.train_context_decoder: + # TODO: change to use a distribution + reward_pred = self.context_decoder(obs, actions, task_z_with_grad) + reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean() + context_loss = kl_loss + reward_prediction_loss + else: + context_loss = kl_loss + reward_prediction_loss = ptu.zeros(1) + + """ + Policy Loss + """ + qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached) + qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached) + q_new_actions = torch.min( + qf1_new_actions, + qf2_new_actions, + ) + + # Advantage-weighted regression + if self.vf_K > 1: + vs = [] + for i in range(self.vf_K): + u = dist.sample() + q1 = self.qf1(obs, u, task_z_detached) + q2 = self.qf2(obs, u, task_z_detached) + v = torch.min(q1, q2) + # v = q1 + vs.append(v) + v_pi = torch.cat(vs, 1).mean(dim=1) + else: + # v_pi = self.qf1(obs, new_obs_actions) + v1_pi = self.qf1(obs, new_obs_actions, task_z_detached) + v2_pi = self.qf2(obs, new_obs_actions, task_z_detached) + v_pi = torch.min(v1_pi, v2_pi) + + u = actions + if self.awr_min_q: + q_adv = torch.min(q1_pred, q2_pred) + else: + q_adv = q1_pred + + policy_logpp = dist.log_prob(u) + + if self.use_automatic_beta_tuning: + buffer_dist = self.buffer_policy(obs) + beta = self.log_beta.exp() + kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) + beta_loss = -1 * ( + beta * (kldiv - self.beta_epsilon).detach()).mean() + + self.beta_optimizer.zero_grad() + beta_loss.backward() + self.beta_optimizer.step() + else: + beta = self.beta_schedule.get_value(self._n_train_steps_total) + beta_loss = ptu.zeros(1) + + score = q_adv - v_pi + if self.mask_positive_advantage: + score = torch.sign(score) + + if self.clip_score is not None: + score = torch.clamp(score, max=self.clip_score) + + weights = batch.get('weights', None) + if self.weight_loss and weights is None: + if self.normalize_over_batch == True: + weights = F.softmax(score / beta, dim=0) + elif self.normalize_over_batch == "whiten": + adv_mean = torch.mean(score) + adv_std = torch.std(score) + 1e-5 + normalized_score = (score - adv_mean) / adv_std + weights = torch.exp(normalized_score / beta) + elif self.normalize_over_batch == "exp": + weights = torch.exp(score / beta) + elif self.normalize_over_batch == "step_fn": + weights = (score > 0).float() + elif self.normalize_over_batch == False: + weights = score + elif self.normalize_over_batch == 'uniform': + weights = F.softmax(ptu.ones_like(score) / beta, dim=0) + else: + raise ValueError(self.normalize_over_batch) + weights = weights[:, 0] + + policy_loss = alpha * log_pi.mean() + + if self.use_awr_update and self.weight_loss: + policy_loss = policy_loss + self.awr_weight * ( + -policy_logpp * len(weights) * weights.detach()).mean() + elif self.use_awr_update: + policy_loss = policy_loss + self.awr_weight * (-policy_logpp).mean() + + if self.use_reparam_update: + policy_loss = policy_loss + self.train_reparam_weight * ( + -q_new_actions).mean() + + policy_loss = self.rl_weight * policy_loss + + """ + Update networks + """ + if self._n_train_steps_total % self.q_update_period == 0: + if self.train_encoder_decoder: + self.context_optimizer.zero_grad() + if self.train_agent: + self.qf1_optimizer.zero_grad() + self.qf2_optimizer.zero_grad() + context_loss.backward(retain_graph=True) + # retain graph because the encoder is trained by both QF losses + qf1_loss.backward(retain_graph=True) + qf2_loss.backward() + if self.train_agent: + self.qf1_optimizer.step() + self.qf2_optimizer.step() + if self.train_encoder_decoder: + self.context_optimizer.step() + + if self.train_agent: + if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: + self.policy_optimizer.zero_grad() + policy_loss.backward() + self.policy_optimizer.step() + self._num_gradient_steps += 1 + + """ + Soft Updates + """ + if self._n_train_steps_total % self.target_update_period == 0: + ptu.soft_update_from_to( + self.qf1, self.target_qf1, self.soft_target_tau + ) + ptu.soft_update_from_to( + self.qf2, self.target_qf2, self.soft_target_tau + ) + + """ + Save some statistics for eval + """ + if self._need_to_update_eval_statistics: + self._need_to_update_eval_statistics = False + """ + Eval should set this to None. + This way, these statistics are only computed for one batch. + """ + policy_loss = (log_pi - q_new_actions).mean() + + self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) + self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) + self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( + policy_loss + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Q1 Predictions', + ptu.get_numpy(q1_pred), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Q2 Predictions', + ptu.get_numpy(q2_pred), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Q Targets', + ptu.get_numpy(q_target), + )) + self.eval_statistics['task_embedding/kl_divergence'] = ( + ptu.get_numpy(kl_div) + ) + self.eval_statistics['task_embedding/kl_loss'] = ( + ptu.get_numpy(kl_loss) + ) + self.eval_statistics['task_embedding/reward_prediction_loss'] = ( + ptu.get_numpy(reward_prediction_loss) + ) + self.eval_statistics['task_embedding/context_loss'] = ( + ptu.get_numpy(context_loss) + ) + self.eval_statistics.update(create_stats_ordered_dict( + 'Log Pis', + ptu.get_numpy(log_pi), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'rewards', + ptu.get_numpy(rewards), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'terminals', + ptu.get_numpy(terminals), + )) + policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") + self.eval_statistics.update(policy_statistics) + self.eval_statistics.update(create_stats_ordered_dict( + 'Advantage Weights', + ptu.get_numpy(weights), + )) + self.eval_statistics.update(create_stats_ordered_dict( + 'Advantage Score', + ptu.get_numpy(score), + )) + self.eval_statistics['reparam_weight'] = self.train_reparam_weight + self.eval_statistics['num_gradient_steps'] = ( + self._num_gradient_steps + ) + + if self.use_automatic_entropy_tuning: + self.eval_statistics['Alpha'] = alpha.item() + self.eval_statistics['Alpha Loss'] = alpha_loss.item() + + if self.use_automatic_beta_tuning: + self.eval_statistics.update({ + "adaptive_beta/beta": ptu.get_numpy(beta.mean()), + "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), + }) + + self._n_train_steps_total += 1 + + def configure(self, **params): + for k, v in params.items(): + if k not in self.__dict__: + raise KeyError('Member {} is not in {}'.format(k, self)) + self.__dict__[k] = v + + #### Trainer #### + def get_snapshot(self): + snapshot = OrderedDict( + qf1=self.qf1.state_dict(), + qf2=self.qf2.state_dict(), + target_qf1=self.target_qf1.state_dict(), + target_qf2=self.target_qf2.state_dict(), + policy=self.agent.policy.state_dict(), + context_encoder=self.agent.context_encoder.state_dict(), + context_decoder=self.context_decoder.state_dict(), + ) + return snapshot + + def end_epoch(self, epoch): + self._need_to_update_eval_statistics = True + + def get_diagnostics(self): + stats = super().get_diagnostics() + stats.update(self.eval_statistics) + return stats + + ###### Torch stuff ##### + @property + def networks(self): + return [ + self.policy, + self.qf1, + self.qf2, + self.target_qf1, + self.target_qf2, + self.context_encoder, + self.context_decoder, + self.reward_predictor, + ] + + def training_mode(self, mode): + for net in self.networks: + net.train(mode) + + def to(self, device=None): + if device == None: + device = ptu.device + for net in self.networks: + net.to(device) diff --git a/rlkit/util/io.py b/rlkit/util/io.py index 649bb5784..96e9e46b8 100644 --- a/rlkit/util/io.py +++ b/rlkit/util/io.py @@ -101,10 +101,12 @@ def load_local_or_remote_file(filepath, file_type=None): extension = local_path.split('.')[-1] if extension == 'npy': file_type = NUMPY - else: + elif extension == 'pkl': file_type = PICKLE - else: - file_type = PICKLE + elif extension == 'joblib': + file_type = JOBLIB + else: + raise ValueError("Could not infer file type.") if file_type == NUMPY: object = np.load(open(local_path, "rb"), allow_pickle=True) elif file_type == JOBLIB: diff --git a/rlkit/util/wrapper.py b/rlkit/util/wrapper.py new file mode 100644 index 000000000..8d8930c24 --- /dev/null +++ b/rlkit/util/wrapper.py @@ -0,0 +1,43 @@ +import os +LOG_DIR = os.getcwd() + + +class Wrapper(object): + """ + Mixin for deferring attributes to a wrapped, inner object. + """ + + def __init__(self, inner): + self.inner = inner + + def __getattr__(self, attr): + """ + Dispatch attributes by their status as magic, members, or missing. + - magic is handled by the standard getattr + - existing attributes are returned + - missing attributes are deferred to the inner object. + """ + # don't make magic any more magical + is_magic = attr.startswith('__') and attr.endswith('__') + if is_magic: + return super().__getattr__(attr) + try: + # try to return the attribute... + return self.__dict__[attr] + except: + # ...and defer to the inner dataset if it's not here + return getattr(self.inner, attr) + + +class SimpleWrapper(object): + """ + Mixin for deferring attributes to a wrapped, inner object. + """ + + def __init__(self, inner): + self._inner = inner + + def __getattr__(self, attr): + if attr == '_inner': + raise AttributeError() + return getattr(self._inner, attr)