Skip to content

Commit

Permalink
added embedding norm and rat example
Browse files Browse the repository at this point in the history
  • Loading branch information
peach-lucien committed Nov 24, 2023
1 parent fc737e7 commit 6d8d587
Show file tree
Hide file tree
Showing 6 changed files with 1,812 additions and 4 deletions.
1 change: 1 addition & 0 deletions MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ out_channels: 3 # number of output channels (if null, then =hidden_channels)
bias: True # learn bias parameters in MLP
vec_norm: False
batch_norm: False # batch normalisation
emb_norm: False # spherical output

# other params
seed: 0 # seed for reproducibility
Expand Down
9 changes: 5 additions & 4 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class net(nn.Module):
out_channels: number of output channels (if null, then =hidden_channels) (default=3)
bias: learn bias parameters in MLP (default=True)
vec_norm: normalise features to unit length (default=False)
emb_norm: normalise MLP output to unit length (default=False)
batch_norm: batch normalisation (default=False)
seed: seed for reproducibility (default=0)
processes: number of cpus (default=1)
Expand All @@ -62,7 +63,7 @@ def __init__(self, data, loadpath=None, params=None, verbose=True):
if loadpath is not None:
if Path(loadpath).is_dir():
loadpath = max(glob.glob(f"{loadpath}/best_model*"))
self.params = torch.load(loadpath)["params"]
self.params = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))["params"]
else:
if params is not None:
if isinstance(params, str) and Path(params).exists():
Expand Down Expand Up @@ -282,8 +283,8 @@ def forward(self, data, n_id, adjs=None):

emb = self.enc(out)

#if self.params['emb_norm']:
emb = F.normalize(emb)
if self.params['emb_norm']: # spherical output
emb = F.normalize(emb)

return emb, mask[: size[1]]

Expand Down Expand Up @@ -412,7 +413,7 @@ def load_model(self, loadpath):
Args:
loadpath: directory with models to load best model, or specific model path
"""
checkpoint = torch.load(loadpath)
checkpoint = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
self._epoch = checkpoint["epoch"]
self.load_state_dict(checkpoint["model_state_dict"])
self.optimizer_state_dict = checkpoint["optimizer_state_dict"]
Expand Down
745 changes: 745 additions & 0 deletions examples/rat_task/Demo_consistency.ipynb

Large diffs are not rendered by default.

966 changes: 966 additions & 0 deletions examples/rat_task/Demo_decoding.ipynb

Large diffs are not rendered by default.

Binary file added examples/rat_task/rat_data.pkl
Binary file not shown.
95 changes: 95 additions & 0 deletions examples/rat_task/rat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys
import numpy as np
import matplotlib.pyplot as plt

from elephant.kernels import GaussianKernel
from elephant.statistics import instantaneous_rate
from quantities import ms
import neo

from sklearn.decomposition import PCA
import sklearn

import MARBLE
import cebra

def prepare_marble(spikes, labels, pca=None, pca_n=10, skip=1):

s_interval = 1

gk = GaussianKernel(10 * ms)
rates = []
for sp in spikes:
sp_times = np.where(sp)[0]
st = neo.SpikeTrain(sp_times, units="ms", t_stop=len(sp))
r = instantaneous_rate(st, kernel=gk, sampling_period=s_interval * ms).magnitude
rates.append(r.T)

rates = np.vstack(rates)

if pca is None:
pca = PCA(n_components=pca_n)
rates_pca = pca.fit_transform(rates.T)
else:
rates_pca = pca.transform(rates.T)

vel_rates_pca = np.diff(rates_pca, axis=0)
print(pca.explained_variance_ratio_)

rates_pca = rates_pca[:-1,:] # skip last

labels = labels[:rates_pca.shape[0]]

data = MARBLE.construct_dataset(
rates_pca,
features=vel_rates_pca,
k=15,
stop_crit=0.0,
delta=1.5,
compute_laplacian=True,
local_gauges=False,
)

return data, labels, pca


def find_sequences(vector):
sequences = []
start_index = 0

for i in range(1, len(vector)):
if vector[i] != vector[i - 1]:
sequences.append((start_index, i - 1))
start_index = i

# Add the last sequence
sequences.append((start_index, len(vector) - 1))

return sequences

# Define decoding function with kNN decoder. For a simple demo, we will use the fixed number of neighbors 36.
def decoding_pos_dir(embedding_train, embedding_test, label_train, label_test):
pos_decoder = cebra.KNNDecoder(n_neighbors=36, metric="cosine")
dir_decoder = cebra.KNNDecoder(n_neighbors=36, metric="cosine")

pos_decoder.fit(embedding_train, label_train[:,0])
dir_decoder.fit(embedding_train, label_train[:,1])

pos_pred = pos_decoder.predict(embedding_test)
dir_pred = dir_decoder.predict(embedding_test)

prediction = np.stack([pos_pred, dir_pred],axis = 1)

test_score = sklearn.metrics.r2_score(label_test[:,:2], prediction)
pos_test_err = np.median(abs(prediction[:,0] - label_test[:, 0]))
pos_test_score = sklearn.metrics.r2_score(label_test[:, 0], prediction[:,0])

prediction_error = abs(prediction[:,0] - label_test[:, 0])

# prediction error by back and forth
sequences = find_sequences(label_test[:,1])
errors = []
for seq in sequences:
errors.append(np.median(abs(prediction[seq,0] - label_test[seq, 0])))

return test_score, pos_test_err, pos_test_score, prediction, prediction_error, np.array(errors)

0 comments on commit 6d8d587

Please sign in to comment.