Skip to content

Commit

Permalink
implement sketch-based acceleration of integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Hie committed Nov 18, 2019
1 parent 794d08d commit 048e5de
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 79 deletions.
14 changes: 6 additions & 8 deletions bin/mouse_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,23 @@

if __name__ == '__main__':
process(data_names, min_trans=100)

datasets, genes_list, n_cells = load_names(data_names)

datasets, genes = merge_datasets(datasets, genes_list, ds_names=data_names)

datasets_dimred, genes = process_data(datasets, genes, verbose=True)

t0 = time()
datasets_dimred = assemble(
datasets_dimred, batch_size=BATCH_SIZE,
geosketch=True, geosketch_max=6900
)
print('Integrated panoramas in {:.3f}s'.format(time() - t0))

t0 = time()
datasets_dimred, datasets, genes = correct(
datasets, genes_list, ds_names=data_names,
return_dimred=True, batch_size=BATCH_SIZE,
geosketch=True, geosketch_max=6900
)
print('Integrated and batch corrected panoramas in {:.3f}s'
.format(time() - t0))
Expand All @@ -59,12 +57,12 @@
names.append(data_names[i])
curr_label += 1
labels = np.array(labels, dtype=int)

mouse_brain_genes = [
'Gja1', 'Flt1', 'Gabra6', 'Syt1', 'Gabrb2', 'Gabra1',
'Meg3', 'Mbp', 'Rgs5',
]

# Downsample for visualization purposes
datasets_dimred = []
for i in range(len(data_names)):
Expand All @@ -83,7 +81,7 @@
image_suffix='.png')
np.savetxt('data/{}_embedding.txt'.format(NAMESPACE),
embedding, delimiter='\t')

