From 3434f3e9feab4fa5ae4364deb5ac679a79c64bee Mon Sep 17 00:00:00 2001 From: arnab39 Date: Wed, 13 Mar 2024 12:23:10 -0400 Subject: [PATCH] added docstring for pointcloud --- equiadapt/common/utils.py | 11 +- equiadapt/pointcloud/__init__.py | 2 + .../pointcloud/canonicalization/__init__.py | 4 +- .../canonicalization_networks/__init__.py | 2 + .../equivariant_networks.py | 59 ++++- .../vector_neuron_layers.py | 249 ++++++++++++++++-- 6 files changed, 299 insertions(+), 28 deletions(-) diff --git a/equiadapt/common/utils.py b/equiadapt/common/utils.py index c095690..bc9f6b0 100644 --- a/equiadapt/common/utils.py +++ b/equiadapt/common/utils.py @@ -29,6 +29,14 @@ def gram_schmidt(vectors: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The orthogonalized vectors of the same shape as the input. + + Examples: + >>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]) + >>> result = gram_schmidt(vectors) + >>> print(result) + tensor([[[1.0000, 0.0000, 0.0000], + [0.0000, 1.0000, 0.0000], + [0.0000, 0.0000, 1.0000]]]) """ v1 = vectors[:, 0] v1 = v1 / torch.norm(v1, dim=1, keepdim=True) @@ -188,7 +196,8 @@ def get_en_rep( return en_rep def get_group_rep(self, params: torch.Tensor) -> torch.Tensor: - """Computes the representation for the specified Lie group. + """ + Computes the representation for the specified Lie group. Args: params (torch.Tensor): Input parameters of shape (batch_size, param_dim). diff --git a/equiadapt/pointcloud/__init__.py b/equiadapt/pointcloud/__init__.py index 34b2988..e476c6a 100644 --- a/equiadapt/pointcloud/__init__.py +++ b/equiadapt/pointcloud/__init__.py @@ -1,3 +1,5 @@ +"""This package contains modules for the equiadapt pointcloud canonicalization.""" + from equiadapt.pointcloud import canonicalization, canonicalization_networks from equiadapt.pointcloud.canonicalization import ( ContinuousGroupPointcloudCanonicalization, diff --git a/equiadapt/pointcloud/canonicalization/__init__.py b/equiadapt/pointcloud/canonicalization/__init__.py index e16eea1..1542e44 100644 --- a/equiadapt/pointcloud/canonicalization/__init__.py +++ b/equiadapt/pointcloud/canonicalization/__init__.py @@ -1,4 +1,5 @@ -from equiadapt.pointcloud.canonicalization import continuous_group +"""This module contains the pointcloud canonicalization methods.""" + from equiadapt.pointcloud.canonicalization.continuous_group import ( ContinuousGroupPointcloudCanonicalization, EquivariantPointcloudCanonicalization, @@ -7,5 +8,4 @@ __all__ = [ "ContinuousGroupPointcloudCanonicalization", "EquivariantPointcloudCanonicalization", - "continuous_group", ] diff --git a/equiadapt/pointcloud/canonicalization_networks/__init__.py b/equiadapt/pointcloud/canonicalization_networks/__init__.py index cf888f1..b70c9ee 100644 --- a/equiadapt/pointcloud/canonicalization_networks/__init__.py +++ b/equiadapt/pointcloud/canonicalization_networks/__init__.py @@ -1,3 +1,5 @@ +"""This package contains equivariant modules and networks for the equiadapt pointcloud canonicalization.""" + from equiadapt.pointcloud.canonicalization_networks import ( equivariant_networks, vector_neuron_layers, diff --git a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py index 54cb0c0..98ec49f 100644 --- a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py +++ b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py @@ -13,6 +13,18 @@ def knn(x: torch.Tensor, k: int) -> torch.Tensor: + """ + Performs k-nearest neighbors search on a given set of points. + + Args: + x (torch.Tensor): The input tensor representing a set of points. + Shape: (batch_size, num_points, num_dimensions). + k (int): The number of nearest neighbors to find. + + Returns: + torch.Tensor: The indices of the k nearest neighbors for each point in x. + Shape: (batch_size, num_points, k). + """ inner = -2 * torch.matmul(x.transpose(2, 1), x) xx = torch.sum(x**2, dim=1, keepdim=True) pairwise_distance = -xx - inner - xx.transpose(2, 1) @@ -24,6 +36,18 @@ def knn(x: torch.Tensor, k: int) -> torch.Tensor: def get_graph_feature_cross( x: torch.Tensor, k: int = 20, idx: Optional[torch.Tensor] = None ) -> torch.Tensor: + """ + Computes the graph feature cross for a given input tensor. + + Args: + x (torch.Tensor): The input tensor of shape (batch_size, num_dims, num_points). + k (int, optional): The number of nearest neighbors to consider. Defaults to 20. + idx (torch.Tensor, optional): The indices of the nearest neighbors. Defaults to None. + + Returns: + torch.Tensor: The computed graph feature cross tensor of shape (batch_size, num_dims*3, num_points, k). + + """ batch_size = x.size(0) num_points = x.size(3) x = x.view(batch_size, -1, num_points) @@ -53,6 +77,24 @@ def get_graph_feature_cross( class VNSmall(torch.nn.Module): + """ + VNSmall is a small variant of the vector neuron equivariant network used for canonicalization of point clouds. + + Args: + hyperparams (DictConfig): Hyperparameters for the network. + + Attributes: + n_knn (int): Number of nearest neighbors to consider. + pooling (str): Pooling type to use, either "max" or "mean". + conv_pos (VNLinearLeakyReLU): Convolutional layer for positional encoding. + conv1 (VNLinearLeakyReLU): First convolutional layer. + bn1 (VNBatchNorm): Batch normalization layer. + conv2 (VNLinearLeakyReLU): Second convolutional layer. + dropout (nn.Dropout): Dropout layer. + pool (Union[VNMaxPool, mean_pool]): Pooling layer. + + """ + def __init__(self, hyperparams: DictConfig): super().__init__() self.n_knn = hyperparams.n_knn @@ -70,11 +112,19 @@ def __init__(self, hyperparams: DictConfig): else: raise ValueError(f"Pooling type {self.pooling} not supported") - # Wild idea -- Just use a linear layer to predict the output - # self.conv = VNLinear(3, 12 // 3) - def forward(self, point_cloud: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the VNSmall network. + + For every pointcloud in the batch, the network outputs three vectors that transform equivariantly with respect to SO3 group. + Args: + point_cloud (torch.Tensor): Input point cloud tensor of shape (batch_size, num_points, 3). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, 3, 3). + + """ point_cloud = point_cloud.unsqueeze(1) feat = get_graph_feature_cross(point_cloud, k=self.n_knn) out = self.conv_pos(feat) @@ -84,7 +134,4 @@ def forward(self, point_cloud: torch.Tensor) -> torch.Tensor: out = self.conv2(out) out = self.dropout(out) - # out = self.pool(self.conv(feat)) - # out = self.dropout(out) - return out.mean(dim=-1)[:, :3] diff --git a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py index f42878b..19dcbf2 100644 --- a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py +++ b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py @@ -1,6 +1,9 @@ -# Layers for vector neuron networks -# Taken from Vector Neurons: A General Framework for SO(3)-Equivariant Networks (https://arxiv.org/abs/2104.12229) paper and -# their codebase https://github.com/FlyingGiraffe/vnn +""" +Layers for vector neuron networks + +Taken from Vector Neurons: A General Framework for SO(3)-Equivariant Networks (https://arxiv.org/abs/2104.12229) paper and +their codebase https://github.com/FlyingGiraffe/vnn +""" from typing import Tuple @@ -11,28 +14,68 @@ class VNLinear(nn.Module): + """ + Vector Neuron Linear layer. + + This layer applies a linear transformation to the input tensor. + """ + def __init__(self, in_channels: int, out_channels: int): - super(VNLinear, self).__init__() + """ + Initializes a VNLinear layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + super().__init__() self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Performs forward pass of the VNLinear layer. + + Args: + x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. + + Returns: + torch.Tensor: Output tensor of shape [B, N_feat, 3, N_samples, ...]. """ x_out = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) return x_out class VNBilinear(nn.Module): + """ + Vector Neuron Bilinear layer. + + VNBilinear applies a bilinear layer to the input features. + """ + def __init__(self, in_channels1: int, in_channels2: int, out_channels: int): - super(VNBilinear, self).__init__() + """ + Initializes the VNBilinear layer. + + Args: + in_channels1 (int): Number of input channels for the first input. + in_channels2 (int): Number of input channels for the second input. + out_channels (int): Number of output channels. + """ + super().__init__() self.map_to_feat = nn.Bilinear( in_channels1, in_channels2, out_channels, bias=False ) def forward(self, x: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Forward pass of the VNBilinear layer. + + Args: + x (torch.Tensor): Input features of shape [B, N_feat, 3, N_samples, ...]. + labels (torch.Tensor): Labels of shape [B, N_feat, N_samples]. + + Returns: + torch.Tensor: Output features after applying the bilinear transformation. """ labels = labels.repeat(1, x.shape[2], 1).float() x_out = self.map_to_feat(x.transpose(1, -1), labels).transpose(1, -1) @@ -40,13 +83,28 @@ def forward(self, x: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: class VNSoftplus(nn.Module): + """ + Vector Neuron Softplus layer. + + VNSoftplus applies a softplus activation to the input features. + """ + def __init__( self, in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.0, ): - super(VNSoftplus, self).__init__() + """ + Initializes a VNSoftplus layer. + + Args: + in_channels (int): Number of input channels. + share_nonlinearity (bool): Whether to share the nonlinearity across channels. + negative_slope (float): Negative slope parameter for the LeakyReLU activation. + + """ + super().__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: @@ -55,7 +113,14 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Performs forward pass of the VNSoftplus layer. + + Args: + x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. + + Returns: + torch.Tensor: Output tensor of shape [B, N_feat, 3, N_samples, ...]. + """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdim=True) @@ -75,13 +140,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VNLeakyReLU(nn.Module): + """ + Vector Neuron Leaky ReLU layer. + + VNLLeakyReLU applies a LeakyReLU activation to the input features. + """ + def __init__( self, in_channels: int, share_nonlinearity: bool = False, negative_slope: float = 0.2, ): - super(VNLeakyReLU, self).__init__() + """ + Vector Neuron Leaky ReLU (VNLeakyReLU) module. + + Args: + in_channels (int): Number of input channels. + share_nonlinearity (bool, optional): Whether to share the nonlinearity across channels. + If True, a single linear layer is used to compute the direction. + If False, a separate linear layer is used for each channel. + Defaults to False. + negative_slope (float, optional): Negative slope of the Leaky ReLU activation. + Defaults to 0.2. + """ + super().__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) else: @@ -90,7 +173,13 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Forward pass of the VNLeakyReLU module. + + Args: + x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. + + Returns: + torch.Tensor: Output tensor after applying VNLeakyReLU activation. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdim=True) @@ -103,6 +192,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VNLinearLeakyReLU(nn.Module): + """ + Vector Neuron Linear Leaky ReLU layer. + + VNLinearLeakyReLU applies a linear transformation followed by a LeakyReLU activation to the input features. + """ + def __init__( self, in_channels: int, @@ -111,6 +206,16 @@ def __init__( share_nonlinearity: bool = False, negative_slope: float = 0.2, ): + """ + Vector Neuron Linear Leaky ReLU layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + dim (int, optional): Dimension of the input features. Defaults to 5. + share_nonlinearity (bool, optional): Whether to share the nonlinearity across channels. Defaults to False. + negative_slope (float, optional): Negative slope of the LeakyReLU activation. Defaults to 0.2. + """ super(VNLinearLeakyReLU, self).__init__() self.dim = dim self.negative_slope = negative_slope @@ -125,7 +230,13 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Forward pass of the VNLinearLeakyReLU layer. + + Args: + x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...] + + Returns: + torch.Tensor: Output tensor of shape [B, N_feat, 3, N_samples, ...] """ # Linear p = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) @@ -143,7 +254,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VNBatchNorm(nn.Module): + """ + Vector Neuron Batch Normalization layer. + + VNBatchNorm applies batch normalization to the input features. + """ + def __init__(self, num_features: int, dim: int): + """ + Vector Neuron Batch Normalization layer. + + Args: + num_features (int): Number of input features. + dim (int): Dimensionality of the input tensor. + + """ super(VNBatchNorm, self).__init__() self.dim = dim if dim == 3 or dim == 4: @@ -153,7 +278,14 @@ def __init__(self, num_features: int, dim: int): def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Forward pass of the Vector Neuron Batch Normalization layer. + + Args: + x (torch.Tensor): Input tensor of shape [B, N_feat, 3, N_samples, ...]. + + Returns: + torch.Tensor: Output tensor after applying batch normalization. + """ # norm = torch.sqrt((x*x).sum(2)) norm = torch.norm(x, dim=2) + EPS @@ -169,13 +301,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VNMaxPool(nn.Module): + """ + Vector Neuron Max Pooling layer. + + VNMaxPool applies max pooling to the input features. + """ + def __init__(self, in_channels: int): + """ + Initializes a VNMaxPool layer. + + Args: + in_channels (int): The number of input channels. + + """ super(VNMaxPool, self).__init__() self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Performs vector neuron max pooling on the input tensor. + + Args: + x (torch.Tensor): Point features of shape [B, N_feat, 3, N_samples, ...]. + + Returns: + torch.Tensor: Max pooled tensor of shape [B, N_feat, 3, N_samples, ...]. """ d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) dotprod = (x * d).sum(2, keepdims=True) @@ -186,10 +337,51 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def mean_pool(x: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor: + """ + Compute the mean pooling of a tensor along a specified dimension. + + Args: + x (torch.Tensor): The input tensor. + dim (int, optional): The dimension along which to compute the mean pooling. Default is -1. + keepdim (bool, optional): Whether to keep the dimension of the input tensor. Default is False. + + Returns: + torch.Tensor: The mean pooled tensor. + + """ return x.mean(dim=dim, keepdim=keepdim) class VNStdFeature(nn.Module): + """ + Vector Neuron Standard Feature module. + + This module performs standard feature extraction using Vector Neuron layers. + It takes point features as input and applies a series of VNLinearLeakyReLU layers + followed by a linear layer to produce the standard features. + + Args: + in_channels (int): Number of input channels. + dim (int, optional): Dimension of the output features. Defaults to 4. + normalize_frame (bool, optional): Whether to normalize the frame. Defaults to False. + share_nonlinearity (bool, optional): Whether to share the nonlinearity across layers. Defaults to False. + negative_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.2. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing the standard features and the frame vectors. + + Shape: + - Input: (B, N_feat, 3, N_samples, ...) + - Output: + - x_std: (B, N_feat, dim, N_samples, ...) + - z0: (B, dim, 3) + + Example: + >>> model = VNStdFeature(in_channels=64, dim=4, normalize_frame=True) + >>> input = torch.randn(2, 64, 3, 100) + >>> output, frame_vectors = model(input) + """ + def __init__( self, in_channels: int, @@ -198,6 +390,16 @@ def __init__( share_nonlinearity: bool = False, negative_slope: float = 0.2, ): + """ + Initializes the VNStdFeature layer. + + Args: + in_channels (int): Number of input channels. + dim (int, optional): Dimension of the input feature. Defaults to 4. + normalize_frame (bool, optional): Whether to normalize the frame. Defaults to False. + share_nonlinearity (bool, optional): Whether to share the nonlinearity across layers. Defaults to False. + negative_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.2. + """ super(VNStdFeature, self).__init__() self.dim = dim self.normalize_frame = normalize_frame @@ -223,7 +425,17 @@ def __init__( def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - x: point features of shape [B, N_feat, 3, N_samples, ...] + Forward pass of the VNStdFeature module. + + Args: + x (torch.Tensor): Input point features of shape (B, N_feat, 3, N_samples, ...). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing the standard features and the frame vectors. + + Note: + - The frame vectors are computed only if `normalize_frame` is set to True. + - The shape of the standard features depends on the value of `dim`. """ z0 = x z0 = self.vn1(z0) @@ -232,14 +444,13 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.normalize_frame: v1 = z0[:, 0, :] - v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True)) # type: ignore + v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True)) u1 = v1 / (v1_norm + EPS) v2 = z0[:, 1, :] - v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1 # type: ignore - v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True)) # type: ignore + v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1 + v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True)) u2 = v2 / (v2_norm + EPS) - # compute the cross product of the two output vectors u3 = torch.cross(u1, u2) z0 = torch.stack([u1, u2, u3], dim=1).transpose(1, 2) else: