Skip to content

Commit

Permalink
simple examples
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Nov 27, 2023
1 parent 5717092 commit 3bef285
Show file tree
Hide file tree
Showing 39 changed files with 106 additions and 134 deletions.
8 changes: 4 additions & 4 deletions MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#training parameters
epochs : 20, # optimisation epochs
epochs : 100 # optimisation epochs
batch_size : 64 # batch size
lr: 0.01 # learning rate
momentum: 0.9

#manifold/signal parameters
order: 2 # order to which to compute the directional derivatives
inner_product_features: True
inner_product_features: False
diffusion: False
frac_sampled_nb: -1 # fraction of neighbours to sample for gradient computation (if -1 then all neighbours)
include_positions: False # include positions as features (warning: this is untested!)
Expand All @@ -18,9 +18,9 @@ hidden_channels: [16] # number of hidden channels
out_channels: 3 # number of output channels (if null, then =hidden_channels)
bias: True # learn bias parameters in MLP
vec_norm: False
batch_norm: False # batch normalisation
batch_norm: True # batch normalisation
emb_norm: False # spherical output
skip_connections: True # use skips in MLP
skip_connections: False # use skips in MLP

# other params
seed: 0 # seed for reproducibility
Expand Down
23 changes: 16 additions & 7 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from MARBLE.lib.cknn import cknneighbors_graph # isort:skip
from MARBLE import utils # isort:skip


def furthest_point_sampling(x, N=None, stop_crit=0.0, start_idx=0):
"""A greedy O(N^2) algorithm to do furthest points sampling
Expand Down Expand Up @@ -457,7 +456,9 @@ def compute_laplacian(data, normalization="rw"):
num_nodes=data.num_nodes
)

return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze()
# return PyGu.to_dense_adj(edge_index, edge_attr=edge_attr).squeeze()
n = len(data.x)
return sp.coo_array((edge_attr, (edge_index[0], edge_index[1])), shape=(n, n))


def compute_connection_laplacian(data, R, normalization="rw"):
Expand Down Expand Up @@ -670,11 +671,12 @@ def vector_diffusion(x, t, method="spectral", Lc=None):
return out


def compute_eigendecomposition(A, eps=1e-8):
def compute_eigendecomposition(A, k=50, eps=1e-8):
"""Eigendecomposition of a square matrix A.
Args:
A: square matrix A
k: number of eigenvectors
eps: small error term
Returns:
Expand All @@ -683,15 +685,22 @@ def compute_eigendecomposition(A, eps=1e-8):
"""
if A is None:
return None

A = A.to_dense()

if k >= A.shape[0]:
k = None

# Compute the eigenbasis
failcount = 0
while True:
try:
evals, evecs = torch.linalg.eigh(A)
if k is None:
evals, evecs = torch.linalg.eigh(A)
else:
evals, evecs = sp.linalg.eigsh(A, k=k, which="SM")
evals, evecs = torch.tensor(evals), torch.tensor(evecs)

evals = torch.clamp(evals, min=0.0)
evecs *= np.sqrt(len(evecs))

break
except Exception as e: # pylint: disable=broad-exception-caught
Expand All @@ -702,4 +711,4 @@ def compute_eigendecomposition(A, eps=1e-8):
print("--- decomp failed; adding eps ===> count: " + str(failcount))
A += torch.eye(A.shape[0]) * (eps * 10 ** (failcount - 1))

return evals, evecs
return evals, evecs
27 changes: 17 additions & 10 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from MARBLE import layers
from MARBLE import utils

import warnings

class net(nn.Module):
"""MARBLE neural network.
Expand Down Expand Up @@ -60,16 +61,17 @@ def __init__(self, data, loadpath=None, params=None, verbose=True):
"""
super().__init__()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if loadpath is not None:
if Path(loadpath).is_dir():
loadpath = max(glob.glob(f"{loadpath}/best_model*"))
self.params = torch.load(loadpath, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))["params"]
self.params = torch.load(loadpath, map_location=device)["params"]
else:
if params is not None:
if isinstance(params, str) and Path(params).exists():
with open(params, "rb") as f:
params = yaml.safe_load(f)
self.params = params
else:
self.params = {}

self._epoch = 0 # to resume optimisation
self.parse_parameters(data)
Expand Down Expand Up @@ -217,11 +219,7 @@ def setup_layers(self):
bias=self.params["bias"],
)






def forward(self, data, n_id, adjs=None):
"""Forward pass.
Messages are passed to a set target nodes (current batch) from source
Expand Down Expand Up @@ -289,8 +287,12 @@ def forward(self, data, n_id, adjs=None):
emb = F.normalize(emb)

