diff --git a/examples/iql/mujoco_finetune.py b/examples/iql/mujoco_finetune.py index f6e7dd44b..37156c25b 100644 --- a/examples/iql/mujoco_finetune.py +++ b/examples/iql/mujoco_finetune.py @@ -84,7 +84,8 @@ def main(): variant=variant, exp_prefix='iql-halfcheetah-medium-v2', mode="here_no_doodad", - unpack_variant=False + unpack_variant=False, + use_gpu=False, ) if __name__ == "__main__": diff --git a/rlkit/testing/debug_util.py b/rlkit/testing/debug_util.py new file mode 100644 index 000000000..31f2cd689 --- /dev/null +++ b/rlkit/testing/debug_util.py @@ -0,0 +1,24 @@ +"""For tracing programs and comparing outputs""" + +import torch + +i = 0 + +def save(x): + torch.save(x, "../tmp.pt") + return x + +def load(): + return torch.load("../tmp.pt") + +def savei(x): + global i + torch.save(x, "../tmp/%d.pt" % i) + i = i + 1 + return x + +def loadi(): + global i + x = torch.load("../tmp/%d.pt" % i) + i = i + 1 + return x diff --git a/rlkit/torch/sac/iql_trainer.py b/rlkit/torch/sac/iql_trainer.py index 1d598f047..4571829ac 100644 --- a/rlkit/torch/sac/iql_trainer.py +++ b/rlkit/torch/sac/iql_trainer.py @@ -139,7 +139,6 @@ def train_from_torch(self, batch, train=True, pretrain=False,): Policy and Alpha Loss """ dist = self.policy(obs) - new_obs_actions, log_pi = dist.rsample_and_logprob() """ QF Loss @@ -237,10 +236,6 @@ def train_from_torch(self, batch, train=True, pretrain=False,): 'Q Targets', ptu.get_numpy(q_target), )) - self.eval_statistics.update(create_stats_ordered_dict( - 'Log Pis', - ptu.get_numpy(log_pi), - )) self.eval_statistics.update(create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), diff --git a/tests/regression/iql/halfcheetah_offline_progress.csv b/tests/regression/iql/halfcheetah_offline_progress.csv index 657bc32cf..6ca1d38fd 100644 --- a/tests/regression/iql/halfcheetah_offline_progress.csv +++ b/tests/regression/iql/halfcheetah_offline_progress.csv @@ -1,3 +1,3 @@ epoch,eval/num paths total,eval/num steps total,expl/Actions Max,expl/Actions Mean,expl/Actions Min,expl/Actions Std,expl/Average Returns,expl/Num Paths,expl/Returns Max,expl/Returns Mean,expl/Returns Min,expl/Returns Std,expl/Rewards Max,expl/Rewards Mean,expl/Rewards Min,expl/Rewards Std,expl/num paths total,expl/num steps total,expl/path length Max,expl/path length Mean,expl/path length Min,expl/path length Std,replay_buffer/size,time/epoch_time (s),time/evaluation sampling (s),time/exploration sampling (s),time/global_time (s),time/replay buffer data storing (s),time/saving (s),time/training (s),trainer/Advantage Score Max,trainer/Advantage Score Mean,trainer/Advantage Score Min,trainer/Advantage Score Std,trainer/Advantage Weights Max,trainer/Advantage Weights Mean,trainer/Advantage Weights Min,trainer/Advantage Weights Std,trainer/Policy Loss,trainer/Q Targets Max,trainer/Q Targets Mean,trainer/Q Targets Min,trainer/Q Targets Std,trainer/Q1 Predictions Max,trainer/Q1 Predictions Mean,trainer/Q1 Predictions Min,trainer/Q1 Predictions Std,trainer/Q2 Predictions Max,trainer/Q2 Predictions Mean,trainer/Q2 Predictions Min,trainer/Q2 Predictions Std,trainer/QF1 Loss,trainer/QF2 Loss,trainer/V1 Predictions Max,trainer/V1 Predictions Mean,trainer/V1 Predictions Min,trainer/V1 Predictions Std,trainer/VF Loss,trainer/num train calls,trainer/policy/mean Max,trainer/policy/mean Mean,trainer/policy/mean Min,trainer/policy/mean Std,trainer/policy/std Max,trainer/policy/std Mean,trainer/policy/std Min,trainer/policy/std Std,trainer/replay_buffer_len,trainer/rewards Max,trainer/rewards Mean,trainer/rewards Min,trainer/rewards Std,trainer/terminals Max,trainer/terminals Mean,trainer/terminals Min,trainer/terminals Std --2,0,0,0.13717124,0.03626487,-0.049319427,0.045948837,-0.19655459402248504,1,-0.19655459402248504,-0.19655459402248504,-0.19655459402248504,0.0,-0.09325349516403507,-0.09827729701124252,-0.10330109885844999,0.005023801847207458,1,2,2,2.0,2,0.0,998999,0.2963569164276123,0.047545433044433594,0.017406940460205078,32.431710720062256,2.86102294921875e-06,0.05810999870300293,0.1723041534423828,-0.0036168185,-0.008904904,-0.01419299,0.005288086,0.9892082,0.9737615,0.9583148,0.0154467225,857.7699,1.0095485,0.91235673,0.8151649,0.09719181,-3.695485e-05,-0.0023042315,-0.004571508,0.0022672766,0.0,0.0,0.0,0.0,0.8464968,0.84184104,0.01419299,0.007483647,0.0007743044,0.006709343,3.2178352e-05,2,0.004211932,0.0006116896,-0.0027325656,0.0019366256,0.049787067,0.049787063,0.049787067,3.7252903e-09,998999,1.0072821,0.90650094,0.80571973,0.1007812,0.0,0.0,0.0,0.0 --1,0,0,0.05840967,-0.0072007906,-0.09415653,0.046789993,0.5074781763310496,1,0.5074781763310496,0.5074781763310496,0.5074781763310496,0.0,0.2553473427685612,0.2537390881655248,0.2521308335624885,0.0016082546030363465,2,4,2,2.0,2,0.0,998999,0.14186596870422363,0.0059931278228759766,0.013265132904052734,32.5822057723999,2.1457672119140625e-06,0.012963533401489258,0.10884451866149902,0.0034906738,0.0020773786,0.0006640834,0.0014132953,1.010527,1.0062606,1.0019943,0.0042663813,874.27075,0.8171325,0.81321764,0.80930287,0.0039148033,0.00060014235,-0.0017158713,-0.004031885,0.0023160137,0.0006001451,0.0006001451,0.0006001451,0.0,0.66411924,0.66036254,-0.0006595902,-0.0020728854,-0.0034861807,0.0014132953,4.4190338e-06,4,0.008797872,-0.01306281,-0.01957312,0.0073613864,0.049831923,0.049831625,0.049831126,3.505173e-07,998999,0.8163113,0.8157928,0.81527424,0.0005185306,0.0,0.0,0.0,0.0 +-2,0,0,0.13717124,0.03626487,-0.049319427,0.045948837,-0.19655459402248504,1,-0.19655459402248504,-0.19655459402248504,-0.19655459402248504,0.0,-0.09325349516403507,-0.09827729701124252,-0.10330109885844999,0.005023801847207458,1,2,2,2.0,2,0.0,998999,0.7398281097412109,0.022743940353393555,0.0016701221466064453,20.234392166137695,1.9073486328125e-06,0.008835077285766602,0.7058250904083252,-0.0036168185,-0.008904904,-0.01419299,0.005288086,0.9892082,0.9737615,0.9583148,0.0154467225,857.7699,1.0095485,0.91235673,0.8151649,0.09719181,-3.695485e-05,-0.0023042315,-0.004571508,0.0022672766,0.0,0.0,0.0,0.0,0.8464968,0.84184104,0.01419299,0.007483647,0.0007743044,0.006709343,3.2178352e-05,100,0.004211932,0.0006116896,-0.0027325656,0.0019366256,0.049787067,0.049787063,0.049787067,3.7252903e-09,998999,1.0072821,0.90650094,0.80571973,0.1007812,0.0,0.0,0.0,0.0 +-1,0,0,0.049955986,-0.051808134,-0.24854904,0.08398724,0.4338117202694962,1,0.4338117202694962,0.4338117202694962,0.4338117202694962,0.0,0.24300404877931653,0.2169058601347481,0.19080767149017966,0.026098188644568435,2,4,2,2.0,2,0.0,998999,0.6178200244903564,0.0024394989013671875,0.002116680145263672,20.855294704437256,1.6689300537109375e-06,0.008217334747314453,0.6044015884399414,-0.004432007,-0.0058829254,-0.0073338435,0.0014509181,0.98679197,0.98251534,0.97823876,0.0042766035,440.564,1.0878772,0.8795528,0.6712284,0.20832437,0.07119025,0.055561554,0.039932854,0.0156287,0.06366714,0.046736993,0.029806845,0.016930148,0.7160932,0.7302138,0.013697105,0.0122417575,0.0107864095,0.001455348,1.1014193e-05,200,-0.17424563,-0.48487774,-0.8831331,0.2023349,0.051678017,0.05148844,0.051258426,0.00012519717,998999,1.0813023,0.8714981,0.661694,0.20980415,0.0,0.0,0.0,0.0 diff --git a/tests/regression/iql/halfcheetah_online_progress.csv b/tests/regression/iql/halfcheetah_online_progress.csv index e5b768efb..f8b1df354 100644 --- a/tests/regression/iql/halfcheetah_online_progress.csv +++ b/tests/regression/iql/halfcheetah_online_progress.csv @@ -1,5 +1,5 @@ -Epoch,epoch,eval/num paths total,eval/num steps total,expl/Actions Max,expl/Actions Mean,expl/Actions Min,expl/Actions Std,expl/Average Returns,expl/Num Paths,expl/Returns Max,expl/Returns Mean,expl/Returns Min,expl/Returns Std,expl/Rewards Max,expl/Rewards Mean,expl/Rewards Min,expl/Rewards Std,expl/env_infos/final/reward_ctrl Max,expl/env_infos/final/reward_ctrl Mean,expl/env_infos/final/reward_ctrl Min,expl/env_infos/final/reward_ctrl Std,expl/env_infos/final/reward_run Max,expl/env_infos/final/reward_run Mean,expl/env_infos/final/reward_run Min,expl/env_infos/final/reward_run Std,expl/env_infos/initial/reward_ctrl Max,expl/env_infos/initial/reward_ctrl Mean,expl/env_infos/initial/reward_ctrl Min,expl/env_infos/initial/reward_ctrl Std,expl/env_infos/initial/reward_run Max,expl/env_infos/initial/reward_run Mean,expl/env_infos/initial/reward_run Min,expl/env_infos/initial/reward_run Std,expl/env_infos/reward_ctrl Max,expl/env_infos/reward_ctrl Mean,expl/env_infos/reward_ctrl Min,expl/env_infos/reward_ctrl Std,expl/env_infos/reward_run Max,expl/env_infos/reward_run Mean,expl/env_infos/reward_run Min,expl/env_infos/reward_run Std,expl/num paths total,expl/num steps total,expl/path length Max,expl/path length Mean,expl/path length Min,expl/path length Std,replay_buffer/size,time/data storing (s),time/epoch (s),time/evaluation sampling (s),time/exploration sampling (s),time/logging (s),time/saving (s),time/total (s),time/training (s),trainer/Advantage Score Max,trainer/Advantage Score Mean,trainer/Advantage Score Min,trainer/Advantage Score Std,trainer/Advantage Weights Max,trainer/Advantage Weights Mean,trainer/Advantage Weights Min,trainer/Advantage Weights Std,trainer/Log Pis Max,trainer/Log Pis Mean,trainer/Log Pis Min,trainer/Log Pis Std,trainer/Policy Loss,trainer/Q Targets Max,trainer/Q Targets Mean,trainer/Q Targets Min,trainer/Q Targets Std,trainer/Q1 Predictions Max,trainer/Q1 Predictions Mean,trainer/Q1 Predictions Min,trainer/Q1 Predictions Std,trainer/Q2 Predictions Max,trainer/Q2 Predictions Mean,trainer/Q2 Predictions Min,trainer/Q2 Predictions Std,trainer/QF1 Loss,trainer/QF2 Loss,trainer/V1 Predictions Max,trainer/V1 Predictions Mean,trainer/V1 Predictions Min,trainer/V1 Predictions Std,trainer/VF Loss,trainer/num train calls,trainer/policy/mean Max,trainer/policy/mean Mean,trainer/policy/mean Min,trainer/policy/mean Std,trainer/policy/std Max,trainer/policy/std Mean,trainer/policy/std Min,trainer/policy/std Std,trainer/replay_buffer_len,trainer/rewards Max,trainer/rewards Mean,trainer/rewards Min,trainer/rewards Std,trainer/terminals Max,trainer/terminals Mean,trainer/terminals Min,trainer/terminals Std --2,-2,0,0,0.13717124,0.03626487,-0.049319427,0.045948837,-0.19655459402248504,1,-0.19655459402248504,-0.19655459402248504,-0.19655459402248504,0.0,-0.09325349516403507,-0.09827729701124252,-0.10330109885844999,0.005023801847207458,-0.0011511136777698995,-0.0011511136777698995,-0.0011511136777698995,0.0,-0.10214998518068008,-0.10214998518068008,-0.10214998518068008,0.0,-0.0029606109485030177,-0.0029606109485030177,-0.0029606109485030177,0.0,-0.09029288421553205,-0.09029288421553205,-0.09029288421553205,0.0,-0.0011511136777698995,-0.0020558623131364585,-0.0029606109485030177,0.0009047486353665591,-0.09029288421553205,-0.09622143469810607,-0.10214998518068008,0.005928550482574013,1,2,2,2.0,2,0.0,998999,2.0239967852830887e-06,0.14877367997542024,0.06630018298164941,0.0016109530115500093,0.002452037006150931,0.011688223981764168,29.101119285012828,0.06672025899752043,-0.0036168185,-0.008904904,-0.01419299,0.005288086,0.9892082,0.9737615,0.9583148,0.0154467225,10.170204,9.836491,9.502777,0.33371353,857.7699,1.0095485,0.91235673,0.8151649,0.09719181,-3.695485e-05,-0.0023042315,-0.004571508,0.0022672766,0.0,0.0,0.0,0.0,0.8464968,0.84184104,0.01419299,0.007483647,0.0007743044,0.006709343,3.2178352e-05,2,0.004211932,0.0006116896,-0.0027325656,0.0019366256,0.049787067,0.049787063,0.049787067,3.7252903e-09,998999,1.0072821,0.90650094,0.80571973,0.1007812,0.0,0.0,0.0,0.0 --1,-1,0,0,0.052264355,-0.0057546324,-0.06404259,0.028885532,0.4706610319223018,1,0.4706610319223018,0.4706610319223018,0.4706610319223018,0.0,0.26311454589158256,0.2353305159611509,0.20754648603071926,0.02778402993043165,-0.0005197156686335803,-0.0005197156686335803,-0.0005197156686335803,0.0,0.26363426156021613,0.26363426156021613,0.26363426156021613,0.0,-0.000521271862089634,-0.000521271862089634,-0.000521271862089634,0.0,0.2080677578928089,0.2080677578928089,0.2080677578928089,0.0,-0.0005197156686335803,-0.0005204937653616071,-0.000521271862089634,7.780967280268452e-07,0.26363426156021613,0.23585100972651252,0.2080677578928089,0.027783251833703615,2,4,2,2.0,2,0.0,998999,1.9890139810740948e-06,0.045404493022942916,0.0032403480145148933,0.0024896099930629134,0.0013084060046821833,0.011234880017582327,29.15082311502192,0.027129259979119524,0.0034906738,0.0020773786,0.0006640834,0.0014132953,1.010527,1.0062606,1.0019943,0.0042663813,9.529525,8.669807,7.810091,0.8597169,874.27075,0.8171325,0.81321764,0.80930287,0.0039148033,0.00060014235,-0.0017158713,-0.004031885,0.0023160137,0.0006001451,0.0006001451,0.0006001451,0.0,0.66411924,0.66036254,-0.0006595902,-0.0020728854,-0.0034861807,0.0014132953,4.4190338e-06,4,0.008797872,-0.01306281,-0.01957312,0.0073613864,0.049831923,0.049831625,0.049831126,3.505173e-07,998999,0.8163113,0.8157928,0.81527424,0.0005185306,0.0,0.0,0.0,0.0 -0,0,0,0,0.08386825,-0.0078114443,-0.08241301,0.049048368,-0.13273116819990327,1,-0.13273116819990327,-0.13273116819990327,-0.13273116819990327,0.0,-0.04573715315183351,-0.06636558409995164,-0.08699401504806978,0.020628430948118136,-0.0011722307652235033,-0.0011722307652235033,-0.0011722307652235033,0.0,-0.04456492238661,-0.04456492238661,-0.04456492238661,0.0,-0.0017878822982311249,-0.0017878822982311249,-0.0017878822982311249,0.0,-0.08520613274983865,-0.08520613274983865,-0.08520613274983865,0.0,-0.0011722307652235033,-0.001480056531727314,-0.0017878822982311249,0.0003078257665038108,-0.04456492238661,-0.06488552756822433,-0.08520613274983865,0.020320605181614326,4,8,2,2.0,2,0.0,999003,5.4203992476686835e-05,0.052629252983024344,0.003310469997813925,0.0013926989922765642,0.0019216470245737582,0.009531075978884473,29.206655308022164,0.036419156996998936,0.0002089516,-0.0074295807,-0.015068113,0.007638532,1.000627,0.9782146,0.95580214,0.02241245,10.203518,9.326001,8.448484,0.87751675,749.7348,0.8889162,0.6575853,0.42625445,0.23133087,0.0011964651,-0.0006333184,-0.0024631019,0.0018297834,0.0011964955,0.0011964955,0.0011964955,0.0,0.48761567,0.48436025,0.015083013,0.007444481,-0.00019405119,0.007638532,3.4072487e-05,6,0.00570123,-0.021549514,-0.041000422,0.013357697,0.049876567,0.04987557,0.049874462,7.828489e-07,999003,0.882376,0.65433013,0.4262843,0.22804585,0.0,0.0,0.0,0.0 -1,1,0,0,0.076388456,0.0012027482,-0.039100315,0.034910962,0.02090291014717436,1,0.02090291014717436,0.02090291014717436,0.02090291014717436,0.0,0.012212730127035254,0.01045145507358718,0.008690180020139104,0.0017612750534480746,-0.0009377333335578442,-0.0009377333335578442,-0.0009377333335578442,0.0,0.009627913353696949,0.009627913353696949,0.009627913353696949,0.0,-0.0005265331361442805,-0.0005265331361442805,-0.0005265331361442805,0.0,0.012739263263179534,0.012739263263179534,0.012739263263179534,0.0,-0.0005265331361442805,-0.0007321332348510623,-0.0009377333335578442,0.00020560009870678185,0.012739263263179534,0.011183588308438241,0.009627913353696949,0.001555674954741293,5,10,2,2.0,2,0.0,999005,5.293299909681082e-05,0.04207111700088717,0.001639438996789977,0.0014825899852439761,0.0018372160266153514,0.012364558991976082,29.253537037031492,0.024694380001164973,-0.0025933185,-0.008260839,-0.013928359,0.00566752,0.9922502,0.97566307,0.95907587,0.016587168,9.920778,9.850875,9.780972,0.06990337,735.0451,0.9276143,0.85773814,0.78786194,0.069876164,0.0059791873,0.0038828503,0.0017865131,0.002096337,0.0017865759,0.0017865759,0.0017865759,0.0,0.7342488,0.73753566,0.010155345,0.00638988,0.0026244146,0.0037654655,3.0108675e-05,8,0.026928315,-0.027511992,-0.05553205,0.02637993,0.049919788,0.049918402,0.049915932,1.4599646e-06,999005,0.92533284,0.8552668,0.7852007,0.070066065,0.0,0.0,0.0,0.0 +epoch,eval/num paths total,eval/num steps total,expl/Actions Max,expl/Actions Mean,expl/Actions Min,expl/Actions Std,expl/Average Returns,expl/Num Paths,expl/Returns Max,expl/Returns Mean,expl/Returns Min,expl/Returns Std,expl/Rewards Max,expl/Rewards Mean,expl/Rewards Min,expl/Rewards Std,expl/num paths total,expl/num steps total,expl/path length Max,expl/path length Mean,expl/path length Min,expl/path length Std,replay_buffer/size,time/epoch_time (s),time/evaluation sampling (s),time/exploration sampling (s),time/global_time (s),time/replay buffer data storing (s),time/saving (s),time/training (s),trainer/Advantage Score Max,trainer/Advantage Score Mean,trainer/Advantage Score Min,trainer/Advantage Score Std,trainer/Advantage Weights Max,trainer/Advantage Weights Mean,trainer/Advantage Weights Min,trainer/Advantage Weights Std,trainer/Policy Loss,trainer/Q Targets Max,trainer/Q Targets Mean,trainer/Q Targets Min,trainer/Q Targets Std,trainer/Q1 Predictions Max,trainer/Q1 Predictions Mean,trainer/Q1 Predictions Min,trainer/Q1 Predictions Std,trainer/Q2 Predictions Max,trainer/Q2 Predictions Mean,trainer/Q2 Predictions Min,trainer/Q2 Predictions Std,trainer/QF1 Loss,trainer/QF2 Loss,trainer/V1 Predictions Max,trainer/V1 Predictions Mean,trainer/V1 Predictions Min,trainer/V1 Predictions Std,trainer/VF Loss,trainer/num train calls,trainer/policy/mean Max,trainer/policy/mean Mean,trainer/policy/mean Min,trainer/policy/mean Std,trainer/policy/std Max,trainer/policy/std Mean,trainer/policy/std Min,trainer/policy/std Std,trainer/replay_buffer_len,trainer/rewards Max,trainer/rewards Mean,trainer/rewards Min,trainer/rewards Std,trainer/terminals Max,trainer/terminals Mean,trainer/terminals Min,trainer/terminals Std +-2,0,0,0.13717124,0.03626487,-0.049319427,0.045948837,-0.19655459402248504,1,-0.19655459402248504,-0.19655459402248504,-0.19655459402248504,0.0,-0.09325349516403507,-0.09827729701124252,-0.10330109885844999,0.005023801847207458,1,2,2,2.0,2,0.0,998999,0.7535736560821533,0.01971721649169922,0.0017578601837158203,20.90360975265503,2.384185791015625e-06,0.008637666702270508,0.7226138114929199,-0.0036168185,-0.008904904,-0.01419299,0.005288086,0.9892082,0.9737615,0.9583148,0.0154467225,857.7699,1.0095485,0.91235673,0.8151649,0.09719181,-3.695485e-05,-0.0023042315,-0.004571508,0.0022672766,0.0,0.0,0.0,0.0,0.8464968,0.84184104,0.01419299,0.007483647,0.0007743044,0.006709343,3.2178352e-05,100,0.004211932,0.0006116896,-0.0027325656,0.0019366256,0.049787067,0.049787063,0.049787067,3.7252903e-09,998999,1.0072821,0.90650094,0.80571973,0.1007812,0.0,0.0,0.0,0.0 +-1,0,0,0.049955986,-0.051808134,-0.24854904,0.08398724,0.4338117202694962,1,0.4338117202694962,0.4338117202694962,0.4338117202694962,0.0,0.24300404877931653,0.2169058601347481,0.19080767149017966,0.026098188644568435,2,4,2,2.0,2,0.0,998999,0.6934354305267334,0.0024938583374023438,0.002553224563598633,21.599952220916748,1.9073486328125e-06,0.00806879997253418,0.6797149181365967,-0.004432007,-0.0058829254,-0.0073338435,0.0014509181,0.98679197,0.98251534,0.97823876,0.0042766035,440.564,1.0878772,0.8795528,0.6712284,0.20832437,0.07119025,0.055561554,0.039932854,0.0156287,0.06366714,0.046736993,0.029806845,0.016930148,0.7160932,0.7302138,0.013697105,0.0122417575,0.0107864095,0.001455348,1.1014193e-05,200,-0.17424563,-0.48487774,-0.8831331,0.2023349,0.051678017,0.05148844,0.051258426,0.00012519717,998999,1.0813023,0.8714981,0.661694,0.20980415,0.0,0.0,0.0,0.0 +0,0,0,0.12498924,-0.07945812,-0.43564224,0.1438326,-0.22196420100855407,1,-0.22196420100855407,-0.22196420100855407,-0.22196420100855407,0.0,-0.06626331815569395,-0.11098210050427704,-0.15570088285286013,0.04471878234858309,4,8,2,2.0,2,0.0,999003,0.6045346260070801,0.0011963844299316406,0.0011255741119384766,22.207528352737427,3.814697265625e-05,0.010853290557861328,0.5892090797424316,-0.0010308158,-0.0048863282,-0.008741841,0.0038555125,0.9969123,0.9855138,0.9741154,0.011398464,180.96556,0.9284907,0.8786598,0.8288289,0.049830914,0.1675013,0.12049564,0.07348998,0.04700566,0.06015433,0.059967212,0.0597801,0.00018711574,0.5748209,0.67272204,0.030508472,0.026754959,0.023001445,0.003753513,1.1622354e-05,300,0.3537424,-0.43678454,-0.9513559,0.35695845,0.053244345,0.052891567,0.05234169,0.0002728698,999003,0.91071784,0.85639596,0.80207413,0.054321855,0.0,0.0,0.0,0.0 +1,0,0,0.11573008,-0.044602025,-0.21003142,0.09092906,-0.10424991824240865,1,-0.10424991824240865,-0.10424991824240865,-0.10424991824240865,0.0,-0.010003530794350048,-0.05212495912120432,-0.0942463874480586,0.04212142832685428,5,10,2,2.0,2,0.0,999005,0.6860871315002441,0.0013425350189208984,0.0012538433074951172,22.896247386932373,3.790855407714844e-05,0.007896184921264648,0.6749513149261475,0.004231114,0.0035969168,0.0029627196,0.00063419715,1.0127742,1.010851,1.0089278,0.0019232035,123.68506,0.8583293,0.8315314,0.80473346,0.02679792,0.33825877,0.21806143,0.097864084,0.12019734,0.09187969,0.089366615,0.086853534,0.002513077,0.39795297,0.5513983,0.040352434,0.039215915,0.038079392,0.0011365209,9.338011e-06,400,0.40095848,-0.35091817,-0.9745383,0.52953124,0.054846372,0.054164585,0.053079303,0.00054456014,999005,0.81591684,0.78952396,0.7631311,0.026392877,0.0,0.0,0.0,0.0 diff --git a/tests/regression/iql/test_iql_offline.py b/tests/regression/iql/test_iql_offline.py index b78954ea6..74f9eb303 100644 --- a/tests/regression/iql/test_iql_offline.py +++ b/tests/regression/iql/test_iql_offline.py @@ -15,7 +15,7 @@ def test_iql(): iql.variant["algo_kwargs"]["batch_size"] = 2 iql.variant["algo_kwargs"]["num_eval_steps_per_epoch"] = 2 iql.variant["algo_kwargs"]["num_expl_steps_per_train_loop"] = 2 - iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 2 + iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 100 iql.variant["algo_kwargs"]["min_num_steps_before_training"] = 2 iql.variant["qf_kwargs"] = dict(hidden_sizes=[2, 2]) diff --git a/tests/regression/iql/test_iql_online.py b/tests/regression/iql/test_iql_online.py index 99a4d5be3..68d85af19 100644 --- a/tests/regression/iql/test_iql_online.py +++ b/tests/regression/iql/test_iql_online.py @@ -15,7 +15,7 @@ def test_iql(): iql.variant["algo_kwargs"]["batch_size"] = 2 iql.variant["algo_kwargs"]["num_eval_steps_per_epoch"] = 2 iql.variant["algo_kwargs"]["num_expl_steps_per_train_loop"] = 2 - iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 2 + iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 100 iql.variant["algo_kwargs"]["min_num_steps_before_training"] = 2 iql.variant["qf_kwargs"] = dict(hidden_sizes=[2, 2])