From f0cfad66e074a7db81d96d51c5230ac2c5913f7d Mon Sep 17 00:00:00 2001 From: falexwolf Date: Thu, 27 Jul 2017 12:23:38 +0200 Subject: [PATCH] cleaned up graph loading --- scanpy/api/tools.py | 11 +++----- scanpy/data_structs/data_graph.py | 42 +++++++++++++++++++++++++------ scanpy/tools/aga.py | 12 +++++---- scanpy/tools/dpt.py | 2 +- scanpy/tools/draw_graph.py | 30 ++++++++++++---------- scanpy/tools/louvain.py | 39 +++++++++++----------------- scanpy/tools/tsne.py | 8 +++--- 7 files changed, 80 insertions(+), 64 deletions(-) diff --git a/scanpy/api/tools.py b/scanpy/api/tools.py index 86733d3b8e..047d8ac917 100644 --- a/scanpy/api/tools.py +++ b/scanpy/api/tools.py @@ -4,7 +4,10 @@ """ # order alphabetically +from ..tools.aga import aga +from ..tools.aga import aga_contract_graph from ..tools.dbscan import dbscan +from ..tools.draw_graph import draw_graph from ..tools.diffmap import diffmap from ..tools.rank_genes_groups import rank_genes_groups from ..tools.dpt import dpt @@ -13,11 +16,3 @@ from ..tools.sim import sim from ..tools.spring import spring from ..tools.tsne import tsne - -try: - # development tools - from ..tools.draw_graph import draw_graph - from ..tools.aga import aga - from ..tools.aga import aga_contract_graph -except ImportError: - pass diff --git a/scanpy/data_structs/data_graph.py b/scanpy/data_structs/data_graph.py index f0170c6e8f..3e4238e745 100644 --- a/scanpy/data_structs/data_graph.py +++ b/scanpy/data_structs/data_graph.py @@ -14,6 +14,27 @@ from .. import utils +def add_graph_to_adata( + adata, + n_neighbors=30, + n_pcs=50, + recompute_pca=None, + recompute_graph=False, + n_jobs=None): + graph = DataGraph(adata, + k=n_neighbors, + n_pcs=n_pcs, + recompute_pca=recompute_pca, + recompute_graph=recompute_graph, + n_jobs=n_jobs) + graph.update_diffmap() + adata.add['distance'] = graph.Dsq + adata.add['Ktilde'] = graph.Ktilde + adata.smp['X_diffmap'] = graph.rbasis[:, 1:] + adata.smp['X_diffmap0'] = graph.rbasis[:, 0] + adata.add['diffmap_evals'] = graph.evals[1:] + + def get_neighbors(X, Y, k): Dsq = utils.comp_sqeuclidean_distance_using_matrix_mult(X, Y) chunk_range = np.arange(Dsq.shape[0])[:, None] @@ -159,7 +180,7 @@ def __init__(self, and adata.smp['X_diffmap'].shape[1] >= n_dcs-1): self.n_pcs = n_pcs self.n_dcs = n_dcs - self.iroot = None if 'iroot' not in adata.add else adata.add['iroot'] + self.init_iroot_directly(adata) self.X = adata.X # this is a hack, PCA? self.knn = issparse(adata.add['Ktilde']) self.Ktilde = adata.add['Ktilde'] @@ -177,7 +198,7 @@ def __init__(self, self.Dchosen = OnFlySymMatrix(self.get_Ddiff_row, shape=(self.X.shape[0], self.X.shape[0])) np.set_printoptions(precision=3) - logg.info('use stored data graph with `n_neighbors = {}` and ' + logg.info(' using stored data graph with n_neighbors = {} and ' 'spectrum\n {}' .format(self.k, str(self.evals).replace('\n', '\n '))) @@ -211,19 +232,24 @@ def __init__(self, .format(self.k)) self.Dsq = adata.add['distance'] - def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs): - # retrieve xroot - xroot = None - if 'xroot' in adata.add: xroot = adata.add['xroot'] - elif 'xroot' in adata.var: xroot = adata.var['xroot'] - # set iroot directly + def init_iroot_directly(self, adata): if 'iroot' in adata.add: if adata.add['iroot'] >= adata.n_smps: logg.warn('Root cell index {} does not exist for {} samples. ' 'Is ignored.' .format(adata.add['iroot'], adata.n_smps)) + self.iroot = None else: self.iroot = adata.add['iroot'] + + + def init_iroot_and_X_from_PCA(self, adata, recompute_pca, n_pcs): + # retrieve xroot + xroot = None + if 'xroot' in adata.add: xroot = adata.add['xroot'] + elif 'xroot' in adata.var: xroot = adata.var['xroot'] + # set iroot directly + self.init_iroot_directly(adata) # see whether we can set self.iroot using the full data matrix if xroot is not None and xroot.size == self.X.shape[1]: self.set_root(xroot) diff --git a/scanpy/tools/aga.py b/scanpy/tools/aga.py index 67825f4cb6..947acd7baa 100644 --- a/scanpy/tools/aga.py +++ b/scanpy/tools/aga.py @@ -108,12 +108,14 @@ def aga(adata, root_cell_was_passed = False logg.m('... no root cell found, no computation of pseudotime') msg = \ - '''To enable computation of pseudotime, pass the expression "xroot" of a root cell. - Either add + '''To enable computation of pseudotime, pass the index or expression vector + of a root cell. Either add + adata.add['iroot'] = root_cell_index + or (robust to subsampling) adata.var['xroot'] = adata.X[root_cell_index, :] - where `root_cell_index` is the integer index of the root cell, or + where "root_cell_index" is the integer index of the root cell, or adata.var['xroot'] = adata[root_cell_name, :].X - where `root_cell_name` is the name (a string) of the root cell.''' + where "root_cell_name" is the name (a string) of the root cell.''' logg.hint(msg) fresh_compute_louvain = False if ((node_groups == 'louvain' and 'louvain_groups' not in adata.smp_keys()) @@ -127,6 +129,7 @@ def aga(adata, fresh_compute_louvain = True clusters = node_groups if node_groups == 'louvain': clusters = 'louvain_groups' + logg.info('running Approximate Graph Abstraction (AGA)', r=True) aga = AGA(adata, clusters=clusters, n_neighbors=n_neighbors, @@ -149,7 +152,6 @@ def aga(adata, adata.add['diffmap_evals'] = aga.evals[1:] adata.add['distance'] = aga.Dsq adata.add['Ktilde'] = aga.Ktilde - logg.info('running Approximate Graph Abstraction (AGA)', r=True) if aga.iroot is not None: aga.set_pseudotime() # pseudotimes are random walk distances from root point adata.add['iroot'] = aga.iroot # update iroot, might have changed when subsampling, for example diff --git a/scanpy/tools/dpt.py b/scanpy/tools/dpt.py index 39ccee96e0..f5f7364d5c 100644 --- a/scanpy/tools/dpt.py +++ b/scanpy/tools/dpt.py @@ -104,7 +104,7 @@ def dpt(adata, n_branchings=0, n_neighbors=30, knn=True, n_pcs=50, n_dcs=10, where "root_cell_index" is the integer index of the root cell, or adata.var['xroot'] = adata[root_cell_name, :].X where "root_cell_name" is the name (a string) of the root cell.''' - logg.m(msg, v='hint') + logg.hint(msg) if n_branchings == 0: logg.m('set parameter `n_branchings` > 0 to detect branchings', v='hint') dpt = DPT(adata, n_neighbors=n_neighbors, knn=knn, n_pcs=n_pcs, n_dcs=n_dcs, diff --git a/scanpy/tools/draw_graph.py b/scanpy/tools/draw_graph.py index c4f4069cdf..2fb8258f93 100644 --- a/scanpy/tools/draw_graph.py +++ b/scanpy/tools/draw_graph.py @@ -8,16 +8,21 @@ transcriptomics: Weinreb et al., bioRxiv doi:10.1101/090332 (2016) """ +import numpy as np +from .. import utils +from ..data_structs.data_graph import add_graph_to_adata + def draw_graph(adata, layout='fr', + root=None, n_neighbors=30, n_pcs=50, - root=None, - n_jobs=None, random_state=0, + recompute_pca=None, recompute_graph=False, adjacency=None, + n_jobs=None, copy=False): """Visualize data using standard graph drawing algorithms. @@ -51,22 +56,21 @@ def draw_graph(adata, from .. import logging as logg logg.info('drawing single-cell graph using layout "{}"'.format(layout), r=True) - import numpy as np - from .. import data_structs - from .. import utils avail_layouts = {'fr', 'drl', 'kk', 'grid_fr', 'lgl', 'rt', 'rt_circular'} if layout not in avail_layouts: raise ValueError('Provide a valid layout, one of {}.'.format(avail_layouts)) adata = adata.copy() if copy else adata if 'Ktilde' not in adata.add or recompute_graph: - graph = data_structs.DataGraph(adata, - k=n_neighbors, - n_pcs=n_pcs, - n_jobs=n_jobs) - graph.compute_transition_matrix(recompute_distance=True) - adata.add['Ktilde'] = graph.Ktilde - elif n_neighbors is not None and not recompute_graph: - logg.warn('`n_neighbors={}` has no effect (set `recompute_graph=True` to enable it)' + add_graph_to_adata( + adata, + n_neighbors=n_neighbors, + n_pcs=n_pcs, + recompute_pca=recompute_pca, + recompute_graph=recompute_graph, + n_jobs=n_jobs) + else: + n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1 + logg.info(' using stored graph with n_neighbors = {}' .format(n_neighbors)) adjacency = adata.add['Ktilde'] g = utils.get_igraph_from_adjacency(adjacency) diff --git a/scanpy/tools/louvain.py b/scanpy/tools/louvain.py index 7424cdf79a..9ec413e597 100644 --- a/scanpy/tools/louvain.py +++ b/scanpy/tools/louvain.py @@ -7,7 +7,8 @@ import numpy as np from .. import utils from .. import logging as logg -from ..data_structs import DataGraph +from ..data_structs.data_graph import add_graph_to_adata + def louvain(adata, n_neighbors=30, @@ -43,28 +44,19 @@ def louvain(adata, - basic suggestion for single-cell: Levine et al., Cell 162, 184-197 (2015) - combination with "attachedness" matrix: Wolf et al., bioRxiv (2017) """ - logg.m('run Louvain clustering', r=True) + logg.m('running Louvain clustering', r=True) adata = adata.copy() if copy else adata if 'Ktilde' not in adata.add or recompute_graph: - graph = DataGraph(adata, - k=n_neighbors, - n_pcs=n_pcs, - recompute_pca=recompute_pca, - recompute_graph=recompute_graph, - n_jobs=n_jobs) - # compute diffmap for later use although it's not needed here - # it does not cost much - graph.update_diffmap() - adata.add['distance'] = graph.Dsq - adata.add['Ktilde'] = graph.Ktilde - adata.smp['X_diffmap'] = graph.rbasis[:, 1:] - adata.smp['X_diffmap0'] = graph.rbasis[:, 0] - adata.add['diffmap_evals'] = graph.evals[1:] + add_graph_to_adata( + adata, + n_neighbors=n_neighbors, + n_pcs=n_pcs, + recompute_pca=recompute_pca, + recompute_graph=recompute_graph, + n_jobs=n_jobs) else: - # do not use the undirected kernel Ktilde here, but the - # sparse distance matrix n_neighbors = adata.add['distance'][0].nonzero()[0].size + 1 - logg.info(' using precomputed graph with n_neighbors={}' + logg.info(' using stored graph with n_neighbors = {}' .format(n_neighbors)) adjacency = adata.add['Ktilde'] if flavor in {'vtraag', 'igraph'}: @@ -83,11 +75,10 @@ def louvain(adata, resolution_parameter=resolution) adata.add['louvain_quality'] = part.quality() except AttributeError: - logg.warn('Did not find louvain package >= 0.6 on your system, ' - 'the result will therefore not be 100% reproducible, but ' - 'is influenced by randomness in the community detection ' - 'algorithm. Still you get very meaningful results!\n' - 'If you want 100% reproducible results, but 0.6 is not yet ' + logg.warn('Did not find package louvain>=0.6, ' + 'the clustering result will therefore not be 100% reproducible, ' + 'but still meaningful! ' + 'If you want 100% reproducible results, but louvain 0.6 is not yet ' 'available via "pip install louvain", ' 'either get the latest (development) version from ' 'https://github.com/vtraag/louvain-igraph or use the option ' diff --git a/scanpy/tools/tsne.py b/scanpy/tools/tsne.py index 61130d3ebb..d15c14d464 100644 --- a/scanpy/tools/tsne.py +++ b/scanpy/tools/tsne.py @@ -99,17 +99,15 @@ def tsne(adata, random_state=0, n_pcs=50, perplexity=30, learning_rate=None, X_tsne = tsne.fit_transform(X.astype(np.float64)) except ImportError: multicore_failed = True - logg.hint('did not find package MulticoreTSNE: to speed up the computation, install it from\n' - ' https://github.com/DmitryUlyanov/Multicore-TSNE') if multicore_failed: from sklearn.manifold import TSNE # unfortunately, we cannot set a minimum number of iterations for barnes-hut params_sklearn['learning_rate'] = 1000 if learning_rate is None else learning_rate tsne = TSNE(**params_sklearn) - logg.warn('Consider installing the package MulticoreTSNE.\n' - ' https://github.com/DmitryUlyanov/Multicore-TSNE\n' - 'Even for `n_jobs=1` this speeds up the computation considerably.') logg.info(' using sklearn.manifold.TSNE') + logg.warn('Consider installing the package MulticoreTSNE ' + ' https://github.com/DmitryUlyanov/Multicore-TSNE.' + ' Even for `n_jobs=1` this speeds up the computation considerably and might yield better converged results.') X_tsne = tsne.fit_transform(X) # update AnnData instance adata.smp['X_tsne'] = X_tsne