Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
sibasmarak committed Mar 13, 2024
2 parents abf38d5 + e219bdf commit d69dab9
Show file tree
Hide file tree
Showing 22 changed files with 647 additions and 183 deletions.
11 changes: 10 additions & 1 deletion equiadapt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def gram_schmidt(vectors: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The orthogonalized vectors of the same shape as the input.
Examples:
>>> vectors = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]])
>>> result = gram_schmidt(vectors)
>>> print(result)
tensor([[[1.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000]]])
"""
v1 = vectors[:, 0]
v1 = v1 / torch.norm(v1, dim=1, keepdim=True)
Expand Down Expand Up @@ -188,7 +196,8 @@ def get_en_rep(
return en_rep

def get_group_rep(self, params: torch.Tensor) -> torch.Tensor:
"""Computes the representation for the specified Lie group.
"""
Computes the representation for the specified Lie group.
Args:
params (torch.Tensor): Input parameters of shape (batch_size, param_dim).
Expand Down
2 changes: 2 additions & 0 deletions equiadapt/pointcloud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""This package contains modules for the equiadapt pointcloud canonicalization."""

from equiadapt.pointcloud import canonicalization, canonicalization_networks
from equiadapt.pointcloud.canonicalization import (
ContinuousGroupPointcloudCanonicalization,
Expand Down
4 changes: 2 additions & 2 deletions equiadapt/pointcloud/canonicalization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from equiadapt.pointcloud.canonicalization import continuous_group
"""This module contains the pointcloud canonicalization methods."""

from equiadapt.pointcloud.canonicalization.continuous_group import (
ContinuousGroupPointcloudCanonicalization,
EquivariantPointcloudCanonicalization,
Expand All @@ -7,5 +8,4 @@
__all__ = [
"ContinuousGroupPointcloudCanonicalization",
"EquivariantPointcloudCanonicalization",
"continuous_group",
]
58 changes: 45 additions & 13 deletions equiadapt/pointcloud/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@


class ContinuousGroupPointcloudCanonicalization(ContinuousGroupCanonicalization):
"""
This class represents a continuous group point cloud canonicalization.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network.
canonicalization_hyperparams (DictConfig): The hyperparameters for canonicalization.
Attributes:
device: The device on which the operations are performed.
Methods:
get_groupelement: Maps the input point cloud to the group element.
canonicalize: Returns the canonicalized point cloud.
"""

def __init__(
self,
canonicalization_network: torch.nn.Module,
Expand All @@ -20,29 +35,33 @@ def __init__(

def get_groupelement(self, x: torch.Tensor) -> dict:
"""
This method takes the input image and
maps it to the group element
This method takes the input image and maps it to the group element.
Args:
x: input image
x (torch.Tensor): The input image.
Returns:
group_element: group element
dict: The group element.
Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError("get_groupelement method is not implemented")

def canonicalize(
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
This method takes an image as input and
returns the canonicalized image
This method takes an image as input and returns the canonicalized image.
Args:
x: input point cloud
x (torch.Tensor): The input point cloud.
targets (Optional[List]): The list of targets (optional).
**kwargs (Any): Additional keyword arguments.
Returns:
x_canonicalized: canonicalized point cloud
Union[torch.Tensor, Tuple[torch.Tensor, List]]: The canonicalized point cloud.
"""
self.device = x.device

Expand All @@ -63,6 +82,21 @@ def canonicalize(


class EquivariantPointcloudCanonicalization(ContinuousGroupPointcloudCanonicalization):
"""
This class represents the equivariant point cloud canonicalization module.
It inherits from the ContinuousGroupPointcloudCanonicalization class.
Args:
canonicalization_network (torch.nn.Module): The canonicalization network module.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization.
Attributes:
canonicalization_network (torch.nn.Module): The canonicalization network module.
canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization.
canonicalization_info_dict (dict): A dictionary to store the canonicalization information.
"""

def __init__(
self,
canonicalization_network: torch.nn.Module,
Expand All @@ -72,16 +106,14 @@ def __init__(

def get_groupelement(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""
This method takes the input image and
maps it to the group element
This method takes the input image and maps it to the group element.
Args:
x: input point cloud
x (torch.Tensor): The input point cloud.
Returns:
group_element: group element
dict[str, torch.Tensor]: A dictionary containing the group element.
"""

group_element_dict = {}

# convert the group activations to one hot encoding of group element
Expand Down
2 changes: 2 additions & 0 deletions equiadapt/pointcloud/canonicalization_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""This package contains equivariant modules and networks for the equiadapt pointcloud canonicalization."""

from equiadapt.pointcloud.canonicalization_networks import (
equivariant_networks,
vector_neuron_layers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@


def knn(x: torch.Tensor, k: int) -> torch.Tensor:
"""
Performs k-nearest neighbors search on a given set of points.
Args:
x (torch.Tensor): The input tensor representing a set of points.
Shape: (batch_size, num_points, num_dimensions).
k (int): The number of nearest neighbors to find.
Returns:
torch.Tensor: The indices of the k nearest neighbors for each point in x.
Shape: (batch_size, num_points, k).
"""
inner = -2 * torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x**2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
Expand All @@ -24,6 +36,18 @@ def knn(x: torch.Tensor, k: int) -> torch.Tensor:
def get_graph_feature_cross(
x: torch.Tensor, k: int = 20, idx: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Computes the graph feature cross for a given input tensor.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, num_dims, num_points).
k (int, optional): The number of nearest neighbors to consider. Defaults to 20.
idx (torch.Tensor, optional): The indices of the nearest neighbors. Defaults to None.
Returns:
torch.Tensor: The computed graph feature cross tensor of shape (batch_size, num_dims*3, num_points, k).
"""
batch_size = x.size(0)
num_points = x.size(3)
x = x.view(batch_size, -1, num_points)
Expand Down Expand Up @@ -53,7 +77,38 @@ def get_graph_feature_cross(


class VNSmall(torch.nn.Module):
"""
VNSmall is a small variant of the vector neuron equivariant network used for canonicalization of point clouds.
Args:
hyperparams (DictConfig): Hyperparameters for the network.
Attributes:
n_knn (int): Number of nearest neighbors to consider.
pooling (str): Pooling type to use, either "max" or "mean".
conv_pos (VNLinearLeakyReLU): Convolutional layer for positional encoding.
conv1 (VNLinearLeakyReLU): First convolutional layer.
bn1 (VNBatchNorm): Batch normalization layer.
conv2 (VNLinearLeakyReLU): Second convolutional layer.
dropout (nn.Dropout): Dropout layer.
pool (Union[VNMaxPool, mean_pool]): Pooling layer.
Methods:
__init__: Initializes the VNSmall network.
forward: Forward pass of the VNSmall network.
"""

def __init__(self, hyperparams: DictConfig):
"""
Initialize the VN Small network.
Args:
hyperparams (DictConfig): A dictionary-like object containing hyperparameters.
Raises:
ValueError: If the specified pooling type is not supported.
"""
super().__init__()
self.n_knn = hyperparams.n_knn
self.pooling = hyperparams.pooling
Expand All @@ -70,11 +125,19 @@ def __init__(self, hyperparams: DictConfig):
else:
raise ValueError(f"Pooling type {self.pooling} not supported")

# Wild idea -- Just use a linear layer to predict the output
# self.conv = VNLinear(3, 12 // 3)

def forward(self, point_cloud: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the VNSmall network.
For every pointcloud in the batch, the network outputs three vectors that transform equivariantly with respect to SO3 group.
Args:
point_cloud (torch.Tensor): Input point cloud tensor of shape (batch_size, num_points, 3).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 3, 3).
"""
point_cloud = point_cloud.unsqueeze(1)
feat = get_graph_feature_cross(point_cloud, k=self.n_knn)
out = self.conv_pos(feat)
Expand All @@ -84,7 +147,4 @@ def forward(self, point_cloud: torch.Tensor) -> torch.Tensor:
out = self.conv2(out)
out = self.dropout(out)

# out = self.pool(self.conv(feat))
# out = self.dropout(out)

return out.mean(dim=-1)[:, :3]
Loading

0 comments on commit d69dab9

Please sign in to comment.