Skip to content

Commit

Permalink
cleaned up graph loading
Browse files Browse the repository at this point in the history
  • Loading branch information
falexwolf committed Jul 27, 2017
1 parent d35d714 commit f0cfad6
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 64 deletions.
11 changes: 3 additions & 8 deletions scanpy/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
"""

# order alphabetically
from ..tools.aga import aga
from ..tools.aga import aga_contract_graph
from ..tools.dbscan import dbscan
from ..tools.draw_graph import draw_graph
from ..tools.diffmap import diffmap
from ..tools.rank_genes_groups import rank_genes_groups
from ..tools.dpt import dpt
Expand All @@ -13,11 +16,3 @@
from ..tools.sim import sim
from ..tools.spring import spring
from ..tools.tsne import tsne

try:
# development tools
from ..tools.draw_graph import draw_graph
from ..tools.aga import aga
from ..tools.aga import aga_contract_graph
except ImportError:
pass
42 changes: 34 additions & 8 deletions scanpy/data_structs/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@
from .. import utils


def add_graph_to_adata(
adata,
n_neighbors=30,
n_pcs=50,
recompute_pca=None,
recompute_graph=False,
n_jobs=None):
graph = DataGraph(adata,
k=n_neighbors,
n_pcs=n_pcs,
recompute_pca=recompute_pca,
recompute_graph=recompute_graph,
n_jobs=n_jobs)
graph.update_diffmap()
adata.add['distance'] = graph.Dsq
adata.add['Ktilde'] = graph.Ktilde
adata.smp['X_diffmap'] = graph.rbasis[:, 1:]
adata.smp['X_diffmap0'] = graph.rbasis[:, 0]
adata.add['diffmap_evals'] = graph.evals[1:]


def get_neighbors(X, Y, k):
Dsq = utils.comp_sqeuclidean_distance_using_matrix_mult(X, Y)
chunk_range = np.arange(Dsq.shape[0])[:, None]
Expand Down Expand Up @@ -159,7 +180,7 @@ def __init__(self,
and adata.smp['X_diffmap'].shape[1] >= n_dcs-1):
self.n_pcs = n_pcs
self.n_dcs = n_dcs
self.iroot = None if 'iroot' not in adata.add else adata.add['iroot']
self.init_iroot_directly(adata)
self.X = adata.X # this is a hack, PCA?
self.knn = issparse(adata.add['Ktilde'])
self.Ktilde = adata.add['Ktilde']
Expand All @@ -177,7 +198,7 @@ def __init__(self,
self.Dchosen = OnFlySymMatrix(self.get_Ddiff_row,
shape=(self.X.shape[0], self.X.shape[0]))
np.set_printoptions(precision=3)
logg.info('use stored data graph with `n_neighbors = {}` and '
logg.info(' using stored data graph with n_neighbors = {} and '
'spectrum\n {}'
.format(self.k,
str(self.evals).replace('\n', '\n ')))
Expand Down Expand Up @@ -211,19 +232,24 @@ def __init__(self,
.format(self.k))
self.Dsq = adata.add['distance']

def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs):
# retrieve xroot
xroot = None
if 'xroot' in adata.add: xroot = adata.add['xroot']
elif 'xroot' in adata.var: xroot = adata.var['xroot']
# set iroot directly
def init_iroot_directly(self, adata):
if 'iroot' in adata.add:
if adata.add['iroot'] >= adata.n_smps:
logg.warn('Root cell index {} does not exist for {} samples. '
'Is ignored.'
.format(adata.add['iroot'], adata.n_smps))
self.iroot = None
else:
self.iroot = adata.add['iroot']


def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs):
# retrieve xroot
xroot = None
if 'xroot' in adata.add: xroot = adata.add['xroot']
elif 'xroot' in adata.var: xroot = adata.var['xroot']
# set iroot directly
self.init_iroot_directly(adata)
# see whether we can set self.iroot using the full data matrix
if xroot is not None and xroot.size == self.X.shape[1]:
self.set_root(xroot)
Expand Down
12 changes: 7 additions & 5 deletions scanpy/tools/aga.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ def aga(adata,
root_cell_was_passed = False
logg.m('... no root cell found, no computation of pseudotime')
msg = \
'''To enable computation of pseudotime, pass the expression "xroot" of a root cell.
Either add
'''To enable computation of pseudotime, pass the index or expression vector
of a root cell. Either add
adata.add['iroot'] = root_cell_index
or (robust to subsampling)
adata.var['xroot'] = adata.X[root_cell_index, :]
where `root_cell_index` is the integer index of the root cell, or
where "root_cell_index" is the integer index of the root cell, or
adata.var['xroot'] = adata[root_cell_name, :].X
where `root_cell_name` is the name (a string) of the root cell.'''
where "root_cell_name" is the name (a string) of the root cell.'''
logg.hint(msg)
fresh_compute_louvain = False
if ((node_groups == 'louvain' and 'louvain_groups' not in adata.smp_keys())
Expand All @@ -127,6 +129,7 @@ def aga(adata,
fresh_compute_louvain = True
clusters = node_groups
if node_groups == 'louvain': clusters = 'louvain_groups'
logg.info('running Approximate Graph Abstraction (AGA)', r=True)
aga = AGA(adata,
clusters=clusters,
n_neighbors=n_neighbors,
Expand All @@ -149,7 +152,6 @@ def aga(adata,
adata.add['diffmap_evals'] = aga.evals[1:]
adata.add['distance'] = aga.Dsq
adata.add['Ktilde'] = aga.Ktilde
logg.info('running Approximate Graph Abstraction (AGA)', r=True)
if aga.iroot is not None:
aga.set_pseudotime() # pseudotimes are random walk distances from root point
adata.add['iroot'] = aga.iroot # update iroot, might have changed when subsampling, for example
Expand Down
2 changes: 1 addition & 1 deletion scanpy/tools/dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def dpt(adata, n_branchings=0, n_neighbors=30, knn=True, n_pcs=50, n_dcs=10,
where "root_cell_index" is the integer index of the root cell, or
adata.var['xroot'] = adata[root_cell_name, :].X
where "root_cell_name" is the name (a string) of the root cell.'''
logg.m(msg, v='hint')
logg.hint(msg)
if n_branchings == 0:
logg.m('set parameter `n_branchings` > 0 to detect branchings', v='hint')
dpt = DPT(adata, n_neighbors=n_neighbors, knn=knn, n_pcs=n_pcs, n_dcs=n_dcs,
Expand Down
30 changes: 17 additions & 13 deletions scanpy/tools/draw_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
transcriptomics: Weinreb et al., bioRxiv doi:10.1101/090332 (2016)
"""

