Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.1.3 #23

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions equiadapt/common/basecanonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F

# Base skeleton for the canonicalization class
# DiscreteGroupCanonicalization and ContinuousGroupCanonicalization will inherit from this class
Expand Down Expand Up @@ -287,18 +288,41 @@ def invert_canonicalization(
"""
raise NotImplementedError()

def get_prior_regularization_loss(self) -> torch.Tensor:
# def get_prior_regularization_loss(self) -> torch.Tensor:
# """
# Gets the prior regularization loss.

# Returns:
# torch.Tensor: The prior regularization loss.
# """
# group_activations = self.canonicalization_info_dict["group_activations"]
# dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to(
# self.device
# )
# return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior)

def get_prior_regularization_loss(self, dataset_prior: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Gets the prior regularization loss.

Args:
dataset_prior (torch.Tensor, optional): The prior distribution. Defaults to identity prior.

Returns:
torch.Tensor: The prior regularization loss.
"""
group_activations = self.canonicalization_info_dict["group_activations"]
dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to(
self.device
)
return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior)

# If dataset_prior is not provided, create an identity prior
if dataset_prior is None:
dataset_prior = torch.zeros_like(group_activations).to(self.device)
dataset_prior[:, 0] = 1.0 # Set the first column to 1, rest to 0

# Ensure group_activations are in log space
log_group_activations = F.log_softmax(group_activations, dim=1)

# KL Divergence
return F.kl_div(log_group_activations, dataset_prior, reduction='batchmean')

def get_identity_metric(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -406,6 +430,7 @@ def get_prior_regularization_loss(self) -> torch.Tensor:
.to(self.device)
)
return torch.nn.MSELoss()(group_elements_rep, dataset_prior)


def get_identity_metric(self) -> torch.Tensor:
"""
Expand Down
171 changes: 108 additions & 63 deletions equiadapt/images/canonicalization/discrete_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,64 @@ def __init__(
if is_grayscale
else transforms.Resize(size=canonicalization_hyperparams.resize_shape)
)

# group augment specific cropping and padding (required for group_augment())
group_augment_in_shape = canonicalization_hyperparams.resize_shape
self.crop_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.CenterCrop(group_augment_in_shape)
)
self.pad_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.Pad(
math.ceil(group_augment_in_shape * 0.5), padding_mode="edge"
)
)

def rotate_and_maybe_reflect(
self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False
) -> List[torch.Tensor]:
"""
Rotate and maybe reflect the input images.

Args:
x (torch.Tensor): The input image.
degrees (torch.Tensor): The degrees of rotation.
reflect (bool, optional): Whether to reflect the image. Defaults to False.

Returns:
List[torch.Tensor]: The list of rotated and maybe reflected images.
"""
x_augmented_list = []
for degree in degrees:
x_rot = self.pad_group_augment(x)
x_rot = K.geometry.rotate(x_rot, -degree)
if reflect:
x_rot = K.geometry.hflip(x_rot)
x_rot = self.crop_group_augment(x_rot)
x_augmented_list.append(x_rot)
return x_augmented_list

def group_augment(self, x: torch.Tensor) -> torch.Tensor:
"""
Augment the input images by applying group transformations (rotations and reflections).
This function is used both for the energy based optimization method for the discrete rotation

Args:
x (torch.Tensor): The input image.

Returns:
torch.Tensor: The augmented image.
"""
degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device)
x_augmented_list = self.rotate_and_maybe_reflect(x, degrees)

if self.group_type == "roto-reflection":
x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True)

return torch.cat(x_augmented_list, dim=0)

def groupactivations_to_groupelement(self, group_activations: torch.Tensor) -> dict:
"""
Expand Down Expand Up @@ -257,6 +315,54 @@ def invert_canonicalization(
group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore
induced_rep_type=induced_rep_type,
)

def get_prior(
self,
x: torch.Tensor,
model: torch.nn.Module,
targets: torch.Tensor,
metric_function: torch.nn.Module,
tau: float = 1.0,
) -> torch.Tensor:
"""
Get the prior for the input images.

