Skip to content

Commit

Permalink
Changed some type annotations and cleaned up some
Browse files Browse the repository at this point in the history
imports.
  • Loading branch information
jikaelgagnon committed Mar 14, 2024
1 parent ca63da7 commit 985c7d4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 67 deletions.
2 changes: 0 additions & 2 deletions equiadapt/nbody/canonicalization/euclidean_group.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig

from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization

Expand All @@ -22,7 +21,6 @@ class EuclideanGroupNBody(ContinuousGroupCanonicalization):
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: DictConfig,
) -> None:
super().__init__(canonicalization_network)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class VNDeepSets(nn.Module):
A class representing the VNDeepSets model.
Args:
hyperparams: An object containing hyperparameters for the model.
hyperparams: A dictionary containing hyperparameters for the model.
device (str): The device to run the model on. Defaults to "cuda" if available, otherwise "cpu".
Attributes:
Expand Down
5 changes: 3 additions & 2 deletions examples/nbody/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any

import pytorch_lightning as pl
import torch
import torch.nn as nn
from omegaconf import DictConfig

from equiadapt.nbody.canonicalization.euclidean_group import EuclideanGroupNBody
from examples.nbody.model_utils import (
Expand All @@ -12,7 +13,7 @@


class NBodyPipeline(pl.LightningModule):
def __init__(self, hyperparams: DictConfig):
def __init__(self, hyperparams: Any):
super().__init__()
self.hyperparams = hyperparams
self.prediction_network = get_prediction_network(hyperparams.pred_hyperparams)
Expand Down
122 changes: 60 additions & 62 deletions tutorials/nbody/nbody.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -28,11 +28,11 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
Expand All @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -92,7 +92,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -119,18 +119,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jikael/mcgill/comp396/equivariant-adaptation/EquivariantAdaptation/examples/nbody/prepare/nbody_data.py:114: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1670076225403/work/torch/csrc/utils/tensor_new.cpp:230.)\n",
" edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2)\n"
]
}
],
"outputs": [],
"source": [
"nbody_data = NBodyDataModule(hyperparams)\n",
"nbody_data.setup()\n",
Expand All @@ -142,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -186,7 +177,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -196,7 +187,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -214,33 +205,40 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.67it/s, task_loss=1.89, loss=1.89]\n",
"Epoch 1: 100%|██████████| 30/30 [00:01<00:00, 17.58it/s, task_loss=0.135, loss=0.135]\n",
"Epoch 2: 100%|██████████| 30/30 [00:01<00:00, 16.49it/s, task_loss=0.0769, loss=0.0769]\n",
"Epoch 3: 100%|██████████| 30/30 [00:01<00:00, 16.13it/s, task_loss=0.0711, loss=0.0711]\n",
"Epoch 4: 100%|██████████| 30/30 [00:01<00:00, 15.57it/s, task_loss=0.0677, loss=0.0677]\n",
"Epoch 5: 100%|██████████| 30/30 [00:01<00:00, 16.25it/s, task_loss=0.065, loss=0.065] \n",
"Epoch 6: 100%|██████████| 30/30 [00:02<00:00, 10.09it/s, task_loss=0.0626, loss=0.0626]\n",
"Epoch 7: 100%|██████████| 30/30 [00:02<00:00, 14.06it/s, task_loss=0.0619, loss=0.0619]\n",
"Epoch 8: 100%|██████████| 30/30 [00:03<00:00, 9.61it/s, task_loss=0.0571, loss=0.0571]\n",
"Epoch 9: 100%|██████████| 30/30 [00:01<00:00, 16.93it/s, task_loss=0.0527, loss=0.0527]\n",
"Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 11.15it/s, task_loss=0.0502, loss=0.0502]\n",
"Epoch 11: 100%|██████████| 30/30 [00:02<00:00, 12.74it/s, task_loss=0.0451, loss=0.0451]\n",
"Epoch 12: 100%|██████████| 30/30 [00:01<00:00, 16.91it/s, task_loss=0.0431, loss=0.0431]\n",
"Epoch 13: 100%|██████████| 30/30 [00:01<00:00, 16.54it/s, task_loss=0.0418, loss=0.0418]\n",
"Epoch 14: 100%|██████████| 30/30 [00:01<00:00, 15.43it/s, task_loss=0.0373, loss=0.0373]\n",
"Epoch 15: 100%|██████████| 30/30 [00:02<00:00, 14.84it/s, task_loss=0.0333, loss=0.0333]\n",
"Epoch 16: 100%|██████████| 30/30 [00:02<00:00, 12.37it/s, task_loss=0.0323, loss=0.0323]\n",
"Epoch 17: 100%|██████████| 30/30 [00:02<00:00, 14.64it/s, task_loss=0.0305, loss=0.0305]\n",
"Epoch 18: 100%|██████████| 30/30 [00:02<00:00, 13.47it/s, task_loss=0.0292, loss=0.0292]\n",
"Epoch 19: 100%|██████████| 30/30 [00:01<00:00, 16.29it/s, task_loss=0.0301, loss=0.0301]\n"
"Epoch 0: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.88it/s, task_loss=1.99, loss=1.99]\n",
"Epoch 1: 100%|██████████| 30/30 [00:02<00:00, 14.98it/s, task_loss=0.156, loss=0.156]\n",
"Epoch 2: 100%|██████████| 30/30 [00:02<00:00, 12.26it/s, task_loss=0.0819, loss=0.0819]\n",
"Epoch 3: 100%|██████████| 30/30 [00:02<00:00, 10.37it/s, task_loss=0.0717, loss=0.0717]\n",
"Epoch 4: 100%|██████████| 30/30 [00:02<00:00, 10.86it/s, task_loss=0.0692, loss=0.0692]\n",
"Epoch 5: 100%|██████████| 30/30 [00:02<00:00, 14.54it/s, task_loss=0.0663, loss=0.0663]\n",
"Epoch 6: 100%|██████████| 30/30 [00:03<00:00, 9.71it/s, task_loss=0.0616, loss=0.0616]\n",
"Epoch 7: 100%|██████████| 30/30 [00:01<00:00, 15.01it/s, task_loss=0.0614, loss=0.0614]\n",
"Epoch 8: 100%|██████████| 30/30 [00:02<00:00, 13.51it/s, task_loss=0.065, loss=0.065] \n",
"Epoch 9: 100%|██████████| 30/30 [00:02<00:00, 12.67it/s, task_loss=0.0527, loss=0.0527]\n",
"Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 14.25it/s, task_loss=0.0454, loss=0.0454]\n",
"Epoch 11: 100%|██████████| 30/30 [00:01<00:00, 15.25it/s, task_loss=0.0421, loss=0.0421]\n",
"Epoch 12: 100%|██████████| 30/30 [00:02<00:00, 13.32it/s, task_loss=0.0394, loss=0.0394]\n",
"Epoch 13: 100%|██████████| 30/30 [00:02<00:00, 11.99it/s, task_loss=0.0371, loss=0.0371]\n",
"Epoch 14: 100%|██████████| 30/30 [00:02<00:00, 13.31it/s, task_loss=0.0361, loss=0.0361]\n",
"Epoch 15: 100%|██████████| 30/30 [00:02<00:00, 11.45it/s, task_loss=0.034, loss=0.034] \n",
"Epoch 16: 100%|██████████| 30/30 [00:02<00:00, 14.67it/s, task_loss=0.0325, loss=0.0325]\n",
"Epoch 17: 100%|██████████| 30/30 [00:01<00:00, 15.99it/s, task_loss=0.0333, loss=0.0333]\n",
"Epoch 18: 100%|██████████| 30/30 [00:01<00:00, 15.80it/s, task_loss=0.0323, loss=0.0323]\n",
"Epoch 19: 100%|██████████| 30/30 [00:02<00:00, 14.05it/s, task_loss=0.0304, loss=0.0304]\n"
]
}
],
Expand Down Expand Up @@ -319,7 +317,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -340,33 +338,33 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 30/30 [00:02<00:00, 14.93it/s, task_loss=1.63, loss=1.63]\n",
"Epoch 1: 100%|██████████| 30/30 [00:01<00:00, 16.78it/s, task_loss=0.132, loss=0.132]\n",
"Epoch 2: 100%|██████████| 30/30 [00:02<00:00, 14.36it/s, task_loss=0.079, loss=0.079] \n",
"Epoch 3: 100%|██████████| 30/30 [00:01<00:00, 15.51it/s, task_loss=0.0703, loss=0.0703]\n",
"Epoch 4: 100%|██████████| 30/30 [00:02<00:00, 10.34it/s, task_loss=0.0675, loss=0.0675]\n",
"Epoch 5: 100%|██████████| 30/30 [00:01<00:00, 17.05it/s, task_loss=0.065, loss=0.065] \n",
"Epoch 6: 100%|██████████| 30/30 [00:03<00:00, 9.92it/s, task_loss=0.0624, loss=0.0624]\n",
"Epoch 7: 100%|██████████| 30/30 [00:01<00:00, 16.50it/s, task_loss=0.0604, loss=0.0604]\n",
"Epoch 8: 100%|██████████| 30/30 [00:01<00:00, 16.10it/s, task_loss=0.0591, loss=0.0591]\n",
"Epoch 9: 100%|██████████| 30/30 [00:02<00:00, 14.37it/s, task_loss=0.0542, loss=0.0542]\n",
"Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 12.92it/s, task_loss=0.0467, loss=0.0467]\n",
"Epoch 11: 100%|██████████| 30/30 [00:02<00:00, 12.71it/s, task_loss=0.043, loss=0.043] \n",
"Epoch 12: 100%|██████████| 30/30 [00:02<00:00, 11.36it/s, task_loss=0.0435, loss=0.0435]\n",
"Epoch 13: 100%|██████████| 30/30 [00:02<00:00, 13.87it/s, task_loss=0.0362, loss=0.0362]\n",
"Epoch 14: 100%|██████████| 30/30 [00:02<00:00, 10.66it/s, task_loss=0.0342, loss=0.0342]\n",
"Epoch 15: 100%|██████████| 30/30 [00:02<00:00, 13.40it/s, task_loss=0.0328, loss=0.0328]\n",
"Epoch 16: 100%|██████████| 30/30 [00:01<00:00, 17.59it/s, task_loss=0.0315, loss=0.0315]\n",
"Epoch 17: 100%|██████████| 30/30 [00:02<00:00, 13.15it/s, task_loss=0.0303, loss=0.0303]\n",
"Epoch 18: 100%|██████████| 30/30 [00:03<00:00, 9.49it/s, task_loss=0.029, loss=0.029] \n",
"Epoch 19: 100%|██████████| 30/30 [00:02<00:00, 12.47it/s, task_loss=0.0285, loss=0.0285]\n"
"Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.27it/s, task_loss=1.82, loss=1.82]\n",
"Epoch 1: 100%|██████████| 30/30 [00:02<00:00, 14.59it/s, task_loss=0.128, loss=0.128]\n",
"Epoch 2: 100%|██████████| 30/30 [00:02<00:00, 13.11it/s, task_loss=0.0774, loss=0.0774]\n",
"Epoch 3: 100%|██████████| 30/30 [00:02<00:00, 12.53it/s, task_loss=0.0698, loss=0.0698]\n",
"Epoch 4: 100%|██████████| 30/30 [00:02<00:00, 14.04it/s, task_loss=0.0679, loss=0.0679]\n",
"Epoch 5: 100%|██████████| 30/30 [00:02<00:00, 12.82it/s, task_loss=0.0754, loss=0.0754]\n",
"Epoch 6: 100%|██████████| 30/30 [00:02<00:00, 13.56it/s, task_loss=0.0639, loss=0.0639]\n",
"Epoch 7: 100%|██████████| 30/30 [00:02<00:00, 12.07it/s, task_loss=0.0603, loss=0.0603]\n",
"Epoch 8: 100%|██████████| 30/30 [00:02<00:00, 13.41it/s, task_loss=0.0554, loss=0.0554]\n",
"Epoch 9: 100%|██████████| 30/30 [00:01<00:00, 15.00it/s, task_loss=0.0514, loss=0.0514]\n",
"Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 12.85it/s, task_loss=0.0472, loss=0.0472]\n",
"Epoch 11: 100%|██████████| 30/30 [00:02<00:00, 10.31it/s, task_loss=0.0429, loss=0.0429]\n",
"Epoch 12: 100%|██████████| 30/30 [00:02<00:00, 13.50it/s, task_loss=0.0432, loss=0.0432]\n",
"Epoch 13: 100%|██████████| 30/30 [00:02<00:00, 12.77it/s, task_loss=0.0424, loss=0.0424]\n",
"Epoch 14: 100%|██████████| 30/30 [00:01<00:00, 16.99it/s, task_loss=0.0357, loss=0.0357]\n",
"Epoch 15: 100%|██████████| 30/30 [00:02<00:00, 12.02it/s, task_loss=0.0327, loss=0.0327]\n",
"Epoch 16: 100%|██████████| 30/30 [00:02<00:00, 13.01it/s, task_loss=0.032, loss=0.032] \n",
"Epoch 17: 100%|██████████| 30/30 [00:02<00:00, 13.76it/s, task_loss=0.0308, loss=0.0308]\n",
"Epoch 18: 100%|██████████| 30/30 [00:01<00:00, 15.62it/s, task_loss=0.0298, loss=0.0298]\n",
"Epoch 19: 100%|██████████| 30/30 [00:01<00:00, 16.15it/s, task_loss=0.0302, loss=0.0302]\n"
]
}
],
Expand Down

0 comments on commit 985c7d4

Please sign in to comment.