From b734d4ea99fe925d222b6a2b6a7214b0079b70e2 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 5 May 2024 22:29:58 +0900 Subject: [PATCH] Lower atol threshold for DT's save_policy test --- tests/algos/transformer/algo_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/algos/transformer/algo_test.py b/tests/algos/transformer/algo_test.py index f3ce91c1..0f2dfc18 100644 --- a/tests/algos/transformer/algo_test.py +++ b/tests/algos/transformer/algo_test.py @@ -295,7 +295,7 @@ def save_policy_tester( if algo.get_action_type() == ActionSpace.DISCRETE: assert action == algo.predict(inpt).argmax() else: - assert np.allclose(action, algo.predict(inpt), atol=1e-4) + assert np.allclose(action, algo.predict(inpt), atol=1e-3) # check save_policy as ONNX algo.save_policy(os.path.join("test_data", "model.onnx")) @@ -319,4 +319,4 @@ def save_policy_tester( if algo.get_action_type() == ActionSpace.DISCRETE: assert action == algo.predict(inpt).argmax() else: - assert np.allclose(action, algo.predict(inpt), atol=1e-4) + assert np.allclose(action, algo.predict(inpt), atol=1e-3)