Args:
x (torch.Tensor): The input images. shape = (batch_size, in_channels, height, width)
model (torch.nn.Module): The prediction model which decides the prior.
targets (torch.Tensor): The targets for the task. shape = eg. (batch_size, num_classes)
metric_function (torch.nn.Module): The function to calculate the unnormalized probability masses for each group element.
tau (float, optional): The temperature parameter. Defaults to 1.0. Decides the sharpness of the prior distribution.

Returns:
torch.Tensor: output prior of the model and x. shape = (batch_size, group_size)
"""
with torch.no_grad():
batch_size = x.shape[0]
x_augmented = self.group_augment(x) # size (group_size * batch_size, in_channels, height, width)
# If a self.group_augment_target is defined, apply the same transformation to the targets
# Or else just repeat the targets for each group element in the first dimension
if hasattr(self, "group_augment_target"):
targets_augmented = self.group_augment_target(targets)
else:
targets_augmented = targets.repeat(self.num_group, 1).flatten() # size (group_size * batch_size)

# Get the output of the model for the augmented images
model_output = model(x_augmented) # size eg (group_size * batch_size, num_classes)

# Get the unnormalized probability masses for each group element
unnormalized_prob_masses = metric_function(
model_output, targets_augmented
).reshape(self.num_group, batch_size).transpose(0, 1) # size (batch_size, group_size)

# Get the prior for the input images
prior = F.softmax(unnormalized_prob_masses / tau, dim=-1) # size (batch_size, group_size)

return prior






class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization):
Expand Down Expand Up @@ -360,21 +466,6 @@ def __init__(
)
self.out_vector_size = canonicalization_network.out_vector_size

# group optimization specific cropping and padding (required for group_augment())
group_augment_in_shape = canonicalization_hyperparams.resize_shape
self.crop_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.CenterCrop(group_augment_in_shape)
)
self.pad_group_augment = (
torch.nn.Identity()
if in_shape[0] == 1
else transforms.Pad(
math.ceil(group_augment_in_shape * 0.5), padding_mode="edge"
)
)

self.reference_vector = torch.nn.Parameter(
torch.randn(1, self.out_vector_size),
requires_grad=canonicalization_hyperparams.learn_ref_vec,
Expand All @@ -384,48 +475,6 @@ def __init__(
"num_group": self.num_group,
}

def rotate_and_maybe_reflect(
self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False
) -> List[torch.Tensor]:
"""
Rotate and maybe reflect the input images.

Args:
x (torch.Tensor): The input image.
degrees (torch.Tensor): The degrees of rotation.
reflect (bool, optional): Whether to reflect the image. Defaults to False.

Returns:
List[torch.Tensor]: The list of rotated and maybe reflected images.
"""
x_augmented_list = []
for degree in degrees:
x_rot = self.pad_group_augment(x)
x_rot = K.geometry.rotate(x_rot, -degree)
if reflect:
x_rot = K.geometry.hflip(x_rot)
x_rot = self.crop_group_augment(x_rot)
x_augmented_list.append(x_rot)
return x_augmented_list

def group_augment(self, x: torch.Tensor) -> torch.Tensor:
"""
Augment the input images by applying group transformations (rotations and reflections).

Args:
x (torch.Tensor): The input image.

