Skip to content

Commit

Permalink
Merge pull request #168 from geometric-intelligence/pirnn_refactor
Browse files Browse the repository at this point in the history
Pirnn refactor
  • Loading branch information
franciscoeacosta authored Aug 21, 2024
2 parents 5f1d3db + fe1dd2e commit 34087aa
Show file tree
Hide file tree
Showing 36 changed files with 78,351 additions and 5,099 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ neurometry/datasets/rnn_grid_cells/Dual agent path integration high res/*
neurometry/datasets/rnn_grid_cells/Single agent path integration high res/*
neurometry/curvature/grid-cells-curvature/models/xu_rnn/results/*
neurometry/curvature/grid-cells-curvature/multi-agent/*
neurometry/neuroai/piRNNs/models/results/*
neurometry/neuroai/piRNNs/multi-agent/*
notebooks/

*viewer*
Expand Down Expand Up @@ -36,6 +38,9 @@ neurometry/datasets/rnn_grid_cells/Single agent path integration/*

neurometry/curvature/grid-cells-curvature/models/xu_rnn/logs/*
neurometry/curvature/grid-cells-curvature/models/xu_rnn/wandb/*
neurometry/neuroai/piRNNs/models/logs/*
neurometry/neuroai/piRNNs/models/wandb/*
neurometry/neuroai/piRNNs/models/pretrained/*


# Result files
Expand Down
4 changes: 2 additions & 2 deletions neurometry/curvature/datasets/gridcells.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import numpy as np
import pandas as pd

import neurometry.curvature.datasets.structures as structures

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs

import neurometry.curvature.datasets.structures as structures


# TODO
def load_grid_cells_synthetic(
Expand Down
7 changes: 4 additions & 3 deletions neurometry/curvature/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
import logging
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs
import numpy as np
import pandas as pd
import skimage
import torch
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from torch.distributions.multivariate_normal import MultivariateNormal

from neurometry.topology.persistent_homology import (
cohomological_circular_coordinates,
cohomological_toroidal_coordinates,
)

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs
from geomstats.geometry.special_orthogonal import SpecialOrthogonal


def load_projected_images(n_scalars=5, n_angles=1000, img_size=128):
"""Load a dataset of 2D images projected into 1D projections.
Expand Down
12 changes: 6 additions & 6 deletions neurometry/curvature/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import numpy as np
import torch

from neurometry.curvature.datasets.synthetic import (
get_s1_synthetic_immersion,
get_s2_synthetic_immersion,
get_t2_synthetic_immersion,
)

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs
from geomstats.geometry.base import ImmersedSet
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.pullback_metric import PullbackMetric
from geomstats.geometry.special_orthogonal import SpecialOrthogonal

from neurometry.curvature.datasets.synthetic import (
get_s1_synthetic_immersion,
get_s2_synthetic_immersion,
get_t2_synthetic_immersion,
)


class NeuralManifoldIntrinsic(ImmersedSet):
def __init__(self, dim, neural_embedding_dim, neural_immersion, equip=True):
Expand Down
201 changes: 0 additions & 201 deletions neurometry/curvature/grid-cells-curvature/models/xu_rnn/LICENSE

This file was deleted.

Loading

0 comments on commit 34087aa

Please sign in to comment.