Skip to content

Commit

Permalink
fixed the bug in the discrete optimization code
Browse files Browse the repository at this point in the history
  • Loading branch information
arnab39 committed Feb 20, 2024
1 parent 7121009 commit c42bb7b
Show file tree
Hide file tree
Showing 5 changed files with 1,010 additions and 262 deletions.
3 changes: 2 additions & 1 deletion equiadapt/common/basecanonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def get_prior_regularization_loss(self):

def get_identity_metric(self):
group_elements_rep = self.canonicalization_info_dict['group_element_matrix_representation']
identity_element = torch.eye(group_elements_rep.shape[-1]).to(self.device)
identity_element = torch.eye(group_elements_rep.shape[-1]).repeat(
group_elements_rep.shape[0], 1, 1).to(self.device)
return 1.0 - torch.nn.functional.mse_loss(group_elements_rep, identity_element).mean()


Expand Down
19 changes: 18 additions & 1 deletion equiadapt/images/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def __init__(self,
):
super().__init__(canonicalization_network)

assert len(in_shape) == 3, 'Input shape should be in the format (channels, height, width)'

# pad and crop the input image if it is not rotated MNIST
is_grayscale = in_shape[0] == 1
is_grayscale = (in_shape[0] == 1)
self.pad = torch.nn.Identity() if is_grayscale else transforms.Pad(
math.ceil(in_shape[-1] * 0.5), padding_mode='edge'
)
Expand All @@ -25,6 +27,7 @@ def __init__(self,
math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio),
math.ceil(in_shape[-1] * canonicalization_hyperparams.input_crop_ratio)
))
self.resize_canonization = torch.nn.Identity() if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape)
self.group_info_dict = {}

def get_groupelement(self, x: torch.Tensor):
Expand All @@ -40,6 +43,15 @@ def get_groupelement(self, x: torch.Tensor):
"""
raise NotImplementedError('get_groupelement method is not implemented')

def transformations_before_canonicalization_network_forward(self, x: torch.Tensor):
"""
This method takes an image as input and
returns the pre-canonicalized image
"""
x = self.crop_canonization(x)
x = self.resize_canonization(x)
return x

def get_group_from_out_vectors(self, out_vectors: torch.Tensor):
"""
This method takes the output of the canonicalization network and
Expand Down Expand Up @@ -100,6 +112,7 @@ def canonicalize(self, x: torch.Tensor):
group_element_dict = self.get_groupelement(x)

rotation_matrices = group_element_dict['rotation']
rotation_matrices[:, [0, 1], [1, 0]] *= -1

if 'reflection' in group_element_dict:
reflect_indicator = group_element_dict['reflection']
Expand Down Expand Up @@ -179,6 +192,8 @@ def get_groupelement(self, x: torch.Tensor):

group_element_dict = {}

x = self.transformations_before_canonicalization_network_forward(x)

# convert the group activations to one hot encoding of group element
# this conversion is differentiable and will be used to select the group element
out_vectors = self.canonicalization_network(x)
Expand Down Expand Up @@ -282,6 +297,8 @@ def get_groupelement(self, x: torch.Tensor):

x_all = torch.cat([x, x_augmented], dim=0) # size (batch_size * 2, in_channels, height, width)

x_all = self.transformations_before_canonicalization_network_forward(x_all)

out_vectors_all = self.canonicalization_network(x_all) # size (batch_size * 2, out_vector_size)

out_vectors_all = out_vectors_all.reshape(2 * batch_size, -1, 2) # size (batch_size * 2, num_vectors, 2)
Expand Down
11 changes: 7 additions & 4 deletions equiadapt/images/canonicalization/discrete_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ def __init__(self,
in_shape: tuple
):
super().__init__(canonicalization_network)

self.beta = canonicalization_hyperparams.beta

assert len(in_shape) == 3, 'Input shape should be in the format (channels, height, width)'

# DEfine all the image transformations here which are used during canonicalization
# pad and crop the input image if it is not rotated MNIST
is_grayscale = in_shape[0] == 1
is_grayscale = (in_shape[0] == 1)

self.pad = torch.nn.Identity() if is_grayscale else transforms.Pad(
math.ceil(in_shape[-2] * 0.4), padding_mode='edge'
Expand All @@ -28,7 +31,7 @@ def __init__(self,
math.ceil(in_shape[-1] * canonicalization_hyperparams.input_crop_ratio)
))

self.resize_canonization = torch.nn.Identity() if is_grayscale or canonicalization_hyperparams.resize_shape == in_shape[-1] else transforms.Resize(size=canonicalization_hyperparams.resize_shape)
self.resize_canonization = torch.nn.Identity() if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape)

def groupactivations_to_groupelement(self, group_activations: torch.Tensor):
"""
Expand Down Expand Up @@ -203,7 +206,7 @@ def rotate_and_maybe_reflect(self, x: torch.Tensor, degrees: torch.Tensor, refle
x_augmented_list = []
for degree in degrees:
x_rot = self.pad(x)
x_rot = K.geometry.rotate(x_rot, degree)
x_rot = K.geometry.rotate(x_rot, -degree)
if reflect:
x_rot = K.geometry.hflip(x_rot)
x_rot = self.crop(x_rot)
Expand Down Expand Up @@ -245,7 +248,7 @@ def get_optimization_specific_loss(self):
vectors = vectors.reshape(self.num_group, -1, self.out_vector_size).permute((1, 0, 2)) # (batch_size, group_size, vector_out_size)
distances = vectors @ vectors.permute((0, 2, 1))
mask = 1.0 - torch.eye(self.num_group).to(self.device) # (group_size, group_size)
return torch.abs(distances * mask).sum()
return torch.abs(distances * mask).mean()



506 changes: 450 additions & 56 deletions tutorials/images/test_continuous_canonicalization.ipynb

Large diffs are not rendered by default.

733 changes: 533 additions & 200 deletions tutorials/images/test_discrete_canonicalization.ipynb

Large diffs are not rendered by default.

0 comments on commit c42bb7b

Please sign in to comment.