From 805973e9683d93498447c84a573cdf509a14ecec Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Fri, 22 Nov 2024 10:48:34 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=F0=9F=90=9BFixed=20`test=5Frecover`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_learn/test_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_learn/test_data.py b/tests/test_learn/test_data.py index f526a79..09e636c 100644 --- a/tests/test_learn/test_data.py +++ b/tests/test_learn/test_data.py @@ -179,7 +179,8 @@ def forward(self, batch): self.assertAlmostEqual(outputs.metric_outputs.final_score, 0.0) predictions = outputs.forward_results[cflearn.PREDICTIONS_KEY] np.testing.assert_allclose(predictions, np.ones_like(predictions)) - outputs = p.evaluate(loader, return_outputs=True, recover_labels=False) + raw_kw = dict(recover_labels=False, recover_predictions=False) + outputs = p.evaluate(loader, return_outputs=True, **raw_kw) self.assertAlmostEqual(outputs.metric_outputs.final_score, 0.0) predictions = outputs.forward_results[cflearn.PREDICTIONS_KEY] np.testing.assert_allclose(predictions, np.zeros_like(predictions)) @@ -190,7 +191,7 @@ def forward(self, batch): self.assertAlmostEqual(outputs.metric_outputs.final_score, 0.0) predictions = outputs.forward_results[cflearn.PREDICTIONS_KEY] np.testing.assert_allclose(predictions, np.ones_like(predictions)) - outputs = p.evaluate(loader, return_outputs=True, recover_labels=False) + outputs = p.evaluate(loader, return_outputs=True, **raw_kw) self.assertAlmostEqual(outputs.metric_outputs.final_score, 0.0) predictions = outputs.forward_results[cflearn.PREDICTIONS_KEY] np.testing.assert_allclose(predictions, np.zeros_like(predictions))