diff --git a/.coveragerc b/.coveragerc index 103964b..d218808 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,12 +1,12 @@ # .coveragerc to control coverage.py [run] branch = True -source = equiadapt +source = /src/equiadapt/ # omit = bad_file.py [paths] source = - equiadapt/ + /src/equiadapt/ */site-packages/ [report] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c8a2..a8cb9a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -56,8 +56,7 @@ repos: rev: 7.0.0 hooks: - id: flake8 - ## You can add flake8 plugins via `additional_dependencies`: - # additional_dependencies: [flake8-bugbear] + additional_dependencies: [flake8-docstrings] ## Check for misspells in documentation files: # - repo: https://github.com/codespell-project/codespell diff --git a/docs/conf.py b/docs/conf.py index 7249aa9..758fcb3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -176,10 +176,7 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -html_theme_options = { - "sidebar_width": "300px", - "page_width": "1200px" -} +html_theme_options = {"sidebar_width": "300px", "page_width": "1200px"} # Add any paths that contain custom themes here, relative to this directory. # html_theme_path = [] @@ -301,4 +298,4 @@ "pyscaffold": ("https://pyscaffold.org/en/stable", None), } -print(f"loading configurations for {project} {version} ...", file=sys.stderr) \ No newline at end of file +print(f"loading configurations for {project} {version} ...", file=sys.stderr) diff --git a/equiadapt/images/canonicalization/continuous_group.py b/equiadapt/images/canonicalization/continuous_group.py index fb47ad7..03d6816 100644 --- a/equiadapt/images/canonicalization/continuous_group.py +++ b/equiadapt/images/canonicalization/continuous_group.py @@ -1,6 +1,6 @@ import math from omegaconf import DictConfig -from typing import Any, List, Tuple, Union, Optional, Dict +from typing import Optional, Dict, Any, Union, Tuple, List import kornia as K import torch @@ -250,9 +250,10 @@ def get_groupelement(self, x: torch.Tensor) -> dict: if not hasattr(self, "canonicalization_info_dict"): self.canonicalization_info_dict = {} - group_element_dict, group_element_representation = ( - self.get_group_from_out_vectors(out_vectors) - ) + ( + group_element_dict, + group_element_representation, + ) = self.get_group_from_out_vectors(out_vectors) self.canonicalization_info_dict["group_element_matrix_representation"] = ( group_element_representation ) @@ -384,9 +385,10 @@ def get_groupelement(self, x: torch.Tensor) -> dict: if not hasattr(self, "canonicalization_info_dict"): self.canonicalization_info_dict = {} - group_element_dict, group_element_representations = ( - self.get_group_from_out_vectors(out_vectors) - ) + ( + group_element_dict, + group_element_representations, + ) = self.get_group_from_out_vectors(out_vectors) # Store the matrix representation of the group element for regularization and identity metric self.canonicalization_info_dict["group_element_matrix_representation"] = ( group_element_representations diff --git a/equiadapt/pointcloud/canonicalization/continuous_group.py b/equiadapt/pointcloud/canonicalization/continuous_group.py index d47fd53..fb6146d 100644 --- a/equiadapt/pointcloud/canonicalization/continuous_group.py +++ b/equiadapt/pointcloud/canonicalization/continuous_group.py @@ -3,6 +3,7 @@ from omegaconf import DictConfig import torch + from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization from equiadapt.common.utils import gram_schmidt from typing import Any, List, Tuple, Union, Optional diff --git a/examples/images/segmentation/model_utils.py b/examples/images/segmentation/model_utils.py index 4fe1348..3aa1aa9 100644 --- a/examples/images/segmentation/model_utils.py +++ b/examples/images/segmentation/model_utils.py @@ -89,7 +89,6 @@ def forward(self, images, targets): class SAMModel(nn.Module): - def __init__( self, architecture_type: str, @@ -148,7 +147,6 @@ def forward(self, images, targets): class FocalLoss(nn.Module): - def __init__(self, weight=None, size_average=True): super().__init__() self.name = "focal_loss" @@ -169,7 +167,6 @@ def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): class DiceLoss(nn.Module): - def __init__(self, weight=None, size_average=True): super().__init__() self.name = "dice_loss" diff --git a/examples/images/segmentation/prepare/coco_data.py b/examples/images/segmentation/prepare/coco_data.py index 09f376b..8470aff 100644 --- a/examples/images/segmentation/prepare/coco_data.py +++ b/examples/images/segmentation/prepare/coco_data.py @@ -12,7 +12,6 @@ class ResizeAndPad: - def __init__(self, target_size): self.target_size = target_size self.transform = ResizeLongestSide(target_size) @@ -126,7 +125,6 @@ def test_dataloader(self): class COCODataset(Dataset): - def __init__(self, root_dir, annotation_file, transform=None, sam_transform=None): self.root_dir = root_dir self.transform = transform diff --git a/setup.cfg b/setup.cfg index 7f72280..6f245d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ install_requires = torchvision kornia escnn @ git+https://github.com/danibene/escnn.git@remove/py3nj_dep + hydra-core [options.packages.find] exclude = @@ -95,7 +96,7 @@ formats = bdist_wheel [flake8] # Some sane defaults for the code style checker flake8 max_line_length = 130 -extend_ignore = E203, W503, E401, E501, E741, E266 +extend_ignore = E203, W503, E401, E501, E741, E266, D100, D107, D400, D401 # ^ Black-compatible # E203 and W503 have edge cases handled by black exclude = diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils.py b/tests/common/test_utils.py similarity index 100% rename from tests/test_utils.py rename to tests/common/test_utils.py diff --git a/tests/images/__init__.py b/tests/images/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/images/canonicalization/__init__.py b/tests/images/canonicalization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/images/canonicalization/test_continuous_group.py b/tests/images/canonicalization/test_continuous_group.py new file mode 100644 index 0000000..9f22429 --- /dev/null +++ b/tests/images/canonicalization/test_continuous_group.py @@ -0,0 +1,75 @@ +from unittest.mock import Mock + +import pytest +import torch +from omegaconf import DictConfig + +from equiadapt import ( + ContinuousGroupImageCanonicalization, +) # Update with your actual import path + + +@pytest.fixture +def sample_input(): + # Create a sample input tensor + return torch.rand((1, 3, 64, 64)) # A batch with one color image of size 64x64 + + +@pytest.fixture +def grayscale_input(): + # Create a grayscale input tensor + return torch.rand((1, 1, 64, 64)) # A batch with one grayscale image of size 64x64 + + +@pytest.fixture +def init_args(): + # Mock initialization arguments + canonicalization_hyperparams = DictConfig( + { + "input_crop_ratio": 0.9, + "resize_shape": (32, 32), + } + ) + return { + "canonicalization_network": torch.nn.Identity(), # Placeholder + "canonicalization_hyperparams": canonicalization_hyperparams, + "in_shape": (3, 64, 64), + } + + +def test_initialization(init_args): + cgic = ContinuousGroupImageCanonicalization(**init_args) + assert cgic.pad is not None, "Pad should be initialized." + assert cgic.crop is not None, "Crop should be initialized." + + +def test_transformation_before_canonicalization_network_forward( + sample_input, init_args +): + cgic = ContinuousGroupImageCanonicalization(**init_args) + transformed = cgic.transformations_before_canonicalization_network_forward( + sample_input + ) + assert transformed.size() == torch.Size( + [1, 3, 32, 32] + ), "The transformed image should be resized to (32, 32)." + + +@pytest.fixture +def canonicalization_instance(): + instance = ContinuousGroupImageCanonicalization( + canonicalization_network=Mock(), + canonicalization_hyperparams={ + "input_crop_ratio": 0.9, + "resize_shape": (32, 32), + }, + in_shape=(3, 64, 64), + ) + # Mocking the get_groupelement method to return a fixed group element + instance.get_groupelement = Mock( + return_value={ + "rotation": torch.eye(2).unsqueeze(0), + "reflection": torch.tensor([[[[0]]]]), + } + ) + return instance diff --git a/tox.ini b/tox.ini index 69f8159..9fa8a21 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ extras = testing commands = pytest {posargs} +usedevelop=True # # To run `tox -e lint` you need to make sure you have a