import numpy as np
from .. import utils
from ..data_structs.data_graph import add_graph_to_adata


def draw_graph(adata,
layout='fr',
root=None,
n_neighbors=30,
n_pcs=50,
root=None,
n_jobs=None,
random_state=0,
recompute_pca=None,
recompute_graph=False,
adjacency=None,
n_jobs=None,
copy=False):
"""Visualize data using standard graph drawing algorithms.
Expand Down Expand Up @@ -51,22 +56,21 @@ def draw_graph(adata,
from .. import logging as logg
logg.info('drawing single-cell graph using layout "{}"'.format(layout),
r=True)
import numpy as np
from .. import data_structs
from .. import utils
avail_layouts = {'fr', 'drl', 'kk', 'grid_fr', 'lgl', 'rt', 'rt_circular'}
if layout not in avail_layouts:
raise ValueError('Provide a valid layout, one of {}.'.format(avail_layouts))
adata = adata.copy() if copy else adata
if 'Ktilde' not in adata.add or recompute_graph:
graph = data_structs.DataGraph(adata,
k=n_neighbors,
n_pcs=n_pcs,
n_jobs=n_jobs)
graph.compute_transition_matrix(recompute_distance=True)
adata.add['Ktilde'] = graph.Ktilde
elif n_neighbors is not None and not recompute_graph:
logg.warn('`n_neighbors={}` has no effect (set `recompute_graph=True` to enable it)'
add_graph_to_adata(
adata,
n_neighbors=n_neighbors,
n_pcs=n_pcs,
recompute_pca=recompute_pca,
recompute_graph=recompute_graph,
n_jobs=n_jobs)
else:
n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1
logg.info(' using stored graph with n_neighbors = {}'
.format(n_neighbors))
adjacency = adata.add['Ktilde']
g = utils.get_igraph_from_adjacency(adjacency)
Expand Down
39 changes: 15 additions & 24 deletions scanpy/tools/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy as np
from .. import utils
from .. import logging as logg
from ..data_structs import DataGraph
from ..data_structs.data_graph import add_graph_to_adata


def louvain(adata,
n_neighbors=30,
Expand Down Expand Up @@ -43,28 +44,19 @@ def louvain(adata,
- basic suggestion for single-cell: Levine et al., Cell 162, 184-197 (2015)
- combination with "attachedness" matrix: Wolf et al., bioRxiv (2017)
"""
logg.m('run Louvain clustering', r=True)
logg.m('running Louvain clustering', r=True)
adata = adata.copy() if copy else adata
if 'Ktilde' not in adata.add or recompute_graph:
graph = DataGraph(adata,
k=n_neighbors,
n_pcs=n_pcs,
recompute_pca=recompute_pca,
recompute_graph=recompute_graph,
n_jobs=n_jobs)
# compute diffmap for later use although it's not needed here
# it does not cost much
graph.update_diffmap()
adata.add['distance'] = graph.Dsq
adata.add['Ktilde'] = graph.Ktilde
adata.smp['X_diffmap'] = graph.rbasis[:, 1:]
adata.smp['X_diffmap0'] = graph.rbasis[:, 0]
adata.add['diffmap_evals'] = graph.evals[1:]
add_graph_to_adata(
adata,
n_neighbors=n_neighbors,
n_pcs=n_pcs,
recompute_pca=recompute_pca,
recompute_graph=recompute_graph,
n_jobs=n_jobs)
else:
# do not use the undirected kernel Ktilde here, but the
# sparse distance matrix
n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1
logg.info(' using precomputed graph with n_neighbors={}'
logg.info(' using stored graph with n_neighbors = {}'
.format(n_neighbors))
adjacency = adata.add['Ktilde']
if flavor in {'vtraag', 'igraph'}:
Expand All @@ -83,11 +75,10 @@ def louvain(adata,
resolution_parameter=resolution)
adata.add['louvain_quality'] = part.quality()
except AttributeError:
logg.warn('Did not find louvain package >= 0.6 on your system, '
'the result will therefore not be 100% reproducible, but '
'is influenced by randomness in the community detection '
'algorithm. Still you get very meaningful results!\n'
'If you want 100% reproducible results, but 0.6 is not yet '
logg.warn('Did not find package louvain>=0.6, '
'the clustering result will therefore not be 100% reproducible, '
'but still meaningful! '
'If you want 100% reproducible results, but louvain 0.6 is not yet '
'available via "pip install louvain", '
'either get the latest (development) version from '
'https://github.com/vtraag/louvain-igraph or use the option '
Expand Down
8 changes: 3 additions & 5 deletions scanpy/tools/tsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,15 @@ def tsne(adata, random_state=0, n_pcs=50, perplexity=30, learning_rate=None,
X_tsne = tsne.fit_transform(X.astype(np.float64))
except ImportError:
multicore_failed = True
logg.hint('did not find package MulticoreTSNE: to speed up the computation, install it from\n'
' https://github.com/DmitryUlyanov/Multicore-TSNE')
if multicore_failed:
from sklearn.manifold import TSNE
# unfortunately, we cannot set a minimum number of iterations for barnes-hut
params_sklearn['learning_rate'] = 1000 if learning_rate is None else learning_rate
tsne = TSNE(**params_sklearn)
logg.warn('Consider installing the package MulticoreTSNE.\n'
' https://github.com/DmitryUlyanov/Multicore-TSNE\n'
'Even for `n_jobs=1` this speeds up the computation considerably.')
logg.info(' using sklearn.manifold.TSNE')
logg.warn('Consider installing the package MulticoreTSNE '
' https://github.com/DmitryUlyanov/Multicore-TSNE.'
' Even for `n_jobs=1` this speeds up the computation considerably and might yield better converged results.')
X_tsne = tsne.fit_transform(X)
# update AnnData instance
adata.smp['X_tsne'] = X_tsne
Expand Down

0 comments on commit f0cfad6

Please sign in to comment.