From 985c7d4d79b01e7cfcc2938df62953ad44c2d0b8 Mon Sep 17 00:00:00 2001 From: jikael Date: Thu, 14 Mar 2024 14:06:50 -0400 Subject: [PATCH] Changed some type annotations and cleaned up some imports. --- .../nbody/canonicalization/euclidean_group.py | 2 - .../custom_equivariant_networks.py | 2 +- examples/nbody/model.py | 5 +- tutorials/nbody/nbody.ipynb | 122 +++++++++--------- 4 files changed, 64 insertions(+), 67 deletions(-) diff --git a/equiadapt/nbody/canonicalization/euclidean_group.py b/equiadapt/nbody/canonicalization/euclidean_group.py index 883c07c..cda25fc 100644 --- a/equiadapt/nbody/canonicalization/euclidean_group.py +++ b/equiadapt/nbody/canonicalization/euclidean_group.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from omegaconf import DictConfig from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization @@ -22,7 +21,6 @@ class EuclideanGroupNBody(ContinuousGroupCanonicalization): def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: DictConfig, ) -> None: super().__init__(canonicalization_network) diff --git a/equiadapt/nbody/canonicalization_networks/custom_equivariant_networks.py b/equiadapt/nbody/canonicalization_networks/custom_equivariant_networks.py index 412bd48..dd983d8 100644 --- a/equiadapt/nbody/canonicalization_networks/custom_equivariant_networks.py +++ b/equiadapt/nbody/canonicalization_networks/custom_equivariant_networks.py @@ -15,7 +15,7 @@ class VNDeepSets(nn.Module): A class representing the VNDeepSets model. Args: - hyperparams: An object containing hyperparameters for the model. + hyperparams: A dictionary containing hyperparameters for the model. device (str): The device to run the model on. Defaults to "cuda" if available, otherwise "cpu". Attributes: diff --git a/examples/nbody/model.py b/examples/nbody/model.py index f8b86df..b65c308 100644 --- a/examples/nbody/model.py +++ b/examples/nbody/model.py @@ -1,7 +1,8 @@ +from typing import Any + import pytorch_lightning as pl import torch import torch.nn as nn -from omegaconf import DictConfig from equiadapt.nbody.canonicalization.euclidean_group import EuclideanGroupNBody from examples.nbody.model_utils import ( @@ -12,7 +13,7 @@ class NBodyPipeline(pl.LightningModule): - def __init__(self, hyperparams: DictConfig): + def __init__(self, hyperparams: Any): super().__init__() self.hyperparams = hyperparams self.prediction_network = get_prediction_network(hyperparams.pred_hyperparams) diff --git a/tutorials/nbody/nbody.ipynb b/tutorials/nbody/nbody.ipynb index e39103e..56e80ba 100644 --- a/tutorials/nbody/nbody.ipynb +++ b/tutorials/nbody/nbody.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -28,11 +28,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n" + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -119,18 +119,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 36, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/jikael/mcgill/comp396/equivariant-adaptation/EquivariantAdaptation/examples/nbody/prepare/nbody_data.py:114: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1670076225403/work/torch/csrc/utils/tensor_new.cpp:230.)\n", - " edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2)\n" - ] - } - ], + "outputs": [], "source": [ "nbody_data = NBodyDataModule(hyperparams)\n", "nbody_data.setup()\n", @@ -142,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -186,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -196,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -214,33 +205,40 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.67it/s, task_loss=1.89, loss=1.89]\n", - "Epoch 1: 100%|██████████| 30/30 [00:01<00:00, 17.58it/s, task_loss=0.135, loss=0.135]\n", - "Epoch 2: 100%|██████████| 30/30 [00:01<00:00, 16.49it/s, task_loss=0.0769, loss=0.0769]\n", - "Epoch 3: 100%|██████████| 30/30 [00:01<00:00, 16.13it/s, task_loss=0.0711, loss=0.0711]\n", - "Epoch 4: 100%|██████████| 30/30 [00:01<00:00, 15.57it/s, task_loss=0.0677, loss=0.0677]\n", - "Epoch 5: 100%|██████████| 30/30 [00:01<00:00, 16.25it/s, task_loss=0.065, loss=0.065] \n", - "Epoch 6: 100%|██████████| 30/30 [00:02<00:00, 10.09it/s, task_loss=0.0626, loss=0.0626]\n", - "Epoch 7: 100%|██████████| 30/30 [00:02<00:00, 14.06it/s, task_loss=0.0619, loss=0.0619]\n", - "Epoch 8: 100%|██████████| 30/30 [00:03<00:00, 9.61it/s, task_loss=0.0571, loss=0.0571]\n", - "Epoch 9: 100%|██████████| 30/30 [00:01<00:00, 16.93it/s, task_loss=0.0527, loss=0.0527]\n", - "Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 11.15it/s, task_loss=0.0502, loss=0.0502]\n", - "Epoch 11: 100%|██████████| 30/30 [00:02<00:00, 12.74it/s, task_loss=0.0451, loss=0.0451]\n", - "Epoch 12: 100%|██████████| 30/30 [00:01<00:00, 16.91it/s, task_loss=0.0431, loss=0.0431]\n", - "Epoch 13: 100%|██████████| 30/30 [00:01<00:00, 16.54it/s, task_loss=0.0418, loss=0.0418]\n", - "Epoch 14: 100%|██████████| 30/30 [00:01<00:00, 15.43it/s, task_loss=0.0373, loss=0.0373]\n", - "Epoch 15: 100%|██████████| 30/30 [00:02<00:00, 14.84it/s, task_loss=0.0333, loss=0.0333]\n", - "Epoch 16: 100%|██████████| 30/30 [00:02<00:00, 12.37it/s, task_loss=0.0323, loss=0.0323]\n", - "Epoch 17: 100%|██████████| 30/30 [00:02<00:00, 14.64it/s, task_loss=0.0305, loss=0.0305]\n", - "Epoch 18: 100%|██████████| 30/30 [00:02<00:00, 13.47it/s, task_loss=0.0292, loss=0.0292]\n", - "Epoch 19: 100%|██████████| 30/30 [00:01<00:00, 16.29it/s, task_loss=0.0301, loss=0.0301]\n" + "Epoch 0: 0%| | 0/30 [00:00