Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
arnab39 committed Mar 13, 2024
2 parents b42a401 + d23067c commit 1fa46b6
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 22 deletions.
4 changes: 2 additions & 2 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# .coveragerc to control coverage.py
[run]
branch = True
source = equiadapt
source = /src/equiadapt/
# omit = bad_file.py

[paths]
source =
equiadapt/
/src/equiadapt/
*/site-packages/

[report]
Expand Down
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ repos:
rev: 7.0.0
hooks:
- id: flake8
## You can add flake8 plugins via `additional_dependencies`:
# additional_dependencies: [flake8-bugbear]
additional_dependencies: [flake8-docstrings]

## Check for misspells in documentation files:
# - repo: https://github.com/codespell-project/codespell
Expand Down
7 changes: 2 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,7 @@
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
html_theme_options = {
"sidebar_width": "300px",
"page_width": "1200px"
}
html_theme_options = {"sidebar_width": "300px", "page_width": "1200px"}

# Add any paths that contain custom themes here, relative to this directory.
# html_theme_path = []
Expand Down Expand Up @@ -301,4 +298,4 @@
"pyscaffold": ("https://pyscaffold.org/en/stable", None),
}

print(f"loading configurations for {project} {version} ...", file=sys.stderr)
print(f"loading configurations for {project} {version} ...", file=sys.stderr)
16 changes: 9 additions & 7 deletions equiadapt/images/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from omegaconf import DictConfig
from typing import Any, List, Tuple, Union, Optional, Dict
from typing import Optional, Dict, Any, Union, Tuple, List

import kornia as K
import torch
Expand Down Expand Up @@ -250,9 +250,10 @@ def get_groupelement(self, x: torch.Tensor) -> dict:
if not hasattr(self, "canonicalization_info_dict"):
self.canonicalization_info_dict = {}

group_element_dict, group_element_representation = (
self.get_group_from_out_vectors(out_vectors)
)
(
group_element_dict,
group_element_representation,
) = self.get_group_from_out_vectors(out_vectors)
self.canonicalization_info_dict["group_element_matrix_representation"] = (
group_element_representation
)
Expand Down Expand Up @@ -384,9 +385,10 @@ def get_groupelement(self, x: torch.Tensor) -> dict:
if not hasattr(self, "canonicalization_info_dict"):
self.canonicalization_info_dict = {}

group_element_dict, group_element_representations = (
self.get_group_from_out_vectors(out_vectors)
)
(
group_element_dict,
group_element_representations,
) = self.get_group_from_out_vectors(out_vectors)
# Store the matrix representation of the group element for regularization and identity metric
self.canonicalization_info_dict["group_element_matrix_representation"] = (
group_element_representations
Expand Down
1 change: 1 addition & 0 deletions equiadapt/pointcloud/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from omegaconf import DictConfig
import torch

from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization
from equiadapt.common.utils import gram_schmidt
from typing import Any, List, Tuple, Union, Optional
Expand Down
3 changes: 0 additions & 3 deletions examples/images/segmentation/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def forward(self, images, targets):


class SAMModel(nn.Module):

def __init__(
self,
architecture_type: str,
Expand Down Expand Up @@ -148,7 +147,6 @@ def forward(self, images, targets):


class FocalLoss(nn.Module):

def __init__(self, weight=None, size_average=True):
super().__init__()
self.name = "focal_loss"
Expand All @@ -169,7 +167,6 @@ def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):


class DiceLoss(nn.Module):

def __init__(self, weight=None, size_average=True):
super().__init__()
self.name = "dice_loss"
Expand Down
2 changes: 0 additions & 2 deletions examples/images/segmentation/prepare/coco_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class ResizeAndPad:

def __init__(self, target_size):
self.target_size = target_size
self.transform = ResizeLongestSide(target_size)
Expand Down Expand Up @@ -126,7 +125,6 @@ def test_dataloader(self):


class COCODataset(Dataset):

def __init__(self, root_dir, annotation_file, transform=None, sam_transform=None):
self.root_dir = root_dir
self.transform = transform
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ install_requires =
torchvision
kornia
escnn @ git+https://github.com/danibene/escnn.git@remove/py3nj_dep
hydra-core

[options.packages.find]
exclude =
Expand Down Expand Up @@ -95,7 +96,7 @@ formats = bdist_wheel
[flake8]
# Some sane defaults for the code style checker flake8
max_line_length = 130
extend_ignore = E203, W503, E401, E501, E741, E266
extend_ignore = E203, W503, E401, E501, E741, E266, D100, D107, D400, D401
# ^ Black-compatible
# E203 and W503 have edge cases handled by black
exclude =
Expand Down
Empty file added tests/common/__init__.py
Empty file.
File renamed without changes.
Empty file added tests/images/__init__.py
Empty file.
Empty file.
75 changes: 75 additions & 0 deletions tests/images/canonicalization/test_continuous_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from unittest.mock import Mock

import pytest
import torch
from omegaconf import DictConfig

from equiadapt import (
ContinuousGroupImageCanonicalization,
) # Update with your actual import path


@pytest.fixture
def sample_input():
# Create a sample input tensor
return torch.rand((1, 3, 64, 64)) # A batch with one color image of size 64x64


@pytest.fixture
def grayscale_input():
# Create a grayscale input tensor
return torch.rand((1, 1, 64, 64)) # A batch with one grayscale image of size 64x64


@pytest.fixture
def init_args():
# Mock initialization arguments
canonicalization_hyperparams = DictConfig(
{
"input_crop_ratio": 0.9,
"resize_shape": (32, 32),
}
)
return {
"canonicalization_network": torch.nn.Identity(), # Placeholder
"canonicalization_hyperparams": canonicalization_hyperparams,
"in_shape": (3, 64, 64),
}


def test_initialization(init_args):
cgic = ContinuousGroupImageCanonicalization(**init_args)
assert cgic.pad is not None, "Pad should be initialized."
assert cgic.crop is not None, "Crop should be initialized."


def test_transformation_before_canonicalization_network_forward(
sample_input, init_args
):
cgic = ContinuousGroupImageCanonicalization(**init_args)
transformed = cgic.transformations_before_canonicalization_network_forward(
sample_input
)
assert transformed.size() == torch.Size(
[1, 3, 32, 32]
), "The transformed image should be resized to (32, 32)."


@pytest.fixture
def canonicalization_instance():
instance = ContinuousGroupImageCanonicalization(
canonicalization_network=Mock(),
canonicalization_hyperparams={
"input_crop_ratio": 0.9,
"resize_shape": (32, 32),
},
in_shape=(3, 64, 64),
)
# Mocking the get_groupelement method to return a fixed group element
instance.get_groupelement = Mock(
return_value={
"rotation": torch.eye(2).unsqueeze(0),
"reflection": torch.tensor([[[[0]]]]),
}
)
return instance
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ extras =
testing
commands =
pytest {posargs}
usedevelop=True


# # To run `tox -e lint` you need to make sure you have a
Expand Down

0 comments on commit 1fa46b6

Please sign in to comment.