Skip to content

Commit

Permalink
added better docstring for pointcloud
Browse files Browse the repository at this point in the history
  • Loading branch information
arnab39 committed Mar 13, 2024
1 parent 6b76fb3 commit 440ff68
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 102 deletions.
58 changes: 45 additions & 13 deletions equiadapt/pointcloud/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@


class ContinuousGroupPointcloudCanonicalization(ContinuousGroupCanonicalization):
"""
This class represents a continuous group point cloud canonicalization.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network.
canonicalization_hyperparams (DictConfig): The hyperparameters for canonicalization.
Attributes:
device: The device on which the operations are performed.
Methods:
get_groupelement: Maps the input point cloud to the group element.
canonicalize: Returns the canonicalized point cloud.
"""

def __init__(
self,
canonicalization_network: torch.nn.Module,
Expand All @@ -20,29 +35,33 @@ def __init__(

def get_groupelement(self, x: torch.Tensor) -> dict:
"""
This method takes the input image and
maps it to the group element
This method takes the input image and maps it to the group element.
Args:
x: input image
x (torch.Tensor): The input image.
Returns:
group_element: group element
dict: The group element.
Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError("get_groupelement method is not implemented")

def canonicalize(
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
This method takes an image as input and
returns the canonicalized image
This method takes an image as input and returns the canonicalized image.
Args:
x: input point cloud
x (torch.Tensor): The input point cloud.
targets (Optional[List]): The list of targets (optional).
**kwargs (Any): Additional keyword arguments.
Returns:
x_canonicalized: canonicalized point cloud
Union[torch.Tensor, Tuple[torch.Tensor, List]]: The canonicalized point cloud.
"""
self.device = x.device

Expand All @@ -63,6 +82,21 @@ def canonicalize(


class EquivariantPointcloudCanonicalization(ContinuousGroupPointcloudCanonicalization):
"""
This class represents the equivariant point cloud canonicalization module.
It inherits from the ContinuousGroupPointcloudCanonicalization class.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network module.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization.
Attributes:
canonicalization_network (torch.nn.Module): The canonicalization network module.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization.
canonicalization_info_dict (dict): A dictionary to store the canonicalization information.
"""

def __init__(
self,
canonicalization_network: torch.nn.Module,
Expand All @@ -72,16 +106,14 @@ def __init__(

def get_groupelement(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""
This method takes the input image and
maps it to the group element
This method takes the input image and maps it to the group element.
Args:
x: input point cloud
x (torch.Tensor): The input point cloud.
Returns:
group_element: group element
dict[str, torch.Tensor]: A dictionary containing the group element.
"""

group_element_dict = {}

# convert the group activations to one hot encoding of group element
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,22 @@ class VNSmall(torch.nn.Module):
dropout (nn.Dropout): Dropout layer.
pool (Union[VNMaxPool, mean_pool]): Pooling layer.
Methods:
__init__: Initializes the VNSmall network.
forward: Forward pass of the VNSmall network.
"""

def __init__(self, hyperparams: DictConfig):
"""
Initialize the VN Small network.
Args:
hyperparams (DictConfig): A dictionary-like object containing hyperparameters.
Raises:
ValueError: If the specified pooling type is not supported.
"""
super().__init__()
self.n_knn = hyperparams.n_knn
self.pooling = hyperparams.pooling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class VNLinear(nn.Module):
Vector Neuron Linear layer.
This layer applies a linear transformation to the input tensor.
Methods:
__init__: Initializes the VNLinear layer.
forward: Performs forward pass of the VNLinear layer.
"""

def __init__(self, in_channels: int, out_channels: int):
Expand Down Expand Up @@ -50,6 +54,10 @@ class VNBilinear(nn.Module):
Vector Neuron Bilinear layer.
VNBilinear applies a bilinear layer to the input features.
Methods:
__init__: Initializes the VNBilinear layer.
forward: Performs forward pass of the VNBilinear layer.
"""

def __init__(self, in_channels1: int, in_channels2: int, out_channels: int):
Expand Down Expand Up @@ -87,6 +95,10 @@ class VNSoftplus(nn.Module):
Vector Neuron Softplus layer.
VNSoftplus applies a softplus activation to the input features.
Methods:
__init__: Initializes the VNSoftplus layer.
forward: Performs forward pass of the VNSoftplus layer.
"""

def __init__(
Expand Down Expand Up @@ -144,6 +156,10 @@ class VNLeakyReLU(nn.Module):
Vector Neuron Leaky ReLU layer.
VNLLeakyReLU applies a LeakyReLU activation to the input features.
Methods:
__init__: Initializes the VNLeakyReLU layer.
forward: Performs forward pass of the VNLeakyReLU layer.
"""

def __init__(
Expand Down Expand Up @@ -196,6 +212,10 @@ class VNLinearLeakyReLU(nn.Module):
Vector Neuron Linear Leaky ReLU layer.
VNLinearLeakyReLU applies a linear transformation followed by a LeakyReLU activation to the input features.
Methods:
__init__: Initializes the VNLinearLeakyReLU layer.
forward: Performs forward pass of the VNLinearLeakyReLU layer.
"""

def __init__(
Expand Down Expand Up @@ -258,6 +278,10 @@ class VNBatchNorm(nn.Module):
Vector Neuron Batch Normalization layer.
VNBatchNorm applies batch normalization to the input features.
Methods:
__init__: Initializes the VNBatchNorm layer.
forward: Performs forward pass of the VNBatchNorm layer.
"""

def __init__(self, num_features: int, dim: int):
Expand Down Expand Up @@ -305,6 +329,10 @@ class VNMaxPool(nn.Module):
Vector Neuron Max Pooling layer.
VNMaxPool applies max pooling to the input features.
Methods:
__init__: Initializes the VNMaxPool layer.
forward: Performs forward pass of the VNMaxPool layer.
"""

def __init__(self, in_channels: int):
Expand Down Expand Up @@ -360,15 +388,13 @@ class VNStdFeature(nn.Module):
It takes point features as input and applies a series of VNLinearLeakyReLU layers
followed by a linear layer to produce the standard features.
Args:
in_channels (int): Number of input channels.
dim (int, optional): Dimension of the output features. Defaults to 4.
normalize_frame (bool, optional): Whether to normalize the frame. Defaults to False.
share_nonlinearity (bool, optional): Whether to share the nonlinearity across layers. Defaults to False.
negative_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.2.
Attributes:
dim (int): Dimension of the input features.
normalize_frame (bool): Whether to normalize the frame.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple containing the standard features and the frame vectors.
Methods:
__init__: Initializes the VNStdFeature module.
forward: Performs forward pass of the VNStdFeature module.
Shape:
- Input: (B, N_feat, 3, N_samples, ...)
Expand Down
2 changes: 1 addition & 1 deletion examples/images/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ checkpoint.checkpoint_path=/path/of/checkpoint/dir checkpoint.checkpoint_name=<n
```

**Note**:
**Note**:
The final checkpoint that will be loaded during evaluation as follows, hence ensure that the combination:
```
model = ImageClassifierPipeline.load_from_checkpoint(
Expand Down
72 changes: 49 additions & 23 deletions examples/images/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,51 @@
import omegaconf
import pytorch_lightning as pl
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.loggers import WandbLogger
from train_utils import get_model_data_and_callbacks, get_trainer, load_envs

import wandb

def train_images(hyperparams: DictConfig) -> None:

if hyperparams["experiment"]["run_mode"] == "test":
assert (
len(hyperparams["checkpoint"]["checkpoint_name"]) > 0
), "checkpoint_name must be provided for test mode"

def train_images(hyperparams: DictConfig):

if hyperparams['experiment']['run_mode'] == "test":
assert len(hyperparams['checkpoint']['checkpoint_name']) > 0, "checkpoint_name must be provided for test mode"

existing_ckpt_path = hyperparams['checkpoint']['checkpoint_path'] + "/" + hyperparams['checkpoint']['checkpoint_name'] + ".ckpt"
existing_ckpt_path = (
hyperparams["checkpoint"]["checkpoint_path"]
+ "/"
+ hyperparams["checkpoint"]["checkpoint_name"]
+ ".ckpt"
)
existing_ckpt = torch.load(existing_ckpt_path)
conf = OmegaConf.create(existing_ckpt['hyper_parameters']['hyperparams'])

hyperparams['canonicalization_type'] = conf['canonicalization_type']
hyperparams['canonicalization'] = conf['canonicalization']
hyperparams['prediction'] = conf['prediction']

else:
hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type']
hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
hyperparams['dataset']['data_path'] = hyperparams['dataset']['data_path'] + "/" + hyperparams['dataset']['dataset_name']
hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \
hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \
+ "/" + hyperparams['prediction']['prediction_network_architecture']
conf = OmegaConf.create(existing_ckpt["hyper_parameters"]["hyperparams"])

hyperparams["canonicalization_type"] = conf["canonicalization_type"]
hyperparams["canonicalization"] = conf["canonicalization"]
hyperparams["prediction"] = conf["prediction"]

else:
hyperparams["canonicalization_type"] = hyperparams["canonicalization"][
"canonicalization_type"
]
hyperparams["device"] = "cuda" if torch.cuda.is_available() else "cpu"
hyperparams["dataset"]["data_path"] = (
hyperparams["dataset"]["data_path"]
+ "/"
+ hyperparams["dataset"]["dataset_name"]
)
hyperparams["checkpoint"]["checkpoint_path"] = (
hyperparams["checkpoint"]["checkpoint_path"]
+ "/"
+ hyperparams["dataset"]["dataset_name"]
+ "/"
+ hyperparams["canonicalization_type"]
+ "/"
+ hyperparams["prediction"]["prediction_network_architecture"]
)

# set system environment variables for wandb
if hyperparams["wandb"]["use_wandb"]:
Expand All @@ -53,8 +71,16 @@ def train_images(hyperparams: DictConfig):
project=hyperparams["wandb"]["wandb_project"], log_model="all"
)

if not hyperparams['experiment']['run_mode'] == "test":
hyperparams['checkpoint']['checkpoint_name'] = wandb_run.id + "_" + wandb_run.name + "_" + wandb_run.sweep_id + "_" + wandb_run.group
if not hyperparams["experiment"]["run_mode"] == "test":
hyperparams["checkpoint"]["checkpoint_name"] = (
wandb_run.id
+ "_"
+ wandb_run.name
+ "_"
+ wandb_run.sweep_id
+ "_"
+ wandb_run.group
)

# set seed
pl.seed_everything(hyperparams.experiment.seed)
Expand Down Expand Up @@ -86,7 +112,7 @@ def train_images(hyperparams: DictConfig):


@hydra.main(config_path=str("./configs/"), config_name="default")
def main(cfg: omegaconf.DictConfig):
def main(cfg: omegaconf.DictConfig) -> None:
train_images(cfg)


Expand Down
14 changes: 7 additions & 7 deletions examples/images/classification/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint


def get_model_data_and_callbacks(hyperparams: DictConfig):
def get_model_data_and_callbacks(hyperparams: DictConfig) -> tuple:

# get image data
image_data = get_image_data(hyperparams.dataset)

# checkpoint callbacks
callbacks = get_callbacks(hyperparams)

Expand All @@ -28,7 +28,7 @@ def get_model_data_and_callbacks(hyperparams: DictConfig):
return model, image_data, callbacks


def get_model_pipeline(hyperparams: DictConfig):
def get_model_pipeline(hyperparams: DictConfig) -> pl.LightningModule:

if hyperparams.experiment.run_mode == "test":
model = ImageClassifierPipeline.load_from_checkpoint(
Expand All @@ -48,7 +48,7 @@ def get_model_pipeline(hyperparams: DictConfig):

def get_trainer(
hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger
):
) -> pl.Trainer:
if hyperparams.experiment.run_mode == "dryrun":
trainer = pl.Trainer(
fast_dev_run=5,
Expand All @@ -75,7 +75,7 @@ def get_trainer(
return trainer


def get_callbacks(hyperparams: DictConfig):
def get_callbacks(hyperparams: DictConfig) -> list:

checkpoint_callback = ModelCheckpoint(
dirpath=hyperparams.checkpoint.checkpoint_path,
Expand All @@ -93,9 +93,9 @@ def get_callbacks(hyperparams: DictConfig):
)

return [checkpoint_callback, early_stop_metric_callback]


def get_image_data(dataset_hyperparams: DictConfig):

def get_image_data(dataset_hyperparams: DictConfig) -> pl.LightningDataModule:

dataset_classes = {
"rotated_mnist": RotatedMNISTDataModule,
Expand Down
Loading

0 comments on commit 440ff68

Please sign in to comment.