Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Umap2 #123

Merged
merged 23 commits into from
Apr 14, 2024
Merged

Umap2 #123

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(simplicial_set_intersect)
export(simplicial_set_union)
export(tumap)
export(umap)
export(umap2)
export(umap_transform)
export(unload_uwot)
import(Matrix)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ in the documentation, and also the rnndescent package's
[documentation](https://jlmelville.github.io/rnndescent/index.html) for details.
`rnndescent` is only a suggested package, not a requirement, so you need to
install it yourself (e.g. via `install.packages("rnndescent")`).
* New function: `umap2`, which acts like `umap` but with modified defaults,
reflecting my experience with UMAP and correcting some small mistakes. See the
[umap2 article](https://jlmelville.github.io/uwot/articles/umap2.html) for more
details.

## Bug fixes and minor improvements

Expand Down
6 changes: 5 additions & 1 deletion R/neighbors.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@ find_nn <- function(X, k, include_self = TRUE, method = "fnn",
n_threads = NULL,
grain_size = 1,
ret_index = FALSE,
sparse_is_distance = TRUE,
verbose = FALSE) {
if (is.null(n_threads)) {
n_threads <- default_num_threads()
}

if (inherits(X, "dist")) {
res <- dist_nn(X, k, include_self = include_self, verbose = verbose)
} else if (is_sparse_matrix(X)) {
} else if (sparse_is_distance && is_sparse_matrix(X)) {
# sparse distance matrix
if (Matrix::isTriangular(X)) {
res <- sparse_tri_nn(X, k, include_self = include_self, verbose = verbose)
} else {
res <- sparse_nn(X, k, include_self = include_self, verbose = verbose)
}
} else {
if (is_sparse_matrix(X) && method != "nndescent") {
stop("Sparse matrix input only supported for nndescent method.")
}
# normal matrix
switch(method,
"fnn" = {
Expand Down
3 changes: 2 additions & 1 deletion R/transform.R
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ umap_transform <- function(X = NULL, model = NULL,
n_vertices <- NULL
Xnames <- NULL
if (!is.null(X)) {
if (!(methods::is(X, "data.frame") || methods::is(X, "matrix"))) {
if (!(methods::is(X, "data.frame") ||
methods::is(X, "matrix") || is_sparse_matrix(X))) {
stop("Unknown input data format")
}
if (!is.null(norig_col) && ncol(X) != norig_col) {
Expand Down
820 changes: 820 additions & 0 deletions R/umap2.R

Large diffs are not rendered by default.

150 changes: 110 additions & 40 deletions R/uwot.R
Original file line number Diff line number Diff line change
Expand Up @@ -3057,7 +3057,8 @@ uwot <- function(X, n_neighbors = 15, n_components = 2, metric = "euclidean",
dens_scale = NULL,
is_similarity_graph = FALSE,
seed = NULL,
nn_args = list()) {
nn_args = list(),
sparse_X_is_distance_matrix = TRUE) {
if (is.null(n_threads)) {
n_threads <- default_num_threads()
}
Expand Down Expand Up @@ -3209,7 +3210,7 @@ uwot <- function(X, n_neighbors = 15, n_components = 2, metric = "euclidean",
n_vertices <- attr(X, "Size")
tsmessage("Read ", n_vertices, " rows")
Xnames <- labels(X)
} else if (is_sparse_matrix(X)) {
} else if (is_sparse_matrix(X) && sparse_X_is_distance_matrix) {
if (ret_model) {
stop("Can only create models with dense matrix or data frame input")
}
Expand All @@ -3223,7 +3224,7 @@ uwot <- function(X, n_neighbors = 15, n_components = 2, metric = "euclidean",
} else {
cat_ids <- NULL
norig_col <- ncol(X)
if (methods::is(X, "data.frame") || methods::is(X, "matrix")) {
if (methods::is(X, "data.frame") || methods::is(X, "matrix") || is_sparse_matrix(X)) {
cat_res <- find_categoricals(metric)
metric <- cat_res$metrics
cat_ids <- cat_res$categoricals
Expand Down Expand Up @@ -3348,17 +3349,31 @@ uwot <- function(X, n_neighbors = 15, n_components = 2, metric = "euclidean",
need_sigma <- FALSE
} else {
need_sigma <- ret_sigma || ret_localr || !is.null(dens_scale)
d2sr <- data2set(X, Xcat, n_neighbors, metrics, nn_method,
n_trees, search_k,
d2sr <- data2set(
X,
Xcat,
n_neighbors,
metrics,
nn_method,
n_trees,
search_k,
method,
set_op_mix_ratio, local_connectivity, bandwidth,
perplexity, kernel, need_sigma,
n_threads, grain_size,
set_op_mix_ratio,
local_connectivity,
bandwidth,
perplexity,
kernel,
need_sigma,
n_threads,
grain_size,
ret_model,
pca = pca, pca_center = pca_center, pca_method = pca_method,
pca = pca,
pca_center = pca_center,
pca_method = pca_method,
n_vertices = n_vertices,
nn_args = nn_args,
tmpdir = tmpdir,
sparse_is_distance = sparse_X_is_distance_matrix,
verbose = verbose
)
}
Expand Down Expand Up @@ -4279,17 +4294,30 @@ x2nv <- function(X) {
n_vertices
}

data2set <- function(X, Xcat, n_neighbors, metrics, nn_method,
n_trees, search_k,
data2set <- function(X,
Xcat,
n_neighbors,
metrics,
nn_method,
n_trees,
search_k,
method,
set_op_mix_ratio, local_connectivity, bandwidth,
perplexity, kernel, ret_sigma,
n_threads, grain_size,
set_op_mix_ratio,
local_connectivity,
bandwidth,
perplexity,
kernel,
ret_sigma,
n_threads,
grain_size,
ret_model,
n_vertices = x2nv(X),
tmpdir = tempdir(),
pca = NULL, pca_center = TRUE, pca_method = "irlba",
pca = NULL,
pca_center = TRUE,
pca_method = "irlba",
nn_args = list(),
sparse_is_distance = TRUE,
verbose = FALSE) {
V <- NULL
nns <- list()
Expand Down Expand Up @@ -4425,18 +4453,27 @@ data2set <- function(X, Xcat, n_neighbors, metrics, nn_method,
n_neighbors <- NULL
}

x2set_res <- x2set(Xsub, n_neighbors, metric,
x2set_res <- x2set(
Xsub,
n_neighbors,
metric,
nn_method = nn_sub,
n_trees, search_k,
n_trees,
search_k,
method,
set_op_mix_ratio, local_connectivity, bandwidth,
perplexity, kernel,
set_op_mix_ratio,
local_connectivity,
bandwidth,
perplexity,
kernel,
ret_sigma,
n_threads, grain_size,
n_threads,
grain_size,
ret_model,
n_vertices = n_vertices,
nn_args = nn_args,
tmpdir = tmpdir,
sparse_is_distance = sparse_is_distance,
verbose = verbose
)
Vblock <- x2set_res$V
Expand Down Expand Up @@ -4472,13 +4509,19 @@ data2set <- function(X, Xcat, n_neighbors, metrics, nn_method,
res
}

x2nn <- function(X, n_neighbors, metric, nn_method,
n_trees, search_k,
x2nn <- function(X,
n_neighbors,
metric,
nn_method,
n_trees,
search_k,
tmpdir = tempdir(),
n_threads, grain_size,
n_threads,
grain_size,
ret_model,
n_vertices = x2nv(X),
nn_args = list(),
sparse_is_distance = TRUE,
verbose = FALSE) {
if (is.list(nn_method)) {
validate_nn(nn_method, n_vertices)
Expand All @@ -4496,13 +4539,20 @@ x2nn <- function(X, n_neighbors, metric, nn_method,
if (nn_method == "fnn" && ret_model) {
stop("nn_method = 'FNN' is incompatible with ret_model = TRUE")
}
nn <- find_nn(X, n_neighbors,
method = nn_method, metric = metric,
n_trees = n_trees, search_k = search_k,
nn <- find_nn(
X,
n_neighbors,
method = nn_method,
metric = metric,
n_trees = n_trees,
search_k = search_k,
nn_args = nn_args,
tmpdir = tmpdir,
n_threads = n_threads, grain_size = grain_size,
ret_index = ret_model, verbose = verbose
n_threads = n_threads,
grain_size = grain_size,
ret_index = ret_model,
sparse_is_distance = sparse_is_distance,
verbose = verbose
)
}
nn
Expand Down Expand Up @@ -4580,17 +4630,26 @@ nn2set <- function(method, nn,
res
}

x2set <- function(X, n_neighbors, metric, nn_method,
n_trees, search_k,
x2set <- function(X,
n_neighbors,
metric,
nn_method,
n_trees,
search_k,
method,
set_op_mix_ratio, local_connectivity, bandwidth,
perplexity, kernel,
set_op_mix_ratio,
local_connectivity,
bandwidth,
perplexity,
kernel,
ret_sigma,
n_threads, grain_size,
n_threads,
grain_size,
ret_model,
n_vertices = x2nv(X),
tmpdir = tempdir(),
nn_args = list(),
sparse_is_distance = TRUE,
verbose = FALSE) {
if (is_sparse_matrix(nn_method)) {
nn <- nn_method
Expand All @@ -4601,16 +4660,20 @@ x2set <- function(X, n_neighbors, metric, nn_method,
stop("Sparse distance matrix must have same dimensions as input data")
}
} else {
nn <- x2nn(X,
nn <- x2nn(
X,
n_neighbors = n_neighbors,
metric = metric,
nn_method = nn_method,
n_trees = n_trees, search_k = search_k,
n_trees = n_trees,
search_k = search_k,
tmpdir = tmpdir,
n_threads = n_threads, grain_size = grain_size,
n_threads = n_threads,
grain_size = grain_size,
ret_model = ret_model,
nn_args = nn_args,
n_vertices = n_vertices,
sparse_is_distance = sparse_is_distance,
verbose = verbose
)
if (any(is.infinite(nn$dist))) {
Expand All @@ -4619,10 +4682,17 @@ x2set <- function(X, n_neighbors, metric, nn_method,
}
gc()

nn2set_res <- nn2set(method, nn,
set_op_mix_ratio, local_connectivity, bandwidth,
perplexity, kernel, ret_sigma,
n_threads, grain_size,
nn2set_res <- nn2set(
method,
nn,
set_op_mix_ratio,
local_connectivity,
bandwidth,
perplexity,
kernel,
ret_sigma,
n_threads,
grain_size,
verbose = verbose
)
V <- nn2set_res$V
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ devtools::install_github("jlmelville/uwot")
```R
library(uwot)

# umap2 is a version of the umap() function with better defaults
iris_umap <- umap2(iris)

# but you can still use the umap function (which most of the existing
# documentation does)
iris_umap <- umap(iris)

# Load mnist from somewhere, e.g.
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ articles:
- title: Articles
desc: More details on some of what `uwot` can do.
contents:
- articles/umap2
- articles/mixed-data-types
- articles/fast-sgd
- articles/init
Expand Down
Loading
Loading