From 440ff68c06c1aaabd45cbc0c69589666928c6934 Mon Sep 17 00:00:00 2001 From: arnab39 Date: Wed, 13 Mar 2024 13:00:54 -0400 Subject: [PATCH] added better docstring for pointcloud --- .../canonicalization/continuous_group.py | 58 +++++++++++---- .../equivariant_networks.py | 13 ++++ .../vector_neuron_layers.py | 42 +++++++++-- examples/images/classification/README.md | 2 +- examples/images/classification/train.py | 72 ++++++++++++------ examples/images/classification/train_utils.py | 14 ++-- examples/images/segmentation/README.md | 6 +- .../images/segmentation/inference_utils.py | 26 ++++--- examples/images/segmentation/train.py | 74 +++++++++++++------ examples/images/segmentation/train_utils.py | 13 ++-- examples/pointcloud/classification/train.py | 7 +- .../pointcloud/part_segmentation/train.py | 7 +- 12 files changed, 232 insertions(+), 102 deletions(-) diff --git a/equiadapt/pointcloud/canonicalization/continuous_group.py b/equiadapt/pointcloud/canonicalization/continuous_group.py index 2299bf9..31fe138 100644 --- a/equiadapt/pointcloud/canonicalization/continuous_group.py +++ b/equiadapt/pointcloud/canonicalization/continuous_group.py @@ -11,6 +11,21 @@ class ContinuousGroupPointcloudCanonicalization(ContinuousGroupCanonicalization): + """ + This class represents a continuous group point cloud canonicalization. + + Args: + canonicalization_network (torch.nn.Module): The canonicalization network. + canonicalization_hyperparams (DictConfig): The hyperparameters for canonicalization. + + Attributes: + device: The device on which the operations are performed. + + Methods: + get_groupelement: Maps the input point cloud to the group element. + canonicalize: Returns the canonicalized point cloud. + """ + def __init__( self, canonicalization_network: torch.nn.Module, @@ -20,14 +35,16 @@ def __init__( def get_groupelement(self, x: torch.Tensor) -> dict: """ - This method takes the input image and - maps it to the group element + This method takes the input image and maps it to the group element. Args: - x: input image + x (torch.Tensor): The input image. Returns: - group_element: group element + dict: The group element. + + Raises: + NotImplementedError: If the method is not implemented. """ raise NotImplementedError("get_groupelement method is not implemented") @@ -35,14 +52,16 @@ def canonicalize( self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ - This method takes an image as input and - returns the canonicalized image + This method takes an image as input and returns the canonicalized image. Args: - x: input point cloud + x (torch.Tensor): The input point cloud. + targets (Optional[List]): The list of targets (optional). + **kwargs (Any): Additional keyword arguments. Returns: - x_canonicalized: canonicalized point cloud + Union[torch.Tensor, Tuple[torch.Tensor, List]]: The canonicalized point cloud. + """ self.device = x.device @@ -63,6 +82,21 @@ def canonicalize( class EquivariantPointcloudCanonicalization(ContinuousGroupPointcloudCanonicalization): + """ + This class represents the equivariant point cloud canonicalization module. + + It inherits from the ContinuousGroupPointcloudCanonicalization class. + + Args: + canonicalization_network (torch.nn.Module): The canonicalization network module. + canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization. + + Attributes: + canonicalization_network (torch.nn.Module): The canonicalization network module. + canonicalization_hyperparams (DictConfig): The hyperparameters for the canonicalization. + canonicalization_info_dict (dict): A dictionary to store the canonicalization information. + """ + def __init__( self, canonicalization_network: torch.nn.Module, @@ -72,16 +106,14 @@ def __init__( def get_groupelement(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """ - This method takes the input image and - maps it to the group element + This method takes the input image and maps it to the group element. Args: - x: input point cloud + x (torch.Tensor): The input point cloud. Returns: - group_element: group element + dict[str, torch.Tensor]: A dictionary containing the group element. """ - group_element_dict = {} # convert the group activations to one hot encoding of group element diff --git a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py index 98ec49f..3a8bb65 100644 --- a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py +++ b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py @@ -93,9 +93,22 @@ class VNSmall(torch.nn.Module): dropout (nn.Dropout): Dropout layer. pool (Union[VNMaxPool, mean_pool]): Pooling layer. + Methods: + __init__: Initializes the VNSmall network. + forward: Forward pass of the VNSmall network. + """ def __init__(self, hyperparams: DictConfig): + """ + Initialize the VN Small network. + + Args: + hyperparams (DictConfig): A dictionary-like object containing hyperparameters. + + Raises: + ValueError: If the specified pooling type is not supported. + """ super().__init__() self.n_knn = hyperparams.n_knn self.pooling = hyperparams.pooling diff --git a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py index 19dcbf2..eacc486 100644 --- a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py +++ b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py @@ -18,6 +18,10 @@ class VNLinear(nn.Module): Vector Neuron Linear layer. This layer applies a linear transformation to the input tensor. + + Methods: + __init__: Initializes the VNLinear layer. + forward: Performs forward pass of the VNLinear layer. """ def __init__(self, in_channels: int, out_channels: int): @@ -50,6 +54,10 @@ class VNBilinear(nn.Module): Vector Neuron Bilinear layer. VNBilinear applies a bilinear layer to the input features. + + Methods: + __init__: Initializes the VNBilinear layer. + forward: Performs forward pass of the VNBilinear layer. """ def __init__(self, in_channels1: int, in_channels2: int, out_channels: int): @@ -87,6 +95,10 @@ class VNSoftplus(nn.Module): Vector Neuron Softplus layer. VNSoftplus applies a softplus activation to the input features. + + Methods: + __init__: Initializes the VNSoftplus layer. + forward: Performs forward pass of the VNSoftplus layer. """ def __init__( @@ -144,6 +156,10 @@ class VNLeakyReLU(nn.Module): Vector Neuron Leaky ReLU layer. VNLLeakyReLU applies a LeakyReLU activation to the input features. + + Methods: + __init__: Initializes the VNLeakyReLU layer. + forward: Performs forward pass of the VNLeakyReLU layer. """ def __init__( @@ -196,6 +212,10 @@ class VNLinearLeakyReLU(nn.Module): Vector Neuron Linear Leaky ReLU layer. VNLinearLeakyReLU applies a linear transformation followed by a LeakyReLU activation to the input features. + + Methods: + __init__: Initializes the VNLinearLeakyReLU layer. + forward: Performs forward pass of the VNLinearLeakyReLU layer. """ def __init__( @@ -258,6 +278,10 @@ class VNBatchNorm(nn.Module): Vector Neuron Batch Normalization layer. VNBatchNorm applies batch normalization to the input features. + + Methods: + __init__: Initializes the VNBatchNorm layer. + forward: Performs forward pass of the VNBatchNorm layer. """ def __init__(self, num_features: int, dim: int): @@ -305,6 +329,10 @@ class VNMaxPool(nn.Module): Vector Neuron Max Pooling layer. VNMaxPool applies max pooling to the input features. + + Methods: + __init__: Initializes the VNMaxPool layer. + forward: Performs forward pass of the VNMaxPool layer. """ def __init__(self, in_channels: int): @@ -360,15 +388,13 @@ class VNStdFeature(nn.Module): It takes point features as input and applies a series of VNLinearLeakyReLU layers followed by a linear layer to produce the standard features. - Args: - in_channels (int): Number of input channels. - dim (int, optional): Dimension of the output features. Defaults to 4. - normalize_frame (bool, optional): Whether to normalize the frame. Defaults to False. - share_nonlinearity (bool, optional): Whether to share the nonlinearity across layers. Defaults to False. - negative_slope (float, optional): Negative slope of the LeakyReLU activation function. Defaults to 0.2. + Attributes: + dim (int): Dimension of the input features. + normalize_frame (bool): Whether to normalize the frame. - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple containing the standard features and the frame vectors. + Methods: + __init__: Initializes the VNStdFeature module. + forward: Performs forward pass of the VNStdFeature module. Shape: - Input: (B, N_feat, 3, N_samples, ...) diff --git a/examples/images/classification/README.md b/examples/images/classification/README.md index 0fd39ce..26d1430 100644 --- a/examples/images/classification/README.md +++ b/examples/images/classification/README.md @@ -21,7 +21,7 @@ checkpoint.checkpoint_path=/path/of/checkpoint/dir checkpoint.checkpoint_name= None: + + if hyperparams["experiment"]["run_mode"] == "test": + assert ( + len(hyperparams["checkpoint"]["checkpoint_name"]) > 0 + ), "checkpoint_name must be provided for test mode" -def train_images(hyperparams: DictConfig): - - if hyperparams['experiment']['run_mode'] == "test": - assert len(hyperparams['checkpoint']['checkpoint_name']) > 0, "checkpoint_name must be provided for test mode" - - existing_ckpt_path = hyperparams['checkpoint']['checkpoint_path'] + "/" + hyperparams['checkpoint']['checkpoint_name'] + ".ckpt" + existing_ckpt_path = ( + hyperparams["checkpoint"]["checkpoint_path"] + + "/" + + hyperparams["checkpoint"]["checkpoint_name"] + + ".ckpt" + ) existing_ckpt = torch.load(existing_ckpt_path) - conf = OmegaConf.create(existing_ckpt['hyper_parameters']['hyperparams']) - - hyperparams['canonicalization_type'] = conf['canonicalization_type'] - hyperparams['canonicalization'] = conf['canonicalization'] - hyperparams['prediction'] = conf['prediction'] - - else: - hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] - hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' - hyperparams['dataset']['data_path'] = hyperparams['dataset']['data_path'] + "/" + hyperparams['dataset']['dataset_name'] - hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ - hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ - + "/" + hyperparams['prediction']['prediction_network_architecture'] + conf = OmegaConf.create(existing_ckpt["hyper_parameters"]["hyperparams"]) + + hyperparams["canonicalization_type"] = conf["canonicalization_type"] + hyperparams["canonicalization"] = conf["canonicalization"] + hyperparams["prediction"] = conf["prediction"] + + else: + hyperparams["canonicalization_type"] = hyperparams["canonicalization"][ + "canonicalization_type" + ] + hyperparams["device"] = "cuda" if torch.cuda.is_available() else "cpu" + hyperparams["dataset"]["data_path"] = ( + hyperparams["dataset"]["data_path"] + + "/" + + hyperparams["dataset"]["dataset_name"] + ) + hyperparams["checkpoint"]["checkpoint_path"] = ( + hyperparams["checkpoint"]["checkpoint_path"] + + "/" + + hyperparams["dataset"]["dataset_name"] + + "/" + + hyperparams["canonicalization_type"] + + "/" + + hyperparams["prediction"]["prediction_network_architecture"] + ) # set system environment variables for wandb if hyperparams["wandb"]["use_wandb"]: @@ -53,8 +71,16 @@ def train_images(hyperparams: DictConfig): project=hyperparams["wandb"]["wandb_project"], log_model="all" ) - if not hyperparams['experiment']['run_mode'] == "test": - hyperparams['checkpoint']['checkpoint_name'] = wandb_run.id + "_" + wandb_run.name + "_" + wandb_run.sweep_id + "_" + wandb_run.group + if not hyperparams["experiment"]["run_mode"] == "test": + hyperparams["checkpoint"]["checkpoint_name"] = ( + wandb_run.id + + "_" + + wandb_run.name + + "_" + + wandb_run.sweep_id + + "_" + + wandb_run.group + ) # set seed pl.seed_everything(hyperparams.experiment.seed) @@ -86,7 +112,7 @@ def train_images(hyperparams: DictConfig): @hydra.main(config_path=str("./configs/"), config_name="default") -def main(cfg: omegaconf.DictConfig): +def main(cfg: omegaconf.DictConfig) -> None: train_images(cfg) diff --git a/examples/images/classification/train_utils.py b/examples/images/classification/train_utils.py index 3d76edb..c25e1b4 100644 --- a/examples/images/classification/train_utils.py +++ b/examples/images/classification/train_utils.py @@ -14,11 +14,11 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -def get_model_data_and_callbacks(hyperparams: DictConfig): +def get_model_data_and_callbacks(hyperparams: DictConfig) -> tuple: # get image data image_data = get_image_data(hyperparams.dataset) - + # checkpoint callbacks callbacks = get_callbacks(hyperparams) @@ -28,7 +28,7 @@ def get_model_data_and_callbacks(hyperparams: DictConfig): return model, image_data, callbacks -def get_model_pipeline(hyperparams: DictConfig): +def get_model_pipeline(hyperparams: DictConfig) -> pl.LightningModule: if hyperparams.experiment.run_mode == "test": model = ImageClassifierPipeline.load_from_checkpoint( @@ -48,7 +48,7 @@ def get_model_pipeline(hyperparams: DictConfig): def get_trainer( hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger -): +) -> pl.Trainer: if hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( fast_dev_run=5, @@ -75,7 +75,7 @@ def get_trainer( return trainer -def get_callbacks(hyperparams: DictConfig): +def get_callbacks(hyperparams: DictConfig) -> list: checkpoint_callback = ModelCheckpoint( dirpath=hyperparams.checkpoint.checkpoint_path, @@ -93,9 +93,9 @@ def get_callbacks(hyperparams: DictConfig): ) return [checkpoint_callback, early_stop_metric_callback] - -def get_image_data(dataset_hyperparams: DictConfig): + +def get_image_data(dataset_hyperparams: DictConfig) -> pl.LightningDataModule: dataset_classes = { "rotated_mnist": RotatedMNISTDataModule, diff --git a/examples/images/segmentation/README.md b/examples/images/segmentation/README.md index be9b4f3..9c2bd50 100644 --- a/examples/images/segmentation/README.md +++ b/examples/images/segmentation/README.md @@ -26,7 +26,7 @@ checkpoint.checkpoint_path=/path/of/checkpoint/dir checkpoint.checkpoint_name= Union[VanillaInference, GroupInference]: if inference_hyperparams.method == "vanilla": return VanillaInference(canonicalizer, prediction_network) elif inference_hyperparams.method == "group": @@ -33,7 +33,9 @@ def __init__( self.canonicalizer = canonicalizer self.prediction_network = prediction_network - def forward(self, x, targets): + def forward( + self, x: torch.Tensor, targets: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) @@ -44,9 +46,9 @@ def forward(self, x, targets): # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model return self.prediction_network(x_canonicalized, targets_canonicalized) - def get_inference_metrics(self, x: torch.Tensor, targets: torch.Tensor): + def get_inference_metrics(self, x: torch.Tensor, targets: torch.Tensor) -> dict: # Forward pass through the prediction network - _, _, _, outputs = self.forward(x) + _, _, _, outputs = self.forward(x, targets) _map = MeanAveragePrecision(iou_type="segm") targets = [ @@ -75,7 +77,7 @@ def __init__( self, canonicalizer: torch.nn.Module, prediction_network: torch.nn.Module, - inference_hyperparams: Union[Dict, wandb.Config], + inference_hyperparams: DictConfig, in_shape: tuple = (3, 32, 32), ): @@ -90,7 +92,9 @@ def __init__( self.pad = transforms.Pad(math.ceil(in_shape[-2] * 0.4), padding_mode="edge") self.crop = transforms.CenterCrop((in_shape[-2], in_shape[-1])) - def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tensor): + def get_group_element_wise_maps( + self, images: torch.Tensor, targets: torch.Tensor + ) -> dict: map_dict = dict() image_width = images[0].shape[1] @@ -192,7 +196,9 @@ def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tenso return map_dict - def get_inference_metrics(self, images: torch.Tensor, targets: torch.Tensor): + def get_inference_metrics( + self, images: torch.Tensor, targets: torch.Tensor + ) -> dict: metrics = {} map_dict = self.get_group_element_wise_maps(images, targets) diff --git a/examples/images/segmentation/train.py b/examples/images/segmentation/train.py index 069e229..52d3364 100644 --- a/examples/images/segmentation/train.py +++ b/examples/images/segmentation/train.py @@ -4,34 +4,54 @@ import omegaconf import pytorch_lightning as pl import torch +import wandb from omegaconf import DictConfig, OmegaConf from pytorch_lightning.loggers import WandbLogger from train_utils import get_model_data_and_callbacks, get_trainer, load_envs -import wandb +def train_images(hyperparams: DictConfig) -> None: + + if hyperparams["experiment"]["run_mode"] == "test": + assert ( + len(hyperparams["checkpoint"]["checkpoint_name"]) > 0 + ), "checkpoint_name must be provided for test mode" -def train_images(hyperparams: DictConfig): - - if hyperparams['experiment']['run_mode'] == "test": - assert len(hyperparams['checkpoint']['checkpoint_name']) > 0, "checkpoint_name must be provided for test mode" - - existing_ckpt_path = hyperparams['checkpoint']['checkpoint_path'] + "/" + hyperparams['checkpoint']['checkpoint_name'] + ".ckpt" + existing_ckpt_path = ( + hyperparams["checkpoint"]["checkpoint_path"] + + "/" + + hyperparams["checkpoint"]["checkpoint_name"] + + ".ckpt" + ) existing_ckpt = torch.load(existing_ckpt_path) - conf = OmegaConf.create(existing_ckpt['hyper_parameters']['hyperparams']) - - hyperparams['canonicalization_type'] = conf['canonicalization_type'] - hyperparams['canonicalization'] = conf['canonicalization'] - hyperparams['prediction'] = conf['prediction'] - + conf = OmegaConf.create(existing_ckpt["hyper_parameters"]["hyperparams"]) + + hyperparams["canonicalization_type"] = conf["canonicalization_type"] + hyperparams["canonicalization"] = conf["canonicalization"] + hyperparams["prediction"] = conf["prediction"] + else: - hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] - hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' - hyperparams['dataset']['root_dir'] = hyperparams['dataset']['root_dir'] + "/" + hyperparams['dataset']['dataset_name'] - hyperparams['dataset']['ann_dir'] = hyperparams['dataset']['root_dir'] + "/" + "annotations" - hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ - hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ - + "/" + hyperparams['prediction']['prediction_network_architecture'] + hyperparams["canonicalization_type"] = hyperparams["canonicalization"][ + "canonicalization_type" + ] + hyperparams["device"] = "cuda" if torch.cuda.is_available() else "cpu" + hyperparams["dataset"]["root_dir"] = ( + hyperparams["dataset"]["root_dir"] + + "/" + + hyperparams["dataset"]["dataset_name"] + ) + hyperparams["dataset"]["ann_dir"] = ( + hyperparams["dataset"]["root_dir"] + "/" + "annotations" + ) + hyperparams["checkpoint"]["checkpoint_path"] = ( + hyperparams["checkpoint"]["checkpoint_path"] + + "/" + + hyperparams["dataset"]["dataset_name"] + + "/" + + hyperparams["canonicalization_type"] + + "/" + + hyperparams["prediction"]["prediction_network_architecture"] + ) # set system environment variables for wandb if hyperparams["wandb"]["use_wandb"]: @@ -54,8 +74,16 @@ def train_images(hyperparams: DictConfig): project=hyperparams["wandb"]["wandb_project"], log_model="all" ) - if not hyperparams['experiment']['run_mode'] == "test": - hyperparams['checkpoint']['checkpoint_name'] = wandb_run.id + "_" + wandb_run.name + "_" + wandb_run.sweep_id + "_" + wandb_run.group + if not hyperparams["experiment"]["run_mode"] == "test": + hyperparams["checkpoint"]["checkpoint_name"] = ( + wandb_run.id + + "_" + + wandb_run.name + + "_" + + wandb_run.sweep_id + + "_" + + wandb_run.group + ) # set seed pl.seed_everything(hyperparams.experiment.seed) @@ -87,7 +115,7 @@ def train_images(hyperparams: DictConfig): @hydra.main(config_path=str("./configs/"), config_name="default") -def main(cfg: omegaconf.DictConfig): +def main(cfg: omegaconf.DictConfig) -> None: train_images(cfg) diff --git a/examples/images/segmentation/train_utils.py b/examples/images/segmentation/train_utils.py index edbc689..b2df8f9 100644 --- a/examples/images/segmentation/train_utils.py +++ b/examples/images/segmentation/train_utils.py @@ -8,11 +8,11 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -def get_model_data_and_callbacks(hyperparams: DictConfig): +def get_model_data_and_callbacks(hyperparams: DictConfig) -> tuple: # get image data image_data = get_image_data(hyperparams.dataset) - + # checkpoint callbacks callbacks = get_callbacks(hyperparams) @@ -22,7 +22,7 @@ def get_model_data_and_callbacks(hyperparams: DictConfig): return model, image_data, callbacks -def get_model_pipeline(hyperparams: DictConfig): +def get_model_pipeline(hyperparams: DictConfig) -> pl.LightningModule: if hyperparams.experiment.run_mode == "test": model = ImageSegmentationPipeline.load_from_checkpoint( @@ -42,7 +42,7 @@ def get_model_pipeline(hyperparams: DictConfig): def get_trainer( hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger -): +) -> pl.Trainer: if hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( fast_dev_run=5, @@ -76,7 +76,7 @@ def get_trainer( return trainer -def get_callbacks(hyperparams: DictConfig): +def get_callbacks(hyperparams: DictConfig) -> list: checkpoint_callback = ModelCheckpoint( dirpath=hyperparams.checkpoint.checkpoint_path, @@ -95,7 +95,8 @@ def get_callbacks(hyperparams: DictConfig): return [checkpoint_callback, early_stop_metric_callback] -def get_image_data(dataset_hyperparams: DictConfig): + +def get_image_data(dataset_hyperparams: DictConfig) -> pl.LightningDataModule: dataset_classes = {"coco": COCODataModule} diff --git a/examples/pointcloud/classification/train.py b/examples/pointcloud/classification/train.py index 2a6ac9b..e78eaea 100644 --- a/examples/pointcloud/classification/train.py +++ b/examples/pointcloud/classification/train.py @@ -4,6 +4,7 @@ import omegaconf import pytorch_lightning as pl import torch +import wandb from omegaconf import DictConfig, OmegaConf from prepare import ModelNetDataModule from pytorch_lightning.loggers import WandbLogger @@ -15,10 +16,8 @@ load_envs, ) -import wandb - -def train_pointcloud(hyperparams: DictConfig): +def train_pointcloud(hyperparams: DictConfig) -> None: hyperparams["canonicalization_type"] = hyperparams["canonicalization"][ "canonicalization_type" ] @@ -94,7 +93,7 @@ def train_pointcloud(hyperparams: DictConfig): @hydra.main(config_path=str("./configs/"), config_name="default") -def main(cfg: omegaconf.DictConfig): +def main(cfg: omegaconf.DictConfig) -> None: train_pointcloud(cfg) diff --git a/examples/pointcloud/part_segmentation/train.py b/examples/pointcloud/part_segmentation/train.py index 4baad98..e9f6093 100644 --- a/examples/pointcloud/part_segmentation/train.py +++ b/examples/pointcloud/part_segmentation/train.py @@ -4,6 +4,7 @@ import omegaconf import pytorch_lightning as pl import torch +import wandb from omegaconf import DictConfig, OmegaConf from prepare import ShapeNetDataModule from pytorch_lightning.loggers import WandbLogger @@ -15,10 +16,8 @@ load_envs, ) -import wandb - -def train_pointcloud(hyperparams: DictConfig): +def train_pointcloud(hyperparams: DictConfig) -> None: hyperparams["canonicalization_type"] = hyperparams["canonicalization"][ "canonicalization_type" ] @@ -96,7 +95,7 @@ def train_pointcloud(hyperparams: DictConfig): @hydra.main(config_path=str("./configs/"), config_name="default") -def main(cfg: omegaconf.DictConfig): +def main(cfg: omegaconf.DictConfig) -> None: train_pointcloud(cfg)