From 6183031dd1ff517d684f6f3b754c4a8ebbeaa8de Mon Sep 17 00:00:00 2001 From: agosztolai Date: Mon, 17 Jun 2024 11:05:07 +0200 Subject: [PATCH] fix corrupted notebook --- MARBLE/layers.py | 2 +- examples/RNN/RNN.ipynb | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/MARBLE/layers.py b/MARBLE/layers.py index d635ed75..e3ab7214 100644 --- a/MARBLE/layers.py +++ b/MARBLE/layers.py @@ -87,7 +87,7 @@ def __init__(self, C, D): self.O_mat = nn.ModuleList() for _ in range(C): - self.O_mat.append(nn.Linear(D, D, bias=True)) + self.O_mat.append(nn.Linear(D, D, bias=False)) self.reset_parameters() diff --git a/examples/RNN/RNN.ipynb b/examples/RNN/RNN.ipynb index 32834a9b..4b1cbc8b 100644 --- a/examples/RNN/RNN.ipynb +++ b/examples/RNN/RNN.ipynb @@ -1076,10 +1076,8 @@ " v = pos[j]\n", "\n", " n_samples = np.min([len(u), len(v)])\n", - " ind = np.random.choice(len(u), size=(n_samples,), replace=False)\n", - " u = u[np.sort(ind)]\n", - " ind = np.random.choice(len(v), size=(n_samples,), replace=False)\n", - " v = v[np.sort(ind)]\n", + " u = u[:n_samples]\n", + " v = v[:n_samples]\n", " \n", " u_score, v_score = cca.fit_transform(u, v)\n", "\n", @@ -1196,7 +1194,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" },