Returns:
torch.Tensor: The augmented image.
"""
degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device)
x_augmented_list = self.rotate_and_maybe_reflect(x, degrees)

if self.group_type == "roto-reflection":
x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True)

return torch.cat(x_augmented_list, dim=0)

def get_group_activations(self, x: torch.Tensor) -> torch.Tensor:
"""
Gets the group activations for the input image.
Expand All @@ -437,12 +486,8 @@ def get_group_activations(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor: The group activations.
"""
x = self.transformations_before_canonicalization_network_forward(x)
x_augmented = self.group_augment(
x
) # size (batch_size * group_size, in_channels, height, width)
vector_out = self.canonicalization_network(
x_augmented
) # size (batch_size * group_size, reference_vector_size)
x_augmented = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width)
vector_out = self.canonicalization_network(x_augmented) # size (batch_size * group_size, reference_vector_size)
self.canonicalization_info_dict = {"vector_out": vector_out}

if self.artifact_err_wt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ training:
loss:
task_weight: 1.0 # Weight of the task loss in the total loss
prior_weight: 100.0 # Weight of the prior in the loss function if zero dont use it
automated_prior: true # Whether to use automated prior weight or not
group_contrast_weight: 0 # Weight of the group contrastive loss (set to 0 for group_equivariant, 0.0001 for opt_equivariant)
inference:
method: group # Type of inference options 1) vanilla 2) group
Expand Down
14 changes: 10 additions & 4 deletions examples/images/classification/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from math import tau
import pytorch_lightning as pl
import torch
from inference_utils import get_inference_method
from model_utils import get_dataset_specific_info, get_prediction_network
from omegaconf import DictConfig
from torch.optim.lr_scheduler import MultiStepLR
from torch.nn import functional as F

from examples.images.common.utils import get_canonicalization_network, get_canonicalizer

Expand All @@ -13,9 +15,6 @@ class ImageClassifierPipeline(pl.LightningModule):
def __init__(self, hyperparams: DictConfig):
super().__init__()

self.loss, self.image_shape, self.num_classes = get_dataset_specific_info(
hyperparams.dataset.dataset_name
)
self.loss, self.image_shape, self.num_classes = get_dataset_specific_info(
hyperparams.dataset.dataset_name
)
Expand Down Expand Up @@ -104,7 +103,14 @@ def training_step(self, batch: torch.Tensor):

# Add prior regularization loss if the prior weight is non-zero
if self.hyperparams.experiment.training.loss.prior_weight:
prior_loss = self.canonicalizer.get_prior_regularization_loss()
if self.hyperparams.experiment.training.loss.automated_prior:
def metric_function(model_predictions, targets):
return -F.cross_entropy(model_predictions, targets, reduction='none')
prior = self.canonicalizer.get_prior(x, self.prediction_network, y, metric_function, tau=0.01)
prior_loss = self.canonicalizer.get_prior_regularization_loss(prior) # type: ignore
else:
prior_loss = self.canonicalizer.get_prior_regularization_loss()

loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight
metric_identity = self.canonicalizer.get_identity_metric()
training_metrics.update(
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
canonicalization_type: group_equivariant
network_type: e2cnn # Options for canonization method 1) e2cnn 2) custom 3) none
network_hyperparams:
kernel_size: 5 # Kernel size for the canonization network
out_channels: 32 # Number of output channels for the canonization network
num_layers: 3 # Number of layers in the canonization network
group_type: rotation # Type of group for the canonization network
num_rotations: 4 # Number of rotations for the canonization network
beta: 1.0 # Beta parameter for the canonization network
input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization
resize_shape: 64 # Resize shape for the input
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
canonicalization_type: identity
23 changes: 23 additions & 0 deletions examples/images/reinforcementlearning/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
canonicalization_type: group_equivariant # will be set later in training script
device: cuda

# metadata specialised for each experiment
core:
version: 0.0.1
tags:
- ${now:%Y-%m-%d}

hydra:
run:
dir: ${oc.env:HYDRA_JOBS}/singlerun/${now:%Y-%m-%d}/

sweep:
dir: ${oc.env:HYDRA_JOBS}/multirun/${now:%Y-%m-%d}/
subdir: ${hydra.job.num}_${hydra.job.id}

defaults:
- _self_
- env: default
- experiment: default
- canonicalization: identity
- wandb: default
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
learning_rate: 0.001
batch_size: 128
gamma: 0.999
eps_start: 0.9
eps_end: 0.01
eps_decay: 3000
target_update: 10
replay_memory_size: 100000
end_score: 200
training_stop: 142
num_episodes: 50000
last_episodes_num: 20
Empty file.
Loading
Loading