From c95811b7f1ff336b660d7e9149f669bdbe58d487 Mon Sep 17 00:00:00 2001 From: LutingWang <2457348692@qq.com> Date: Tue, 13 Aug 2024 23:02:18 +0800 Subject: [PATCH] fix(datasets): lvis and coco annotations can be empty --- todd/configs/serialize.py | 10 +++++++--- todd/datasets/coco.py | 8 +++++--- todd/datasets/lvis.py | 8 +++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/todd/configs/serialize.py b/todd/configs/serialize.py index 93643e8..091ce22 100644 --- a/todd/configs/serialize.py +++ b/todd/configs/serialize.py @@ -9,7 +9,7 @@ from typing_extensions import Self from ..bases.configs import Config -from ..loggers import logger +from ..loggers import master_logger class SerializeMixin(Config): @@ -27,9 +27,13 @@ def loads(cls, s: str, **kwargs) -> Self: def load(cls, file: str | pathlib.Path, **kwargs) -> Self: if len(kwargs) > 0: kwargs_str = ', '.join(f'{k}={v}' for k, v in kwargs.items()) - logger.debug("Loading config from %s with %s", file, kwargs_str) + master_logger.debug( + "Loading config from %s with %s", + file, + kwargs_str, + ) else: - logger.debug("Loading config from %s", file) + master_logger.debug("Loading config from %s", file) if isinstance(file, str): file = pathlib.Path(file) diff --git a/todd/datasets/coco.py b/todd/datasets/coco.py index 4645160..812e1b1 100644 --- a/todd/datasets/coco.py +++ b/todd/datasets/coco.py @@ -122,9 +122,11 @@ def is_crowd(self) -> torch.Tensor: @property def bboxes(self) -> 'FlattenBBoxesXYWH': from todd.tasks.object_detection import FlattenBBoxesXYWH - return FlattenBBoxesXYWH( - torch.tensor([annotation.bbox for annotation in self]), - ) + if len(self) > 0: + bboxes = torch.tensor([annotation.bbox for annotation in self]) + else: + bboxes = torch.zeros(0, 4) + return FlattenBBoxesXYWH(bboxes) @property def categories(self) -> torch.Tensor: diff --git a/todd/datasets/lvis.py b/todd/datasets/lvis.py index c9a3d4a..96ab7fd 100644 --- a/todd/datasets/lvis.py +++ b/todd/datasets/lvis.py @@ -86,9 +86,11 @@ def masks(self) -> torch.Tensor: @property def bboxes(self) -> 'FlattenBBoxesXYWH': from todd.tasks.object_detection import FlattenBBoxesXYWH - return FlattenBBoxesXYWH( - torch.tensor([annotation.bbox for annotation in self]), - ) + if len(self) > 0: + bboxes = torch.tensor([annotation.bbox for annotation in self]) + else: + bboxes = torch.zeros(0, 4) + return FlattenBBoxesXYWH(bboxes) @property def categories(self) -> torch.Tensor: