Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/agosztolai/MARBLE into main
Browse files Browse the repository at this point in the history
  • Loading branch information
peach-lucien committed Oct 26, 2023
2 parents dc6e97f + c09c929 commit cf72d2a
Show file tree
Hide file tree
Showing 23 changed files with 521 additions and 302 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ myjob*
__pycache__/
*.egg
*.egg-info/
log/
2 changes: 1 addition & 1 deletion MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _compute_geometric_objects(
n_geodesic_nb=2.0,
var_explained=0.9,
return_spectrum=True,
local_gauges=True,
local_gauges=False,
compute_laplacian=False,
compute_connection_laplacian=False,
dim_man=None,
Expand Down
27 changes: 15 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,32 @@ The code is tested for both cpu and gpu (CUDA) machines running Linux or OSX. Al

We recommend you install the code in a fresh Anaconda virtual environment, as follows.

First, clone this repository,
- First, clone this repository,

```
git clone https://github.com/agosztolai/MARBLE
```

Then, create an new anaconda environment using the provided environment file that matches your system.
- Then, create an new anaconda environment using the provided environment file that matches your system.
- For Linux machines with CUDA:

For Linux machines with CUDA:
`conda env create -f environment.yml`
- For Intel-based Mac:

```
conda env create -f environment.yml
```

For Mac without CUDA:
`conda env create -f environment_osx_intel.yml`

```
conda env create -f environment_cpu_osx.yml
```
- For recent M1/M2/M3 Mac:
- Install cmake `brew install cmake` or using the installer on the [cmake website](https://cmake.org/download/)
- Create the environment

This will install all the requires dependencies. Finally, install by running inside the main folder
`conda env create -f environment_osx_arm.yml`
- Activate the environment `conda activate MARBLE`
- Install pytorch geometric
`pip install -r requirements_osx_arm.txt`
- All the required dependencies are now installed. Finally, activate the environment and install by running inside the main folder

```
conda activate MARBLE
pip install .
```

Expand Down
109 changes: 109 additions & 0 deletions environment_osx_arm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
name: MARBLE
channels:
- pytorch
- defaults
dependencies:
- blas=1.0=openblas
- bottleneck=1.3.5=py39heec5a64_0
- brotli=1.0.9=h1a28f6b_7
- brotli-bin=1.0.9=h1a28f6b_7
- brotlipy=0.7.0=py39h1a28f6b_1002
- bzip2=1.0.8=h620ffc9_4
- ca-certificates=2023.08.22=hca03da5_0
- certifi=2023.7.22=py39hca03da5_0
- cffi=1.15.1=py39h80987f9_3
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- contourpy=1.0.5=py39h525c30c_0
- cryptography=41.0.3=py39hd4332d6_0
- cycler=0.11.0=pyhd3eb1b0_0
- ffmpeg=4.2.2=h04105a8_0
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.12.1=h1192e45_0
- gettext=0.21.0=h13f89a0_1
- giflib=5.2.1=h80987f9_3
- gmp=6.2.1=hc377ac9_3
- gnutls=3.6.15=h887c41c_0
- icu=73.1=h313beb8_0
- idna=3.4=py39hca03da5_0
- importlib_resources=5.2.0=pyhd3eb1b0_1
- jpeg=9e=h80987f9_1
- kiwisolver=1.4.4=py39h313beb8_0
- lame=3.100=h1a28f6b_0
- lcms2=2.12=hba8e193_0
- lerc=3.0=hc377ac9_0
- libbrotlicommon=1.0.9=h1a28f6b_7
- libbrotlidec=1.0.9=h1a28f6b_7
- libbrotlienc=1.0.9=h1a28f6b_7
- libcxx=14.0.6=h848a8c0_0
- libdeflate=1.17=h80987f9_1
- libffi=3.4.4=hca03da5_0
- libgfortran=5.0.0=11_3_0_hca03da5_28
- libgfortran5=11.3.0=h009349e_28
- libiconv=1.16=h1a28f6b_2
- libidn2=2.3.4=h80987f9_0
- libopenblas=0.3.21=h269037a_0
- libopus=1.3=h1a28f6b_1
- libpng=1.6.39=h80987f9_0
- libtasn1=4.19.0=h80987f9_0
- libtiff=4.5.1=h313beb8_0
- libunistring=0.9.10=h1a28f6b_0
- libvpx=1.10.0=hc377ac9_0
- libwebp=1.3.2=ha3663a8_0
- libwebp-base=1.3.2=h80987f9_0
- libxml2=2.10.4=h0dcf63f_1
- llvm-openmp=14.0.6=hc6e5704_0
- lz4-c=1.9.4=h313beb8_0
- matplotlib=3.7.2=py39hca03da5_0
- matplotlib-base=3.7.2=py39h46d7db6_0
- munkres=1.1.4=py_0
- ncurses=6.4=h313beb8_0
- nettle=3.7.3=h84b5d62_1
- networkx=3.1=py39hca03da5_0
- numexpr=2.8.7=py39hecc3335_0
- numpy=1.26.0=py39h3b2db8e_0
- numpy-base=1.26.0=py39ha9811e2_0
- openh264=1.8.0=h98b2900_0
- openjpeg=2.3.0=h7a6adac_2
- openssl=3.0.11=h1a28f6b_2
- packaging=23.1=py39hca03da5_0
- pandas=2.1.1=py39h46d7db6_0
- pillow=10.0.1=py39h3b245a6_0
- pip=23.3=py39hca03da5_0
- pybind11=2.10.4=py39h48ca7d4_0
- pybind11-global=2.10.4=py39h48ca7d4_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=23.2.0=py39hca03da5_0
- pyparsing=3.0.9=py39hca03da5_0
- pysocks=1.7.1=py39hca03da5_0
- python=3.9.18=hb885b13_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-tzdata=2023.3=pyhd3eb1b0_0
- pytorch=1.12.1=py3.9_0
- pytz=2023.3.post1=py39hca03da5_0
- pyyaml=6.0=py39h80987f9_1
- readline=8.2=h1a28f6b_0
- requests=2.31.0=py39hca03da5_0
- scipy=1.11.3=py39h20cbe94_0
- seaborn=0.12.2=py39hca03da5_0
- setuptools=68.0.0=py39hca03da5_0
- scikit-learn=1.3.0=py39h46d7db6_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.41.2=h80987f9_0
- tk=8.6.12=hb8d0fd4_0
- torchaudio=0.12.1=py39_cpu
- torchvision=0.13.1=py39_cpu
- tornado=6.3.3=py39h80987f9_0
- tqdm=4.65.0=py39h86d0a89_0
- typing_extensions=4.7.1=py39hca03da5_0
- tzdata=2023c=h04d1e81_0
- urllib3=1.26.16=py39hca03da5_0
- wheel=0.41.2=py39hca03da5_0
- x264=1!152.20180806=h1a28f6b_0
- xz=5.4.2=h80987f9_0
- yaml=0.2.5=h1a28f6b_0
- zipp=3.11.0=py39hca03da5_0
- zlib=1.2.13=h5a0b063_0
- zstd=1.5.5=hd90d995_0
- pip:
- ninja==1.11.1.1
- teaspoon==1.3.1
File renamed without changes.
545 changes: 322 additions & 223 deletions examples/RNN/RNN.ipynb

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions examples/RNN/RNN_scripts/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@ def load_network(f):
return z, net


def sample_network(net, f):
def sample_network(net, f, seed=0):

if os.path.exists(f):
print('Network found with same name. Loading...')
return torch.load(open(f, "rb"))

n_pops = 2
seed = 0
z, _ = clustering.gmm_fit(net, n_pops, algo="bayes", random_state=seed)
net_sampled = clustering.to_support_net(net, z)

Expand All @@ -68,7 +67,7 @@ def sample_network(net, f):


def generate_trajectories(
net, input=None, epochs=None, n_traj=None, fname="./outputs/RNN_trajectories.pkl"
net, input=None, epochs=None, n_traj=None, fname="./data/RNN_trajectories.pkl"
):

if fname is not None:
Expand Down Expand Up @@ -455,7 +454,7 @@ def plot_experiment(net, input, traj, epochs, rect=(-8, 8, -6, 6), traj_to_show=
fig.subplots_adjust(hspace=0.1, wspace=0.1)


def aggregate_data(traj, epochs, transient=10, only_stim=False):
def aggregate_data(traj, epochs, transient=10, only_stim=False, pca=True):

n_conds = len(traj)
n_epochs = len(epochs) - 1
Expand All @@ -468,9 +467,10 @@ def aggregate_data(traj, epochs, transient=10, only_stim=False):
for j in range(n_traj): # trajectories
pos.append(traj[i][j][k][transient:])

pca = PCA(n_components=3)
pca.fit(np.vstack(pos))
print("Explained variance: ", pca.explained_variance_ratio_)
if pca:
pca = PCA(n_components=3)
pca.fit(np.vstack(pos))
print("Explained variance: ", pca.explained_variance_ratio_)

# aggregate data under baseline condition (no input)
pos, vel = [], []
Expand All @@ -480,7 +480,8 @@ def aggregate_data(traj, epochs, transient=10, only_stim=False):
for k in [0, 2, 4]:
for j in range(n_traj): # trajectories
pos_proj = traj[i][j][k][transient:]
pos_proj = pca.transform(pos_proj)
if pca:
pos_proj = pca.transform(pos_proj)
pos_.append(pos_proj[:-1]) # stack trajectories
vel_.append(np.diff(pos_proj, axis=0)) # compute differences

Expand All @@ -494,7 +495,8 @@ def aggregate_data(traj, epochs, transient=10, only_stim=False):
for k in [1, 3]:
for j in range(n_traj): # trajectories
pos_proj = traj[i][j][k][transient:]
pos_proj = pca.transform(pos_proj)
if pca:
pos_proj = pca.transform(pos_proj)
pos_.append(pos_proj[:-1])
vel_.append(np.diff(pos_proj, axis=0))

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file not shown.
Loading

0 comments on commit cf72d2a

Please sign in to comment.