return emb, mask[: size[1]]

def evaluate(self, data):
warnings.warn("MARBLE.evaluate() is deprecated. Use MARBLE.transform() instead.")
self.transform(data)

def evaluate(self, data):
def transform(self, data):
"""Forward pass @ evaluation (no minibatches)"""
with torch.no_grad():
size = (data.x.shape[0], data.x.shape[0])
Expand Down Expand Up @@ -345,8 +347,13 @@ def batch_loss(self, data, loader, train=False, verbose=False, optimizer=None):
self.eval()

return cum_loss / len(loader), optimizer

def run_training(self, data, outdir=None, verbose=False):
warnings.warn("MARBLE.run_training() is deprecated. Use MARBLE.fit() instead.")

self.fit(data, outdir=outdir, verbose=verbose)

def fit(self, data, outdir=None, verbose=False):
"""Network training.
Args:
Expand Down
10 changes: 4 additions & 6 deletions MARBLE/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def histograms(data, titles=None, col=2, figsize=(10, 10)):
"""
assert hasattr(
data, "clusters"
), "No clusters found. First, run \
geometry.cluster(data) or postprocessing(data)!"
), "No clusters found. First, run postprocessing.cluster(data)!"

labels, s = data.clusters["labels"], data.clusters["slices"]
n_slices = len(s) - 1
Expand Down Expand Up @@ -245,9 +244,9 @@ def embedding(
ax.scatter(emb_[t, 0], emb_[t, 1], emb_[t, 2], c=cgrad, alpha=alpha, s=s, label=title)
else:
if dim == 2:
ax.scatter(emb_[:, 0], emb_[:, 1], c=c_, alpha=alpha, s=s, label=title)
ax.scatter(emb_[:, 0], emb_[:, 1], color=c_, alpha=alpha, s=s, label=title)
elif dim == 3:
ax.scatter(emb_[:, 0], emb_[:, 1], emb_[:, 2], c=c, alpha=alpha, s=s, label=title)
ax.scatter(emb_[:, 0], emb_[:, 1], emb_[:, 2], color=c_, alpha=alpha, s=s, label=title)

if dim == 2:
if hasattr(data, "clusters") and clusters_visible:
Expand Down Expand Up @@ -296,8 +295,7 @@ def neighbourhoods(

assert hasattr(
data, "clusters"
), "No clusters found. First, run \
geometry.cluster(data) or postprocessing(data)!"
), "No clusters found. First, run postprocessing.cluster(data)!"

vector = data.x.shape[1] > 1
clusters = data.clusters
Expand Down
27 changes: 17 additions & 10 deletions MARBLE/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
from MARBLE import geometry as g


def cluster(data, cluster_typ="kmeans", n_clusters=15, seed=0):

clusters = g.cluster(data.emb, cluster_typ, n_clusters, seed)
clusters = g.relabel_by_proximity(clusters)

clusters["slices"] = data._slice_dict["x"] # pylint: disable=protected-access

if data.number_of_resamples > 1:
clusters["slices"] = clusters["slices"][:: data.number_of_resamples]

data.clusters = clusters

return data


def distribution_distances(data, cluster_typ="kmeans", n_clusters=None, seed=0):
"""Return distance between datasets.
Expand All @@ -18,21 +33,13 @@ def distribution_distances(data, cluster_typ="kmeans", n_clusters=None, seed=0):

if n_clusters is not None:
# k-means cluster
clusters = g.cluster(emb, cluster_typ, n_clusters, seed)
clusters = g.relabel_by_proximity(clusters)

clusters["slices"] = data._slice_dict["x"] # pylint: disable=protected-access

if data.number_of_resamples > 1:
clusters["slices"] = clusters["slices"][:: data.number_of_resamples]
data = cluster(data, cluster_typ, n_clusters, seed)

# compute distances between clusters
data.dist, data.gamma = g.compute_distribution_distances(
clusters=clusters, slices=clusters["slices"]
clusters=data.clusters, slices=data.clusters["slices"]
)

data.clusters = clusters

else:
data.emb = emb
data.dist, _ = g.compute_distribution_distances(
Expand Down
67 changes: 18 additions & 49 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Prepare module."""
import numpy as np
"""Preprocessing module."""
import torch
from torch_geometric.data import Batch
from torch_geometric.data import Data
Expand All @@ -14,16 +13,12 @@ def construct_dataset(
labels=None,
mask=None,
graph_type="cknn",
k=15,
n_geodesic_nb=10,
k=20,
frac_geodesic_nb=1.5,
stop_crit=0.0,
number_of_resamples=1,
compute_laplacian=False,
compute_connection_laplacian=False,
return_spectrum=True,
var_explained=0.9,
local_gauges=False,
dim_man=None,
delta=1.0,
):
"""Construct PyG dataset from node positions and features.
Expand All @@ -34,16 +29,13 @@ def construct_dataset(
labels: any additional data labels used for plotting only
graph_type: type of nearest-neighbours graph: cknn (default), knn or radius
k: number of nearest-neighbours to construct the graph
n_geodesic_nb: number of geodesic neighbours to fit the gauges to
to map to tangent space
frac_geodesic_nb: number of geodesic neighbours to fit the gauges to
to map to tangent space k*frac_geodesic_nb
stop_crit: stopping criterion for furthest point sampling
number_of_resamples: number of furthest point sampling runs to prevent bias (experimental)
compute_laplacian: set to True to compute laplacian
compute_connection_laplacian: set to True to compute the connection laplacian
var_explained: fraction of variance explained by the local gauges
local_gauges: is True, it will try to compute local gauges if it can (signal dim is > 2,
embedding dimension is > 2 or dim embedding is not dim of manifold)
dim_man: if the manifold dimension is known, it can be set here or it will be estimated
delta: argument for cknn graph construction to decide the radius for each points.
"""

Expand Down Expand Up @@ -103,24 +95,15 @@ def construct_dataset(
return _compute_geometric_objects(
batch,
local_gauges=local_gauges,
compute_laplacian=compute_laplacian,
compute_connection_laplacian=compute_connection_laplacian,
n_geodesic_nb=n_geodesic_nb,
frac_geodesic_nb=frac_geodesic_nb,
var_explained=var_explained,
dim_man=dim_man,
return_spectrum=return_spectrum
)


def _compute_geometric_objects(
data,
n_geodesic_nb=2.0,
def _compute_geometric_objects(data,
frac_geodesic_nb=2.0,
var_explained=0.9,
return_spectrum=True,
local_gauges=False,
compute_laplacian=False,
compute_connection_laplacian=False,
dim_man=None,
):
"""
Compute geometric objects used later: local gauges, Levi-Civita connections
Expand Down Expand Up @@ -157,7 +140,7 @@ def _compute_geometric_objects(

if local_gauges:
try:
gauges, Sigma = g.compute_gauges(data, n_geodesic_nb=n_geodesic_nb)
gauges, Sigma = g.compute_gauges(data, n_geodesic_nb=frac_geodesic_nb)
except Exception as exc:
raise Exception(
"\nCould not compute gauges (possibly data is too sparse or the \
Expand All @@ -166,44 +149,30 @@ def _compute_geometric_objects(
else:
gauges = torch.eye(dim_emb).repeat(n, 1, 1)

# Laplacian
if compute_laplacian:
L = g.compute_laplacian(data)
else:
L = None
L = g.compute_laplacian(data)

if local_gauges:
if not dim_man:
dim_man = g.manifold_dimension(Sigma, frac_explained=var_explained)
data.dim_man = dim_man

print(f"\n---- Manifold dimension: {dim_man}")
data.dim_man = g.manifold_dimension(Sigma, frac_explained=var_explained)
print(f"\n---- Manifold dimension: {data.dim_man}")

gauges = gauges[:, :, :dim_man]
gauges = gauges[:, :, :data.dim_man]
R = g.compute_connections(data, gauges)

print("\n---- Computing kernels ... ", end="")
kernels = g.gradient_op(data.pos, data.edge_index, gauges)
kernels = [utils.tile_tensor(K, dim_man) for K in kernels]
kernels = [utils.tile_tensor(K, data.dim_man) for K in kernels]
kernels = [K * R for K in kernels]
print("Done ")

if compute_connection_laplacian:
Lc = g.compute_connection_laplacian(data, R)
else:
Lc = None
Lc = g.compute_connection_laplacian(data, R)

else:
print("\n---- Computing kernels ... ", end="")
kernels = g.gradient_op(data.pos, data.edge_index, gauges)
print("Done ")
Lc = None

if return_spectrum:
print("---- Computing eigendecomposition ... ", end="")
L = g.compute_eigendecomposition(L)
Lc = g.compute_eigendecomposition(Lc)
print("Done ")
print("---- Computing eigendecomposition ... ", end="")
L = g.compute_eigendecomposition(L)
Lc = g.compute_eigendecomposition(Lc)

data.kernels = [
utils.to_SparseTensor(K.coalesce().indices(), value=K.coalesce().values()) for K in kernels
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 3bef285

Please sign in to comment.