cell_labels = (
open('data/cell_labels/mouse_brain_cluster.txt')
.read().rstrip().split()
Expand Down
87 changes: 87 additions & 0 deletions bin/mouse_brain_sketched.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
from scanorama import *
from scipy.sparse import vstack
from sklearn.preprocessing import normalize, LabelEncoder
import sys
from time import time

from benchmark import write_table
from process import load_names, process

np.random.seed(0)

NAMESPACE = 'mouse_brain_sketched'
BATCH_SIZE = 10000

data_names = [
'data/mouse_brain/nuclei',
'data/mouse_brain/dropviz/Cerebellum_ALT',
'data/mouse_brain/dropviz/Cortex_noRep5_FRONTALonly',
'data/mouse_brain/dropviz/Cortex_noRep5_POSTERIORonly',
'data/mouse_brain/dropviz/EntoPeduncular',
'data/mouse_brain/dropviz/GlobusPallidus',
'data/mouse_brain/dropviz/Hippocampus',
'data/mouse_brain/dropviz/Striatum',
'data/mouse_brain/dropviz/SubstantiaNigra',
'data/mouse_brain/dropviz/Thalamus',
]

if __name__ == '__main__':
process(data_names, min_trans=100)

datasets, genes_list, n_cells = load_names(data_names)

datasets, genes = merge_datasets(datasets, genes_list, ds_names=data_names)

datasets_dimred, genes = process_data(datasets, genes, verbose=True)

t0 = time()
datasets_dimred, genes = integrate(
datasets, genes_list, ds_names=data_names,
sketch=True, sketch_method='geosketch', sketch_max=2000,
)
print('Sketched and integrated panoramas in {:.3f}s'
.format(time() - t0))

labels = []
names = []
curr_label = 0
for i, a in enumerate(datasets_dimred):
labels += list(np.zeros(a.shape[0]) + curr_label)
names.append(data_names[i])
curr_label += 1
labels = np.array(labels, dtype=int)

mouse_brain_genes = [
'Gja1', 'Flt1', 'Gabra6', 'Syt1', 'Gabrb2', 'Gabra1',
'Meg3', 'Mbp', 'Rgs5',
]

# Downsample for visualization purposes
datasets_dimred = []
for i in range(len(data_names)):
ds = datasets_dimred[i]
rand_idx = np.random.choice(ds.shape[0], size=int(ds.shape[0]/10),
replace=False)
datasets_dimred[i] = ds[rand_idx, :]
datasets[i] = datasets[i][rand_idx, :]
data.close()

embedding = visualize(datasets_dimred,
labels, NAMESPACE + '_ds', names,
gene_names=mouse_brain_genes, genes=genes,
gene_expr=vstack(datasets),
multicore_tsne=True,
image_suffix='.png')

cell_labels = (
open('data/cell_labels/mouse_brain_cluster.txt')
.read().rstrip().split()
)
le = LabelEncoder().fit(cell_labels)
cell_labels = le.transform(cell_labels)
cell_types = le.classes_

visualize(None,
cell_labels, NAMESPACE + '_type', cell_types,
embedding=embedding, image_suffix='.png')
2 changes: 2 additions & 0 deletions scanorama/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .scanorama import *

__version__ = 1.5
147 changes: 78 additions & 69 deletions scanorama/scanorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
def correct(datasets_full, genes_list, return_dimred=False,
batch_size=BATCH_SIZE, verbose=VERBOSE, ds_names=None,
dimred=DIMRED, approx=APPROX, sigma=SIGMA, alpha=ALPHA, knn=KNN,
return_dense=False, hvg=None, union=False,
geosketch=False, geosketch_max=20000, seed=0):
return_dense=False, hvg=None, union=False, seed=0):
"""Integrate and batch correct a list of data sets.
Parameters
Expand Down Expand Up @@ -101,7 +100,6 @@ def correct(datasets_full, genes_list, return_dimred=False,
expr_datasets=datasets, # Modified in place.
verbose=verbose, knn=knn, sigma=sigma, approx=approx,
alpha=alpha, ds_names=ds_names, batch_size=batch_size,
geosketch=geosketch, geosketch_max=geosketch_max,
)

if return_dense:
Expand All @@ -115,8 +113,8 @@ def correct(datasets_full, genes_list, return_dimred=False,
# Integrate a list of data sets.
def integrate(datasets_full, genes_list, batch_size=BATCH_SIZE,
verbose=VERBOSE, ds_names=None, dimred=DIMRED, approx=APPROX,
sigma=SIGMA, alpha=ALPHA, knn=KNN, geosketch=False,
geosketch_max=20000, n_iter=1, union=False, hvg=None, seed=0):
sigma=SIGMA, alpha=ALPHA, knn=KNN, union=False, hvg=None, seed=0,
sketch=False, sketch_method='geosketch', sketch_max=10000,):
"""Integrate a list of data sets.
Parameters
Expand Down Expand Up @@ -147,6 +145,16 @@ def integrate(datasets_full, genes_list, batch_size=BATCH_SIZE,
Use this number of top highly variable genes based on dispersion.
seed: `int`, optional (default: 0)
Random seed to use.
sketch: `bool`, optional (default: False)
Apply sketching-based acceleration by first downsampmling the datsets.
See Hie et al., Cell Systems (2019).
sketch_method: {'geosketch', 'uniform'}, optional (default: `geosketch`)
Apply the given sketching method to the data. Only used if
`sketch=True`.
sketch_max: `int`, optional (default: 10000)
If a dataset has more cells than `sketch_max`, downsample to
`sketch_max` using the method provided in `sketch_method`. Only used
if `sketch=True`.
Returns
-------
Expand All @@ -165,12 +173,21 @@ def integrate(datasets_full, genes_list, batch_size=BATCH_SIZE,
datasets_dimred, genes = process_data(datasets, genes, hvg=hvg,
dimred=dimred)

for _ in range(n_iter):
if sketch:
datasets_dimred = integrate_sketch(
datasets_dimred, sketch_method=sketch_method, N=sketch_max,
integration_fn=assemble, integration_fn_args={
'verbose': verbose, 'knn': knn, 'sigma': sigma,
'approx': approx, 'alpha': alpha, 'ds_names': ds_names,
'batch_size': batch_size,
},
)

else:
datasets_dimred = assemble(
datasets_dimred, # Assemble in low dimensional space.
verbose=verbose, knn=knn, sigma=sigma, approx=approx,
alpha=alpha, ds_names=ds_names, batch_size=batch_size,
geosketch=geosketch, geosketch_max=geosketch_max,
)

return datasets_dimred, genes
Expand Down Expand Up @@ -584,21 +601,10 @@ def fill_table(table, i, curr_ds, datasets, base_ds=0,

# Fill table of alignment scores.
def find_alignments_table(datasets, knn=KNN, approx=APPROX, verbose=VERBOSE,
prenormalized=False, geosketch=False,
geosketch_max=20000):
prenormalized=False):
if not prenormalized:
datasets = [ normalize(ds, axis=1) for ds in datasets ]

if geosketch:
# Only match cells in geometric sketches.
from ample import gs, uniform
global gs_idxs
if gs_idxs is None:
gs_idxs = [ uniform(X, geosketch_max, replace=False)
if X.shape[0] > geosketch_max else range(X.shape[0])
for X in datasets ]
datasets = [ datasets[i][gs_idx, :] for i, gs_idx in enumerate(gs_idxs) ]

table = {}
for i in range(len(datasets)):
if len(datasets[:i]) > 0:
Expand Down Expand Up @@ -631,13 +637,6 @@ def find_alignments_table(datasets, knn=KNN, approx=APPROX, verbose=VERBOSE,
if verbose > 1:
table_print[i, j] += table1[(i, j)]

if geosketch:
# Translate matches within geometric sketches to original indices.
matches_mnn = matches[(i, j)]
matches[(i, j)] = [
(gs_idxs[i][a], gs_idxs[j][b]) for a, b in matches_mnn
]

if verbose > 1:
print(table_print)
return table1, table_print, matches
Expand All @@ -646,12 +645,10 @@ def find_alignments_table(datasets, knn=KNN, approx=APPROX, verbose=VERBOSE,

# Find the matching pairs of cells between datasets.
def find_alignments(datasets, knn=KNN, approx=APPROX, verbose=VERBOSE,
alpha=ALPHA, prenormalized=False,
geosketch=False, geosketch_max=20000):
alpha=ALPHA, prenormalized=False,):
table1, _, matches = find_alignments_table(
datasets, knn=knn, approx=approx, verbose=verbose,
prenormalized=prenormalized,
geosketch=geosketch, geosketch_max=geosketch_max
)

alignments = [ (i, j) for (i, j), val in reversed(
Expand Down Expand Up @@ -771,14 +768,13 @@ def transform(curr_ds, curr_ref, ds_ind, ref_ind, sigma=SIGMA, cn=False,
def assemble(datasets, verbose=VERBOSE, view_match=False, knn=KNN,
sigma=SIGMA, approx=APPROX, alpha=ALPHA, expr_datasets=None,
ds_names=None, batch_size=None,
geosketch=False, geosketch_max=20000, alignments=None, matches=None):
alignments=None, matches=None):
if len(datasets) == 1:
return datasets

if alignments is None and matches is None:
alignments, matches = find_alignments(
datasets, knn=knn, approx=approx, alpha=alpha, verbose=verbose,
geosketch=geosketch, geosketch_max=geosketch_max
)

ds_assembled = {}
Expand Down Expand Up @@ -951,6 +947,57 @@ def assemble(datasets, verbose=VERBOSE, view_match=False, knn=KNN,

return datasets

# Sketch-based acceleration of integration.
def integrate_sketch(datasets_dimred, sketch_method='geosketch', N=10000,
integration_fn=assemble, integration_fn_args={}):

from geosketch import gs, uniform

if sketch_method.lower() == 'geosketch' or sketch_method.lower() == 'gs':
sampling_fn = gs
else:
sampling_fn = uniform

# Sketch each dataset.

sketch_idxs = [
sorted(set(sampling_fn(X, N, replace=False)))
if X.shape[0] > N else list(range(X.shape[0]))
for X in datasets_dimred
]
datasets_sketch = [ X[idx] for X, idx in zip(datasets_dimred, sketch_idxs) ]

# Integrate the dataset sketches.

datasets_int = integration_fn(datasets_sketch[:], **integration_fn_args)

# Apply integrated coordinates back to full data.

labels = []
curr_label = 0
for i, a in enumerate(datasets_sketch):
labels += list(np.zeros(a.shape[0]) + curr_label)
curr_label += 1
labels = np.array(labels, dtype=int)

for i, (X_dimred, X_sketch) in enumerate(zip(datasets_dimred, datasets_sketch)):
X_int = datasets_int[i]

neigh = NearestNeighbors(n_neighbors=3).fit(X_dimred)
_, neigh_idx = neigh.kneighbors(X_sketch)

ds_idxs, ref_idxs = [], []
for ref_idx in range(neigh_idx.shape[0]):
for k_idx in range(neigh_idx.shape[1]):
ds_idxs.append(neigh_idx[ref_idx, k_idx])
ref_idxs.append(ref_idx)

bias = transform(X_dimred, X_int, ds_idxs, ref_idxs, 15, batch_size=1000)

datasets_int[i] = X_dimred + bias

return datasets_int

# Non-optimal dataset assembly. Simply accumulate datasets into a
# reference.
def assemble_accum(datasets, verbose=VERBOSE, knn=KNN, sigma=SIGMA,
Expand All @@ -976,41 +1023,3 @@ def assemble_accum(datasets, verbose=VERBOSE, knn=KNN, sigma=SIGMA,
datasets[j] = ds1 + bias

return datasets

def interpret_alignments(datasets, expr_datasets, genes,
verbose=VERBOSE, knn=KNN, approx=APPROX,
alpha=ALPHA, n_permutations=None):
if n_permutations is None:
n_permutations = float(len(genes) * 30)

alignments, matches = find_alignments(
datasets, knn=knn, approx=approx, alpha=alpha, verbose=verbose
)

for i, j in alignments:
# Compute average bias vector that aligns two datasets together.
ds_i = expr_datasets[i]
ds_j = expr_datasets[j]
if i < j:
match = matches[(i, j)]
else:
match = matches[(j, i)]
i_ind = [ a for a, _ in match ]
j_ind = [ b for _, b in match ]
avg_bias = np.absolute(
np.mean(ds_j[j_ind, :] - ds_i[i_ind, :], axis=0)
)

# Construct null distribution and compute p-value.
null_bias = (
ds_j[np.random.randint(ds_j.shape[0], size=n_permutations), :] -
ds_i[np.random.randint(ds_i.shape[0], size=n_permutations), :]
)
p = ((np.sum(np.greater_equal(
np.absolute(np.tile(avg_bias, (n_permutations, 1))),
np.absolute(null_bias)
), axis=0, dtype=float) + 1) / (n_permutations + 1))

print('>>>> Stats for alignment {}'.format((i, j)))
for k in range(len(p)):
print('{}\t{}'.format(genes[k], p[k]))
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

setup(
name='scanorama',
version='1.4',
version='1.5',
description='Panoramic stitching of heterogeneous single cell transcriptomic data',
url='https://github.com/brianhie/scanorama',
download_url='https://github.com/brianhie/scanorama/archive/v1.4.tar.gz',
download_url='https://github.com/brianhie/scanorama/archive/v1.5.tar.gz',
packages=find_packages(exclude=['bin', 'conf', 'data', 'target']),
install_requires=[
'annoy>=1.11.5',
'fbpca>=1.0',
'geosketch>=1.0',
'intervaltree==2.1.0',
'matplotlib>=2.0.2',
'numpy>=1.12.0',
Expand Down

0 comments on commit 048e5de

Please sign in to comment.