diff --git a/tests/algos/transformer/algo_test.py b/tests/algos/transformer/algo_test.py index f3ce91c1..0f2dfc18 100644 --- a/tests/algos/transformer/algo_test.py +++ b/tests/algos/transformer/algo_test.py @@ -295,7 +295,7 @@ def save_policy_tester( if algo.get_action_type() == ActionSpace.DISCRETE: assert action == algo.predict(inpt).argmax() else: - assert np.allclose(action, algo.predict(inpt), atol=1e-4) + assert np.allclose(action, algo.predict(inpt), atol=1e-3) # check save_policy as ONNX algo.save_policy(os.path.join("test_data", "model.onnx")) @@ -319,4 +319,4 @@ def save_policy_tester( if algo.get_action_type() == ActionSpace.DISCRETE: assert action == algo.predict(inpt).argmax() else: - assert np.allclose(action, algo.predict(inpt), atol=1e-4) + assert np.allclose(action, algo.predict(inpt), atol=1e-3)