From 7033892155167c510378a1e401a246baa800d4b0 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 28 Jun 2024 11:20:14 +0800 Subject: [PATCH] fix #7832 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_pad_collation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 61c7d9c720..9d5012c9a3 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -89,7 +89,7 @@ def tearDown(self) -> None: @parameterized.expand(TESTS) def test_pad_collation(self, t_type, collate_method, transform): - if isinstance(t_type, dict): + if t_type is dict: dataset = CacheDataset(self.dict_data, transform, progress=False) else: dataset = _Dataset(self.list_data, self.list_labels, transform) @@ -104,7 +104,7 @@ def test_pad_collation(self, t_type, collate_method, transform): loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) # check collation in forward direction for data in loader: - if isinstance(t_type, dict): + if t_type is dict: shapes = [] decollated_data = decollate_batch(data) for d in decollated_data: @@ -113,7 +113,7 @@ def test_pad_collation(self, t_type, collate_method, transform): self.assertTrue(len(output["image"].applied_operations), len(dataset.transform.transforms)) self.assertTrue(len(set(shapes)) > 1) # inverted shapes must be different because of random xforms - if isinstance(t_type, dict): + if t_type is dict: batch_inverse = BatchInverseTransform(dataset.transform, loader) for data in loader: output = batch_inverse(data)